package org.apache.sysds.hops.rewrite;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.conf.CompilerConfig;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.conf.DMLConfig;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.fedplanner.FTypes;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.parser.IfStatement;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.ParForStatementBlock;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;
import org.apache.sysds.runtime.lineage.LineageCacheConfig;

/* loaded from: input_file:org/apache/sysds/hops/rewrite/ProgramRewriter.class */
public class ProgramRewriter {
    private static final boolean LDEBUG = false;
    private static final boolean CHECK = false;
    private ArrayList<HopRewriteRule> _dagRuleSet;
    private ArrayList<StatementBlockRewriteRule> _sbRuleSet;

    public ProgramRewriter() {
        this(true, true);
    }

    public ProgramRewriter(boolean z, boolean z2) {
        this._dagRuleSet = null;
        this._sbRuleSet = null;
        this._dagRuleSet = new ArrayList<>();
        this._sbRuleSet = new ArrayList<>();
        if (z) {
            this._dagRuleSet.add(new RewriteTransientWriteParentHandling());
            this._dagRuleSet.add(new RewriteRemoveReadAfterWrite());
            this._dagRuleSet.add(new RewriteBlockSizeAndReblock());
            if (OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION) {
                this._dagRuleSet.add(new RewriteRemoveUnnecessaryCasts());
            }
            if (OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION) {
                this._dagRuleSet.add(new RewriteCommonSubexpressionElimination());
            }
            if (OptimizerUtils.ALLOW_CONSTANT_FOLDING) {
                this._dagRuleSet.add(new RewriteConstantFolding());
            }
            if (OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION) {
                this._dagRuleSet.add(new RewriteAlgebraicSimplificationStatic());
            }
            if (OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION) {
                this._dagRuleSet.add(new RewriteCommonSubexpressionElimination());
            }
            if (OptimizerUtils.ALLOW_AUTO_VECTORIZATION) {
                this._dagRuleSet.add(new RewriteIndexingVectorization());
            }
            this._dagRuleSet.add(new RewriteInjectSparkPReadCheckpointing());
            if (OptimizerUtils.ALLOW_BRANCH_REMOVAL) {
                this._sbRuleSet.add(new RewriteRemoveUnnecessaryBranches());
            }
            if (OptimizerUtils.ALLOW_FOR_LOOP_REMOVAL) {
                this._sbRuleSet.add(new RewriteRemoveForLoopEmptySequence());
            }
            if (OptimizerUtils.ALLOW_BRANCH_REMOVAL || OptimizerUtils.ALLOW_FOR_LOOP_REMOVAL) {
                this._sbRuleSet.add(new RewriteMergeBlockSequence());
            }
            if (OptimizerUtils.ALLOW_COMPRESSION_REWRITE) {
                this._sbRuleSet.add(new RewriteCompressedReblock());
            }
            if (OptimizerUtils.ALLOW_SPLIT_HOP_DAGS) {
                this._sbRuleSet.add(new RewriteSplitDagUnknownCSVRead());
            }
            if (OptimizerUtils.ALLOW_SPLIT_HOP_DAGS && ConfigurationManager.getCompilerConfigFlag(CompilerConfig.ConfigType.ALLOW_INDIVIDUAL_SB_SPECIFIC_OPS)) {
                this._sbRuleSet.add(new RewriteSplitDagDataDependentOperators());
            }
            if (OptimizerUtils.ALLOW_AUTO_VECTORIZATION) {
                this._sbRuleSet.add(new RewriteForLoopVectorization());
            }
            this._sbRuleSet.add(new RewriteInjectSparkLoopCheckpointing(true));
            if (OptimizerUtils.ALLOW_CODE_MOTION) {
                this._sbRuleSet.add(new RewriteHoistLoopInvariantOperations());
            }
            if (OptimizerUtils.ALLOW_LOOP_UPDATE_IN_PLACE) {
                this._sbRuleSet.add(new RewriteMarkLoopVariablesUpdateInPlace());
            }
            if (LineageCacheConfig.getCompAssRW()) {
                this._sbRuleSet.add(new MarkForLineageReuse());
            }
        }
        if (z2) {
            if (DMLScript.USE_ACCELERATOR) {
                this._dagRuleSet.add(new RewriteGPUSpecificOps());
            }
            if (OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES) {
                this._dagRuleSet.add(new RewriteMatrixMultChainOptimization());
                this._dagRuleSet.add(new RewriteElementwiseMultChainOptimization());
            }
            if (OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION) {
                this._dagRuleSet.add(new RewriteAlgebraicSimplificationDynamic());
                this._dagRuleSet.add(new RewriteAlgebraicSimplificationStatic());
            }
            String textValue = ConfigurationManager.getDMLConfig().getTextValue(DMLConfig.FEDERATED_PLANNER);
            if (OptimizerUtils.FEDERATED_COMPILATION || FTypes.FederatedPlanner.isCompiled(textValue)) {
                this._dagRuleSet.add(new RewriteFederatedExecution());
            }
        }
        if (OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION) {
            this._dagRuleSet.add(new RewriteRemoveUnnecessaryCasts());
        }
        if (OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION) {
            this._dagRuleSet.add(new RewriteCommonSubexpressionElimination(true));
        }
        if (OptimizerUtils.ALLOW_CONSTANT_FOLDING) {
            this._dagRuleSet.add(new RewriteConstantFolding());
        }
        this._sbRuleSet.add(new RewriteRemoveEmptyBasicBlocks());
    }

    public ProgramRewriter(HopRewriteRule... hopRewriteRuleArr) {
        this._dagRuleSet = null;
        this._sbRuleSet = null;
        this._dagRuleSet = new ArrayList<>();
        for (HopRewriteRule hopRewriteRule : hopRewriteRuleArr) {
            this._dagRuleSet.add(hopRewriteRule);
        }
        this._sbRuleSet = new ArrayList<>();
    }

    public ProgramRewriter(StatementBlockRewriteRule... statementBlockRewriteRuleArr) {
        this._dagRuleSet = null;
        this._sbRuleSet = null;
        this._dagRuleSet = new ArrayList<>();
        this._sbRuleSet = new ArrayList<>();
        for (StatementBlockRewriteRule statementBlockRewriteRule : statementBlockRewriteRuleArr) {
            this._sbRuleSet.add(statementBlockRewriteRule);
        }
    }

    public ProgramRewriter(ArrayList<HopRewriteRule> arrayList, ArrayList<StatementBlockRewriteRule> arrayList2) {
        this._dagRuleSet = null;
        this._sbRuleSet = null;
        this._dagRuleSet = new ArrayList<>();
        this._dagRuleSet.addAll(arrayList);
        this._sbRuleSet = new ArrayList<>();
        this._sbRuleSet.addAll(arrayList2);
    }

    public void removeHopRewrite(Class<? extends HopRewriteRule> cls) {
        this._dagRuleSet.removeIf(hopRewriteRule -> {
            return hopRewriteRule.getClass().equals(cls);
        });
    }

    public void removeStatementBlockRewrite(Class<? extends StatementBlockRewriteRule> cls) {
        this._sbRuleSet.removeIf(statementBlockRewriteRule -> {
            return statementBlockRewriteRule.getClass().equals(cls);
        });
    }

    public ProgramRewriteStatus rewriteProgramHopDAGs(DMLProgram dMLProgram) {
        return rewriteProgramHopDAGs(dMLProgram, true);
    }

    public ProgramRewriteStatus rewriteProgramHopDAGs(DMLProgram dMLProgram, boolean z) {
        return rewriteProgramHopDAGs(dMLProgram, z, new ProgramRewriteStatus());
    }

    public ProgramRewriteStatus rewriteProgramHopDAGs(DMLProgram dMLProgram, boolean z, ProgramRewriteStatus programRewriteStatus) {
        for (String str : dMLProgram.getNamespaces().keySet()) {
            Iterator<String> it = dMLProgram.getFunctionStatementBlocks(str).keySet().iterator();
            while (it.hasNext()) {
                rewriteHopDAGsFunction(dMLProgram.getFunctionStatementBlock(str, it.next()), programRewriteStatus, z);
            }
        }
        for (int i = 0; i < dMLProgram.getNumStatementBlocks(); i++) {
            rRewriteStatementBlockHopDAGs(dMLProgram.getStatementBlock(i), programRewriteStatus);
        }
        if (!this._sbRuleSet.isEmpty()) {
            dMLProgram.setStatementBlocks(rRewriteStatementBlocks(dMLProgram.getStatementBlocks(), programRewriteStatus, z));
        }
        return programRewriteStatus;
    }

    public void rewriteHopDAGsFunction(FunctionStatementBlock functionStatementBlock, boolean z) {
        rewriteHopDAGsFunction(functionStatementBlock, new ProgramRewriteStatus(), z);
    }

    public void rewriteHopDAGsFunction(FunctionStatementBlock functionStatementBlock, ProgramRewriteStatus programRewriteStatus, boolean z) {
        rRewriteStatementBlockHopDAGs(functionStatementBlock, programRewriteStatus);
        if (this._sbRuleSet.isEmpty()) {
            return;
        }
        rRewriteStatementBlock(functionStatementBlock, programRewriteStatus, z);
    }

    public void rRewriteStatementBlockHopDAGs(StatementBlock statementBlock, ProgramRewriteStatus programRewriteStatus) {
        if (programRewriteStatus == null) {
            programRewriteStatus = new ProgramRewriteStatus();
        }
        if (statementBlock instanceof FunctionStatementBlock) {
            Iterator<StatementBlock> it = ((FunctionStatement) ((FunctionStatementBlock) statementBlock).getStatement(0)).getBody().iterator();
            while (it.hasNext()) {
                rRewriteStatementBlockHopDAGs(it.next(), programRewriteStatus);
            }
            return;
        }
        if (statementBlock instanceof WhileStatementBlock) {
            WhileStatementBlock whileStatementBlock = (WhileStatementBlock) statementBlock;
            WhileStatement whileStatement = (WhileStatement) whileStatementBlock.getStatement(0);
            whileStatementBlock.setPredicateHops(rewriteHopDAG(whileStatementBlock.getPredicateHops(), programRewriteStatus));
            Iterator<StatementBlock> it2 = whileStatement.getBody().iterator();
            while (it2.hasNext()) {
                rRewriteStatementBlockHopDAGs(it2.next(), programRewriteStatus);
            }
            return;
        }
        if (statementBlock instanceof IfStatementBlock) {
            IfStatementBlock ifStatementBlock = (IfStatementBlock) statementBlock;
            IfStatement ifStatement = (IfStatement) ifStatementBlock.getStatement(0);
            ifStatementBlock.setPredicateHops(rewriteHopDAG(ifStatementBlock.getPredicateHops(), programRewriteStatus));
            Iterator<StatementBlock> it3 = ifStatement.getIfBody().iterator();
            while (it3.hasNext()) {
                rRewriteStatementBlockHopDAGs(it3.next(), programRewriteStatus);
            }
            Iterator<StatementBlock> it4 = ifStatement.getElseBody().iterator();
            while (it4.hasNext()) {
                rRewriteStatementBlockHopDAGs(it4.next(), programRewriteStatus);
            }
            return;
        }
        if (!(statementBlock instanceof ForStatementBlock)) {
            statementBlock.setHops(rewriteHopDAG(statementBlock.getHops(), programRewriteStatus));
            return;
        }
        ForStatementBlock forStatementBlock = (ForStatementBlock) statementBlock;
        ForStatement forStatement = (ForStatement) forStatementBlock.getStatement(0);
        forStatementBlock.setFromHops(rewriteHopDAG(forStatementBlock.getFromHops(), programRewriteStatus));
        forStatementBlock.setToHops(rewriteHopDAG(forStatementBlock.getToHops(), programRewriteStatus));
        forStatementBlock.setIncrementHops(rewriteHopDAG(forStatementBlock.getIncrementHops(), programRewriteStatus));
        Iterator<StatementBlock> it5 = forStatement.getBody().iterator();
        while (it5.hasNext()) {
            rRewriteStatementBlockHopDAGs(it5.next(), programRewriteStatus);
        }
    }

    public ArrayList<Hop> rewriteHopDAG(ArrayList<Hop> arrayList, ProgramRewriteStatus programRewriteStatus) {
        Iterator<HopRewriteRule> it = this._dagRuleSet.iterator();
        while (it.hasNext()) {
            HopRewriteRule next = it.next();
            Hop.resetVisitStatus(arrayList);
            arrayList = next.rewriteHopDAGs(arrayList, programRewriteStatus);
        }
        return arrayList;
    }

    public Hop rewriteHopDAG(Hop hop, ProgramRewriteStatus programRewriteStatus) {
        if (hop == null) {
            return null;
        }
        Iterator<HopRewriteRule> it = this._dagRuleSet.iterator();
        while (it.hasNext()) {
            HopRewriteRule next = it.next();
            hop.resetVisitStatus();
            hop = next.rewriteHopDAG(hop, programRewriteStatus);
        }
        return hop;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public ArrayList<StatementBlock> rRewriteStatementBlocks(ArrayList<StatementBlock> arrayList, ProgramRewriteStatus programRewriteStatus, boolean z) {
        if (programRewriteStatus == null) {
            programRewriteStatus = new ProgramRewriteStatus();
        }
        List list = arrayList;
        Iterator<StatementBlockRewriteRule> it = this._sbRuleSet.iterator();
        while (it.hasNext()) {
            StatementBlockRewriteRule next = it.next();
            if (z || !next.createsSplitDag()) {
                list = next.rewriteStatementBlocks(list, programRewriteStatus);
            }
        }
        List arrayList2 = new ArrayList();
        Iterator<StatementBlock> it2 = list.iterator();
        while (it2.hasNext()) {
            arrayList2.addAll(rRewriteStatementBlock(it2.next(), programRewriteStatus, z));
        }
        Iterator<StatementBlockRewriteRule> it3 = this._sbRuleSet.iterator();
        while (it3.hasNext()) {
            StatementBlockRewriteRule next2 = it3.next();
            if (z || !next2.createsSplitDag()) {
                arrayList2 = next2.rewriteStatementBlocks(arrayList2, programRewriteStatus);
            }
        }
        arrayList.clear();
        arrayList.addAll(arrayList2);
        return arrayList;
    }

    public ArrayList<StatementBlock> rRewriteStatementBlock(StatementBlock statementBlock, ProgramRewriteStatus programRewriteStatus, boolean z) {
        ArrayList<StatementBlock> arrayList = new ArrayList<>();
        arrayList.add(statementBlock);
        if (statementBlock instanceof FunctionStatementBlock) {
            FunctionStatement functionStatement = (FunctionStatement) ((FunctionStatementBlock) statementBlock).getStatement(0);
            functionStatement.setBody(rRewriteStatementBlocks(functionStatement.getBody(), programRewriteStatus, z));
        } else if (statementBlock instanceof WhileStatementBlock) {
            WhileStatement whileStatement = (WhileStatement) ((WhileStatementBlock) statementBlock).getStatement(0);
            whileStatement.setBody(rRewriteStatementBlocks(whileStatement.getBody(), programRewriteStatus, z));
        } else if (statementBlock instanceof IfStatementBlock) {
            IfStatement ifStatement = (IfStatement) ((IfStatementBlock) statementBlock).getStatement(0);
            ifStatement.setIfBody(rRewriteStatementBlocks(ifStatement.getIfBody(), programRewriteStatus, z));
            ifStatement.setElseBody(rRewriteStatementBlocks(ifStatement.getElseBody(), programRewriteStatus, z));
        } else if (statementBlock instanceof ForStatementBlock) {
            boolean isInParforContext = programRewriteStatus.isInParforContext();
            if (statementBlock instanceof ParForStatementBlock) {
                programRewriteStatus.setInParforContext(true);
            }
            ForStatement forStatement = (ForStatement) ((ForStatementBlock) statementBlock).getStatement(0);
            forStatement.setBody(rRewriteStatementBlocks(forStatement.getBody(), programRewriteStatus, z));
            programRewriteStatus.setInParforContext(isInParforContext);
        }
        Iterator<StatementBlockRewriteRule> it = this._sbRuleSet.iterator();
        while (it.hasNext()) {
            StatementBlockRewriteRule next = it.next();
            if (z || !next.createsSplitDag()) {
                ArrayList arrayList2 = new ArrayList();
                Iterator<StatementBlock> it2 = arrayList.iterator();
                while (it2.hasNext()) {
                    arrayList2.addAll(next.rewriteStatementBlock(it2.next(), programRewriteStatus));
                }
                arrayList.clear();
                arrayList.addAll(arrayList2);
            }
        }
        return arrayList;
    }
}
