package org.apache.sysds.hops.ipa;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.HopsException;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.recompile.Recompiler;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.DMLTranslator;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.parser.IfStatement;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;
import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObjectFactory;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.meta.MetaDataFormat;
import org.apache.sysds.utils.Explain;

/* loaded from: input_file:org/apache/sysds/hops/ipa/InterProceduralAnalysis.class */
public class InterProceduralAnalysis {
    private static final boolean LDEBUG = false;
    private static final Log LOG = LogFactory.getLog(InterProceduralAnalysis.class.getName());
    protected static final boolean INTRA_PROCEDURAL_ANALYSIS = true;
    protected static final boolean PROPAGATE_KNOWN_UDF_STATISTICS = true;
    protected static final boolean ALLOW_MULTIPLE_FUNCTION_CALLS = true;
    protected static final boolean REMOVE_UNUSED_FUNCTIONS = true;
    protected static final boolean FLAG_FUNCTION_RECOMPILE_ONCE = true;
    protected static final boolean REMOVE_UNNECESSARY_CHECKPOINTS = true;
    protected static final boolean REMOVE_CONSTANT_BINARY_OPS = true;
    protected static final boolean PROPAGATE_SCALAR_VARS_INTO_FUN = true;
    protected static final boolean PROPAGATE_SCALAR_LITERALS = true;
    protected static final boolean APPLY_STATIC_REWRITES = true;
    protected static final boolean APPLY_DYNAMIC_REWRITES = true;
    protected static final int INLINING_MAX_NUM_OPS = 10;
    protected static final boolean ELIMINATE_DEAD_CODE = true;
    protected static final boolean FORWARD_SIMPLE_FUN_CALLS = true;
    protected static final boolean FLAG_NONDETERMINISM = true;
    private final DMLProgram _prog;
    private final StatementBlock _sb;
    private FunctionCallGraph _fgraph;
    private final ArrayList<IPAPass> _passes;

    public InterProceduralAnalysis(DMLProgram dMLProgram) {
        this._prog = dMLProgram;
        this._sb = null;
        this._fgraph = new FunctionCallGraph(dMLProgram);
        if (LOG.isDebugEnabled()) {
            LOG.debug("IPA: Initial FunctionCallGraph: \n--MAIN PROGRAM\n" + Explain.explainFunctionCallGraph(this._fgraph, new HashSet(), null, 1));
        }
        this._passes = new ArrayList<>();
        this._passes.add(new IPAPassRemoveUnusedFunctions());
        this._passes.add(new IPAPassFlagFunctionsRecompileOnce());
        this._passes.add(new IPAPassRemoveUnnecessaryCheckpoints());
        this._passes.add(new IPAPassRemoveConstantBinaryOps());
        this._passes.add(new IPAPassPropagateReplaceLiterals());
        this._passes.add(new IPAPassInlineFunctions());
        this._passes.add(new IPAPassReplaceEvalFunctionCalls());
        this._passes.add(new IPAPassEliminateDeadCode());
        this._passes.add(new IPAPassFlagNonDeterminism());
        this._passes.add(new IPAPassForwardFunctionCalls());
        this._passes.add(new IPAPassApplyStaticAndDynamicHopRewrites());
    }

    public InterProceduralAnalysis(StatementBlock statementBlock) {
        this._prog = statementBlock.getDMLProg();
        this._sb = statementBlock;
        this._fgraph = new FunctionCallGraph(statementBlock);
        this._passes = new ArrayList<>();
    }

    public void analyzeProgram() {
        analyzeProgram(1);
    }

    /* JADX WARN: Removed duplicated region for block: B:67:0x0206  */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public void analyzeProgram(int r8) {
        /*
            Method dump skipped, instructions count: 555
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: org.apache.sysds.hops.ipa.InterProceduralAnalysis.analyzeProgram(int):void");
    }

    public Set<String> analyzeSubProgram() {
        DMLTranslator.resetHopsDAGVisitStatus(this._sb);
        FunctionCallSizeInfo functionCallSizeInfo = new FunctionCallSizeInfo(this._fgraph);
        if (!functionCallSizeInfo.getValidFunctions().isEmpty()) {
            propagateStatisticsAcrossBlock(this._sb, new LocalVariableMap(), functionCallSizeInfo, new HashSet(), true);
        }
        return functionCallSizeInfo.getValidFunctions();
    }

    private boolean isUnarySizePreservingFunction(FunctionStatementBlock functionStatementBlock) {
        FunctionStatement functionStatement = (FunctionStatement) functionStatementBlock.getStatement(0);
        boolean z = functionStatement.getInputParams().size() == 1 && functionStatement.getInputParams().get(0).getDataType() == Types.DataType.MATRIX && functionStatement.getOutputParams().size() == 1 && functionStatement.getOutputParams().get(0).getDataType() == Types.DataType.MATRIX;
        if (z) {
            FunctionCallSizeInfo functionCallSizeInfo = new FunctionCallSizeInfo(this._fgraph, false);
            Set<String> hashSet = new HashSet<>();
            LocalVariableMap localVariableMap = new LocalVariableMap();
            MatrixObject createOutputMatrix = createOutputMatrix(7777L, 3333L, -1L);
            localVariableMap.put(functionStatement.getInputParams().get(0).getName(), createOutputMatrix);
            Iterator<StatementBlock> it = functionStatement.getBody().iterator();
            while (it.hasNext()) {
                propagateStatisticsAcrossBlock(it.next(), localVariableMap, functionCallSizeInfo, hashSet, false);
            }
            MatrixObject matrixObject = (MatrixObject) localVariableMap.get(functionStatement.getOutputParams().get(0).getName());
            z &= createOutputMatrix.getNumRows() == matrixObject.getNumRows() && createOutputMatrix.getNumColumns() == matrixObject.getNumColumns();
            createOutputMatrix.getDataCharacteristics().setDimension(-1L, -1L);
            localVariableMap.put(functionStatement.getInputParams().get(0).getName(), createOutputMatrix);
            Iterator<StatementBlock> it2 = functionStatement.getBody().iterator();
            while (it2.hasNext()) {
                propagateStatisticsAcrossBlock(it2.next(), localVariableMap, functionCallSizeInfo, hashSet, false);
            }
        }
        return z;
    }

    private void propagateStatisticsAcrossBlock(StatementBlock statementBlock, LocalVariableMap localVariableMap, FunctionCallSizeInfo functionCallSizeInfo, Set<String> set, boolean z) {
        if (statementBlock instanceof FunctionStatementBlock) {
            Iterator<StatementBlock> it = ((FunctionStatement) ((FunctionStatementBlock) statementBlock).getStatement(0)).getBody().iterator();
            while (it.hasNext()) {
                propagateStatisticsAcrossBlock(it.next(), localVariableMap, functionCallSizeInfo, set, z);
            }
            return;
        }
        if (statementBlock instanceof WhileStatementBlock) {
            WhileStatementBlock whileStatementBlock = (WhileStatementBlock) statementBlock;
            WhileStatement whileStatement = (WhileStatement) whileStatementBlock.getStatement(0);
            propagateStatisticsAcrossPredicateDAG(whileStatementBlock.getPredicateHops(), localVariableMap);
            Recompiler.removeUpdatedScalars(localVariableMap, whileStatementBlock);
            LocalVariableMap localVariableMap2 = (LocalVariableMap) localVariableMap.clone();
            Iterator<StatementBlock> it2 = whileStatement.getBody().iterator();
            while (it2.hasNext()) {
                propagateStatisticsAcrossBlock(it2.next(), localVariableMap, functionCallSizeInfo, set, z);
            }
            if (Recompiler.reconcileUpdatedCallVarsLoops(localVariableMap2, localVariableMap, whileStatementBlock)) {
                propagateStatisticsAcrossPredicateDAG(whileStatementBlock.getPredicateHops(), localVariableMap);
                Iterator<StatementBlock> it3 = whileStatement.getBody().iterator();
                while (it3.hasNext()) {
                    propagateStatisticsAcrossBlock(it3.next(), localVariableMap, functionCallSizeInfo, set, z);
                }
            }
            Recompiler.removeUpdatedScalars(localVariableMap, statementBlock);
            return;
        }
        if (statementBlock instanceof IfStatementBlock) {
            IfStatementBlock ifStatementBlock = (IfStatementBlock) statementBlock;
            IfStatement ifStatement = (IfStatement) ifStatementBlock.getStatement(0);
            propagateStatisticsAcrossPredicateDAG(ifStatementBlock.getPredicateHops(), localVariableMap);
            LocalVariableMap localVariableMap3 = (LocalVariableMap) localVariableMap.clone();
            LocalVariableMap localVariableMap4 = (LocalVariableMap) localVariableMap.clone();
            Iterator<StatementBlock> it4 = ifStatement.getIfBody().iterator();
            while (it4.hasNext()) {
                propagateStatisticsAcrossBlock(it4.next(), localVariableMap, functionCallSizeInfo, set, z);
            }
            Iterator<StatementBlock> it5 = ifStatement.getElseBody().iterator();
            while (it5.hasNext()) {
                propagateStatisticsAcrossBlock(it5.next(), localVariableMap4, functionCallSizeInfo, set, z);
            }
            Recompiler.removeUpdatedScalars(Recompiler.reconcileUpdatedCallVarsIf(localVariableMap3, localVariableMap, localVariableMap4, ifStatementBlock), statementBlock);
            return;
        }
        if (!(statementBlock instanceof ForStatementBlock)) {
            Recompiler.removeUpdatedScalars(localVariableMap, statementBlock);
            ArrayList<Hop> hops = statementBlock.getHops();
            DMLProgram dMLProg = statementBlock.getDMLProg();
            if (z) {
                Hop.resetVisitStatus(hops);
                propagateScalarsAcrossDAG(hops, localVariableMap);
            }
            Hop.resetVisitStatus(hops);
            propagateStatisticsAcrossDAG(hops, localVariableMap);
            Hop.resetVisitStatus(hops);
            propagateStatisticsIntoFunctions(dMLProg, hops, localVariableMap, functionCallSizeInfo, set, z);
            return;
        }
        ForStatementBlock forStatementBlock = (ForStatementBlock) statementBlock;
        ForStatement forStatement = (ForStatement) forStatementBlock.getStatement(0);
        propagateStatisticsAcrossPredicateDAG(forStatementBlock.getFromHops(), localVariableMap);
        propagateStatisticsAcrossPredicateDAG(forStatementBlock.getToHops(), localVariableMap);
        propagateStatisticsAcrossPredicateDAG(forStatementBlock.getIncrementHops(), localVariableMap);
        Recompiler.removeUpdatedScalars(localVariableMap, forStatementBlock);
        LocalVariableMap localVariableMap5 = (LocalVariableMap) localVariableMap.clone();
        Iterator<StatementBlock> it6 = forStatement.getBody().iterator();
        while (it6.hasNext()) {
            propagateStatisticsAcrossBlock(it6.next(), localVariableMap, functionCallSizeInfo, set, z);
        }
        if (Recompiler.reconcileUpdatedCallVarsLoops(localVariableMap5, localVariableMap, forStatementBlock)) {
            Iterator<StatementBlock> it7 = forStatement.getBody().iterator();
            while (it7.hasNext()) {
                propagateStatisticsAcrossBlock(it7.next(), localVariableMap, functionCallSizeInfo, set, z);
            }
        }
        Recompiler.removeUpdatedScalars(localVariableMap, statementBlock);
    }

    private static void propagateScalarsAcrossDAG(ArrayList<Hop> arrayList, LocalVariableMap localVariableMap) {
        Iterator<Hop> it = arrayList.iterator();
        while (it.hasNext()) {
            try {
                Recompiler.rReplaceLiterals(it.next(), localVariableMap, true);
            } catch (Exception e) {
                throw new HopsException("Failed to perform scalar literal replacement.", e);
            }
        }
    }

    private static void propagateStatisticsAcrossPredicateDAG(Hop hop, LocalVariableMap localVariableMap) {
        if (hop == null) {
            return;
        }
        hop.resetVisitStatus();
        try {
            Recompiler.rUpdateStatistics(hop, localVariableMap);
        } catch (Exception e) {
            throw new HopsException("Failed to update Hop DAG statistics.", e);
        }
    }

    private static void propagateStatisticsAcrossDAG(ArrayList<Hop> arrayList, LocalVariableMap localVariableMap) {
        if (arrayList == null) {
            return;
        }
        try {
            Iterator<Hop> it = arrayList.iterator();
            while (it.hasNext()) {
                Recompiler.rUpdateStatistics(it.next(), localVariableMap);
            }
            Recompiler.extractDAGOutputStatistics(arrayList, localVariableMap, true);
        } catch (Exception e) {
            throw new HopsException("Failed to update Hop DAG statistics.", e);
        }
    }

    private void propagateStatisticsIntoFunctions(DMLProgram dMLProgram, ArrayList<Hop> arrayList, LocalVariableMap localVariableMap, FunctionCallSizeInfo functionCallSizeInfo, Set<String> set, boolean z) {
        Iterator<Hop> it = arrayList.iterator();
        while (it.hasNext()) {
            propagateStatisticsIntoFunctions(dMLProgram, it.next(), localVariableMap, functionCallSizeInfo, set, z);
        }
    }

    private void propagateStatisticsIntoFunctions(DMLProgram dMLProgram, Hop hop, LocalVariableMap localVariableMap, FunctionCallSizeInfo functionCallSizeInfo, Set<String> set, boolean z) {
        if (hop.isVisited()) {
            return;
        }
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            propagateStatisticsIntoFunctions(dMLProgram, it.next(), localVariableMap, functionCallSizeInfo, set, z);
        }
        if (hop instanceof FunctionOp) {
            FunctionOp functionOp = (FunctionOp) hop;
            String functionKey = functionOp.getFunctionKey();
            if (functionOp.getFunctionType() == FunctionOp.FunctionType.DML) {
                FunctionStatementBlock functionStatementBlock = dMLProgram.getFunctionStatementBlock(functionOp.getFunctionNamespace(), functionOp.getFunctionName());
                FunctionStatement functionStatement = (FunctionStatement) functionStatementBlock.getStatement(0);
                if (functionCallSizeInfo.isValidFunction(functionKey) && !set.contains(functionKey)) {
                    set.add(functionKey);
                    LocalVariableMap localVariableMap2 = new LocalVariableMap();
                    populateLocalVariableMapForFunctionCall(functionStatement, functionOp, localVariableMap, localVariableMap2, functionCallSizeInfo);
                    propagateStatisticsAcrossBlock(functionStatementBlock, localVariableMap2, functionCallSizeInfo, set, z);
                    extractFunctionCallReturnStatistics(functionStatement, functionOp, localVariableMap2, localVariableMap, true);
                    set.remove(functionKey);
                } else if (functionCallSizeInfo.isDimsPreservingFunction(functionKey)) {
                    extractFunctionCallEquivalentReturnStatistics(functionStatement, functionOp, localVariableMap);
                } else {
                    extractFunctionCallUnknownReturnStatistics(functionStatement, functionOp, localVariableMap);
                }
            }
        }
        hop.setVisited();
    }

    private static void populateLocalVariableMapForFunctionCall(FunctionStatement functionStatement, FunctionOp functionOp, LocalVariableMap localVariableMap, LocalVariableMap localVariableMap2, FunctionCallSizeInfo functionCallSizeInfo) {
        Data data;
        String[] inputVariableNames = functionOp.getInputVariableNames();
        ArrayList<Hop> input = functionOp.getInput();
        String functionKey = functionOp.getFunctionKey();
        for (int i = 0; i < Math.min(input.size(), inputVariableNames.length); i++) {
            DataIdentifier inputParam = functionStatement.getInputParam(inputVariableNames[i]);
            if (inputParam == null) {
                throw new HopsException("Failed IPA: function argument '" + inputVariableNames[i] + "' does not exist in function signature of " + functionOp.getFunctionKey() + ".");
            }
            Hop hop = input.get(i);
            if (hop.getDataType() == Types.DataType.MATRIX) {
                MatrixObject matrixObject = new MatrixObject(Types.ValueType.FP64, null);
                matrixObject.setMetaData(new MetaDataFormat(new MatrixCharacteristics(hop.getDim1(), hop.getDim2(), ConfigurationManager.getBlocksize(), functionCallSizeInfo.isSafeNnz(functionKey, i) ? hop.getNnz() : -1L), null));
                localVariableMap2.put(inputParam.getName(), matrixObject);
            } else if (hop.getDataType() == Types.DataType.SCALAR) {
                if (hop instanceof LiteralOp) {
                    localVariableMap2.put(inputParam.getName(), ScalarObjectFactory.createScalarObject(hop.getValueType(), (LiteralOp) hop));
                } else if (functionCallSizeInfo.getFunctionCallCount(functionKey) == 1 && (hop instanceof DataOp) && (data = localVariableMap.get(hop.getName())) != null && (data instanceof ScalarObject)) {
                    localVariableMap2.put(inputParam.getName(), data);
                }
            }
        }
    }

    private static void extractFunctionCallReturnStatistics(FunctionStatement functionStatement, FunctionOp functionOp, LocalVariableMap localVariableMap, LocalVariableMap localVariableMap2, boolean z) {
        ArrayList<DataIdentifier> outputParams = functionStatement.getOutputParams();
        String[] outputVariableNames = functionOp.getOutputVariableNames();
        String functionKey = functionOp.getFunctionKey();
        for (int i = 0; i < outputParams.size() && outputVariableNames.length > i; i++) {
            try {
                DataIdentifier dataIdentifier = outputParams.get(i);
                String name = dataIdentifier.getName();
                String str = outputVariableNames[i];
                if (localVariableMap2.keySet().contains(str) && dataIdentifier.getDataType() != localVariableMap2.get(str).getDataType()) {
                    localVariableMap2.remove(str);
                }
                if (dataIdentifier.getDataType() == Types.DataType.MATRIX && localVariableMap.keySet().contains(name)) {
                    MatrixObject matrixObject = (MatrixObject) localVariableMap.get(name);
                    if (!localVariableMap2.keySet().contains(str) || z) {
                        localVariableMap2.put(str, createOutputMatrix(matrixObject.getNumRows(), matrixObject.getNumColumns(), matrixObject.getNnz()));
                    } else {
                        Data data = localVariableMap2.get(str);
                        if (data instanceof MatrixObject) {
                            DataCharacteristics dataCharacteristics = ((MatrixObject) data).getDataCharacteristics();
                            if (OptimizerUtils.estimateSizeExactSparsity(dataCharacteristics.getRows(), dataCharacteristics.getCols(), dataCharacteristics.getNonZeros() > 0 ? OptimizerUtils.getSparsity(dataCharacteristics) : 1.0d) < OptimizerUtils.estimateSize(matrixObject.getNumRows(), matrixObject.getNumColumns())) {
                                dataCharacteristics.setDimension(matrixObject.getNumRows(), matrixObject.getNumColumns());
                                dataCharacteristics.setNonZeros(matrixObject.getNnz());
                            }
                        }
                    }
                }
            } catch (Exception e) {
                throw new HopsException("Failed to extract output statistics of function " + functionKey + ".", e);
            }
        }
    }

    private static void extractFunctionCallUnknownReturnStatistics(FunctionStatement functionStatement, FunctionOp functionOp, LocalVariableMap localVariableMap) {
        ArrayList<DataIdentifier> outputParams = functionStatement.getOutputParams();
        String[] outputVariableNames = functionOp.getOutputVariableNames();
        String functionKey = functionOp.getFunctionKey();
        try {
            int min = Math.min(outputParams.size(), outputVariableNames.length);
            for (int i = 0; i < min; i++) {
                DataIdentifier dataIdentifier = outputParams.get(i);
                String str = outputVariableNames[i];
                if (dataIdentifier.getDataType() == Types.DataType.MATRIX) {
                    localVariableMap.put(str, createOutputMatrix(-1L, -1L, -1L));
                }
            }
        } catch (Exception e) {
            throw new HopsException("Failed to extract output statistics of function " + functionKey + ".", e);
        }
    }

    private static void extractFunctionCallEquivalentReturnStatistics(FunctionStatement functionStatement, FunctionOp functionOp, LocalVariableMap localVariableMap) {
        try {
            Hop hop = functionOp.getInput().get(0);
            localVariableMap.put(functionOp.getOutputVariableNames()[0], createOutputMatrix(hop.getDim1(), hop.getDim2(), -1L));
        } catch (Exception e) {
            throw new HopsException("Failed to extract output statistics for unary function " + functionOp.getFunctionKey() + ".", e);
        }
    }

    private static MatrixObject createOutputMatrix(long j, long j2, long j3) {
        MatrixObject matrixObject = new MatrixObject(Types.ValueType.FP64, null);
        matrixObject.setMetaData(new MetaDataFormat(new MatrixCharacteristics(j, j2, ConfigurationManager.getBlocksize(), j3), null));
        return matrixObject;
    }
}
