package org.apache.sysds.runtime.controlprogram;

import java.util.ArrayList;
import java.util.Iterator;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.api.jmlc.JMLCUtils;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.recompile.Recompiler;
import org.apache.sysds.parser.ParseInfo;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.DMLScriptException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.cp.BooleanObject;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.cp.IntObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.cp.StringObject;
import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysds.runtime.lineage.LineageCache;
import org.apache.sysds.runtime.lineage.LineageCacheConfig;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.meta.MetaData;
import org.apache.sysds.runtime.meta.MetaDataFormat;
import org.apache.sysds.runtime.privacy.propagation.PrivacyPropagator;
import org.apache.sysds.utils.Statistics;
import org.apache.sysds.utils.stats.RecompileStatistics;

/* loaded from: input_file:org/apache/sysds/runtime/controlprogram/ProgramBlock.class */
public abstract class ProgramBlock implements ParseInfo {
    public static final String PRED_VAR = "__pred";
    protected static final Log LOG = LogFactory.getLog(ProgramBlock.class.getName());
    private static final boolean CHECK_MATRIX_PROPERTIES = false;
    protected Program _prog;
    protected Instruction _exitInstruction = null;
    protected StatementBlock _sb = null;
    protected long _tid = 0;
    public String _filename;
    public int _beginLine;
    public int _beginColumn;
    public int _endLine;
    public int _endColumn;
    public String _text;

    public ProgramBlock(Program program) {
        this._prog = program;
    }

    public Program getProgram() {
        return this._prog;
    }

    public void setProgram(Program program) {
        this._prog = program;
    }

    public StatementBlock getStatementBlock() {
        return this._sb;
    }

    public void setStatementBlock(StatementBlock statementBlock) {
        this._sb = statementBlock;
    }

    public void setThreadID(long j) {
        this._tid = j;
    }

    public boolean hasThreadID() {
        return this._tid != 0;
    }

    public static boolean isThreadID(long j) {
        return j != 0;
    }

    public long getThreadID() {
        return this._tid;
    }

    public void setExitInstruction(Instruction instruction) {
        this._exitInstruction = instruction;
    }

    public Instruction getExitInstruction() {
        return this._exitInstruction;
    }

    public abstract ArrayList<ProgramBlock> getChildBlocks();

    public abstract boolean isNested();

    public abstract void execute(ExecutionContext executionContext);

    public ScalarObject executePredicate(ArrayList<Instruction> arrayList, Hop hop, boolean z, Types.ValueType valueType, ExecutionContext executionContext) {
        ArrayList<Instruction> arrayList2 = arrayList;
        try {
            long nanoTime = DMLScript.STATISTICS ? System.nanoTime() : 0L;
            if (ConfigurationManager.isDynamicRecompilation() && z) {
                arrayList2 = JMLCUtils.cleanupRuntimeInstructions(Recompiler.recompileHopsDag(hop, executionContext.getVariables(), null, false, true, this._tid), PRED_VAR);
            }
            if (DMLScript.STATISTICS) {
                RecompileStatistics.incrementRecompileTime(System.nanoTime() - nanoTime);
                if (arrayList2 != arrayList) {
                    RecompileStatistics.incrementRecompilePred();
                }
            }
            return executePredicateInstructions(arrayList2, valueType, executionContext);
        } catch (Exception e) {
            throw new DMLRuntimeException("Unable to recompile predicate instructions.", e);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void executeExitInstructions(String str, ExecutionContext executionContext) {
        try {
            if (this._exitInstruction != null) {
                executeSingleInstruction(this._exitInstruction, executionContext);
            }
        } catch (DMLScriptException e) {
            throw e;
        } catch (Exception e2) {
            throw new DMLRuntimeException(printBlockErrorLocation() + "Error evaluating " + str + " exit instructions ", e2);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void executeInstructions(ArrayList<Instruction> arrayList, ExecutionContext executionContext) {
        for (int i = 0; i < arrayList.size(); i++) {
            executeSingleInstruction(arrayList.get(i), executionContext);
        }
    }

    protected ScalarObject executePredicateInstructions(ArrayList<Instruction> arrayList, Types.ValueType valueType, ExecutionContext executionContext) {
        Iterator<Instruction> it = arrayList.iterator();
        while (it.hasNext()) {
            executeSingleInstruction(it.next(), executionContext);
        }
        ScalarObject scalarInput = executionContext.getScalarInput(PRED_VAR, valueType, false);
        if (valueType != null && valueType != scalarInput.getValueType()) {
            switch (valueType) {
                case BOOLEAN:
                    scalarInput = new BooleanObject(scalarInput.getBooleanValue());
                    break;
                case INT64:
                    scalarInput = new IntObject(scalarInput.getLongValue());
                    break;
                case FP64:
                    scalarInput = new DoubleObject(scalarInput.getDoubleValue());
                    break;
                case STRING:
                    scalarInput = new StringObject(scalarInput.getStringValue());
                    break;
            }
        }
        executionContext.removeVariable(PRED_VAR);
        return scalarInput;
    }

    private void executeSingleInstruction(Instruction instruction, ExecutionContext executionContext) {
        try {
            long nanoTime = (DMLScript.STATISTICS || LOG.isTraceEnabled()) ? System.nanoTime() : 0L;
            Instruction preprocessInstruction = instruction.preprocessInstruction(executionContext);
            if (!LineageCache.reuse(preprocessInstruction, executionContext)) {
                long nanoTime2 = (!LineageCacheConfig.ReuseCacheType.isNone() || DMLScript.LINEAGE_ESTIMATE) ? System.nanoTime() : 0L;
                preprocessInstruction.processInstruction(executionContext);
                LineageCache.putValue(preprocessInstruction, executionContext, nanoTime2);
                preprocessInstruction.postprocessInstruction(executionContext);
                if (DMLScript.STATISTICS) {
                    Statistics.maintainCPHeavyHitters(preprocessInstruction.getExtendedOpcode(), System.nanoTime() - nanoTime);
                }
            }
            PrivacyPropagator.postProcessInstruction(preprocessInstruction, executionContext);
            if (LOG.isTraceEnabled()) {
                LOG.trace("Instruction: " + preprocessInstruction + " (executed in " + String.format("%.3f", Double.valueOf((System.nanoTime() - nanoTime) / 1.0E9d)) + "s).");
            }
        } catch (DMLScriptException e) {
            throw e;
        } catch (Exception e2) {
            throw new DMLRuntimeException(printBlockErrorLocation() + "Error evaluating instruction: " + instruction.toString(), e2);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public MatrixObject.UpdateType[] prepareUpdateInPlaceVariables(ExecutionContext executionContext, long j) {
        if (this._sb == null || this._sb.getUpdateInPlaceVars().isEmpty()) {
            return null;
        }
        ArrayList<String> updateInPlaceVars = this._sb.getUpdateInPlaceVars();
        MatrixObject.UpdateType[] updateTypeArr = new MatrixObject.UpdateType[updateInPlaceVars.size()];
        for (int i = 0; i < updateTypeArr.length; i++) {
            String str = updateInPlaceVars.get(i);
            if (executionContext.isMatrixObject(str)) {
                MatrixObject matrixObject = executionContext.getMatrixObject(str);
                updateTypeArr[i] = matrixObject.getUpdateType();
                if (updateTypeArr[i] == MatrixObject.UpdateType.COPY && OptimizerUtils.getLocalMemBudget() / 2.0d > OptimizerUtils.estimateSizeExactSparsity(matrixObject.getDataCharacteristics())) {
                    MatrixObject matrixObject2 = new MatrixObject(matrixObject);
                    MatrixBlock acquireRead = matrixObject.acquireRead();
                    matrixObject2.acquireModify(acquireRead instanceof CompressedMatrixBlock ? new CompressedMatrixBlock((CompressedMatrixBlock) acquireRead) : !acquireRead.isInSparseFormat() ? new MatrixBlock(acquireRead) : new MatrixBlock(acquireRead, MatrixBlock.DEFAULT_INPLACE_SPARSEBLOCK, true));
                    matrixObject2.setFileName(matrixObject.getFileName() + "_uip" + j);
                    matrixObject.release();
                    if (executionContext.removeVariable(str) != null) {
                        executionContext.cleanupCacheableData(matrixObject);
                    }
                    matrixObject2.release();
                    matrixObject2.setUpdateType(MatrixObject.UpdateType.INPLACE);
                    executionContext.setVariable(str, matrixObject2);
                }
            }
        }
        return updateTypeArr;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void resetUpdateInPlaceVariableFlags(ExecutionContext executionContext, MatrixObject.UpdateType[] updateTypeArr) {
        if (updateTypeArr == null) {
            return;
        }
        ArrayList<String> updateInPlaceVars = this._sb.getUpdateInPlaceVars();
        for (int i = 0; i < updateInPlaceVars.size(); i++) {
            if (executionContext.getVariable(updateInPlaceVars.get(i)) != null && updateTypeArr[i] != null) {
                executionContext.getMatrixObject(updateInPlaceVars.get(i)).setUpdateType(updateTypeArr[i]);
            }
        }
    }

    private static void checkSparsity(Instruction instruction, LocalVariableMap localVariableMap, ExecutionContext executionContext) {
        for (String str : localVariableMap.keySet()) {
            Data data = localVariableMap.get(str);
            if (data instanceof MatrixObject) {
                MatrixObject matrixObject = (MatrixObject) data;
                if (matrixObject.isDirty() && !matrixObject.isPartitioned()) {
                    MatrixBlock acquireRead = matrixObject.acquireRead();
                    boolean isInSparseFormat = acquireRead.isInSparseFormat();
                    long nonZeros = acquireRead.getNonZeros();
                    synchronized (acquireRead) {
                        acquireRead.recomputeNonZeros();
                        acquireRead.examSparsity();
                    }
                    if (acquireRead.isInSparseFormat() && acquireRead.isAllocated()) {
                        acquireRead.getSparseBlock().checkValidity(acquireRead.getNumRows(), acquireRead.getNumColumns(), acquireRead.getNonZeros(), true);
                    }
                    boolean isInSparseFormat2 = acquireRead.isInSparseFormat();
                    long nonZeros2 = acquireRead.getNonZeros();
                    matrixObject.release();
                    if (nonZeros != nonZeros2) {
                        DMLRuntimeException dMLRuntimeException = new DMLRuntimeException("Matrix nnz meta data was incorrect: (" + str + ", actual=" + nonZeros + ", expected=" + dMLRuntimeException + ", inst=" + nonZeros2 + ")");
                        throw dMLRuntimeException;
                    }
                    if (isInSparseFormat != isInSparseFormat2 && acquireRead.isAllocated()) {
                        DMLRuntimeException dMLRuntimeException2 = new DMLRuntimeException("Matrix was in wrong data representation: (" + str + ", actual=" + isInSparseFormat + ", expected=" + isInSparseFormat2 + ", nrow=" + acquireRead.getNumRows() + ", ncol=" + acquireRead.getNumColumns() + ", nnz=" + nonZeros + ", inst=" + dMLRuntimeException2 + ")");
                        throw dMLRuntimeException2;
                    }
                }
                MetaData metaData = matrixObject.getMetaData();
                if (matrixObject.getRDDHandle() != null && (!(metaData instanceof MetaDataFormat) || ((MetaDataFormat) metaData).getFileFormat() == Types.FileFormat.BINARY)) {
                    SparkUtils.checkSparsity(str, executionContext);
                }
            }
        }
    }

    private static void checkFederated(LocalVariableMap localVariableMap) {
        Iterator<String> it = localVariableMap.keySet().iterator();
        while (it.hasNext()) {
            Data data = localVariableMap.get(it.next());
            if (data instanceof CacheableData) {
                CacheableData cacheableData = (CacheableData) data;
                if (cacheableData.isFederated() && cacheableData.getFedMapping().getMap().isEmpty()) {
                    throw new DMLRuntimeException("Invalid empty FederationMap for: " + cacheableData);
                }
            }
        }
    }

    @Override // org.apache.sysds.parser.ParseInfo
    public void setFilename(String str) {
        this._filename = str;
    }

    @Override // org.apache.sysds.parser.ParseInfo
    public void setBeginLine(int i) {
        this._beginLine = i;
    }

    @Override // org.apache.sysds.parser.ParseInfo
    public void setBeginColumn(int i) {
        this._beginColumn = i;
    }

    @Override // org.apache.sysds.parser.ParseInfo
    public void setEndLine(int i) {
        this._endLine = i;
    }

    @Override // org.apache.sysds.parser.ParseInfo
    public void setEndColumn(int i) {
        this._endColumn = i;
    }

    @Override // org.apache.sysds.parser.ParseInfo
    public void setText(String str) {
        this._text = str;
    }

    @Override // org.apache.sysds.parser.ParseInfo
    public String getFilename() {
        return this._filename;
    }

    @Override // org.apache.sysds.parser.ParseInfo
    public int getBeginLine() {
        return this._beginLine;
    }

    @Override // org.apache.sysds.parser.ParseInfo
    public int getBeginColumn() {
        return this._beginColumn;
    }

    @Override // org.apache.sysds.parser.ParseInfo
    public int getEndLine() {
        return this._endLine;
    }

    @Override // org.apache.sysds.parser.ParseInfo
    public int getEndColumn() {
        return this._endColumn;
    }

    @Override // org.apache.sysds.parser.ParseInfo
    public String getText() {
        return this._text;
    }

    public String printBlockErrorLocation() {
        return "ERROR: Runtime error in program block generated from statement block between lines " + this._beginLine + " and " + this._endLine + " -- ";
    }

    public void setParseInfo(ParseInfo parseInfo) {
        this._beginLine = parseInfo.getBeginLine();
        this._beginColumn = parseInfo.getBeginColumn();
        this._endLine = parseInfo.getEndLine();
        this._endColumn = parseInfo.getEndColumn();
        this._text = parseInfo.getText();
        this._filename = parseInfo.getFilename();
    }
}
