package org.apache.sysds.lops.rewrite;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.parser.IfStatement;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;

/* loaded from: input_file:org/apache/sysds/lops/rewrite/LopRewriter.class */
public class LopRewriter {
    private ArrayList<LopRewriteRule> _lopSBRuleSet;

    public LopRewriter() {
        this._lopSBRuleSet = null;
        this._lopSBRuleSet = new ArrayList<>();
        this._lopSBRuleSet.add(new RewriteUpdateGPUPlacements());
        this._lopSBRuleSet.add(new RewriteAddPrefetchLop());
        this._lopSBRuleSet.add(new RewriteAddBroadcastLop());
        this._lopSBRuleSet.add(new RewriteAddChkpointLop());
        this._lopSBRuleSet.add(new RewriteAddChkpointInLoop());
        this._lopSBRuleSet.add(new RewriteAddGPUEvictLop());
        this._lopSBRuleSet.add(new RewriteFixIDs());
    }

    public void rewriteProgramLopDAGs(DMLProgram dMLProgram) {
        for (String str : dMLProgram.getNamespaces().keySet()) {
            Iterator<String> it = dMLProgram.getFunctionStatementBlocks(str).keySet().iterator();
            while (it.hasNext()) {
                rewriteLopDAGsFunction(dMLProgram.getFunctionStatementBlock(str, it.next()));
            }
        }
        if (this._lopSBRuleSet.isEmpty()) {
            return;
        }
        dMLProgram.setStatementBlocks(rRewriteLops(dMLProgram.getStatementBlocks()));
    }

    public void rewriteLopDAGsFunction(FunctionStatementBlock functionStatementBlock) {
        if (this._lopSBRuleSet.isEmpty()) {
            return;
        }
        rRewriteLop(functionStatementBlock);
    }

    public ArrayList<Lop> rewriteLopDAG(StatementBlock statementBlock, ArrayList<Lop> arrayList) {
        statementBlock.setLops(arrayList);
        return rRewriteLop(statementBlock).get(0).getLops();
    }

    public ArrayList<StatementBlock> rRewriteLops(ArrayList<StatementBlock> arrayList) {
        List<StatementBlock> list = arrayList;
        Iterator<LopRewriteRule> it = this._lopSBRuleSet.iterator();
        while (it.hasNext()) {
            list = it.next().rewriteLOPinStatementBlocks(list);
        }
        ArrayList arrayList2 = new ArrayList();
        Iterator<StatementBlock> it2 = list.iterator();
        while (it2.hasNext()) {
            arrayList2.addAll(rRewriteLop(it2.next()));
        }
        arrayList.clear();
        arrayList.addAll(arrayList2);
        return arrayList;
    }

    public ArrayList<StatementBlock> rRewriteLop(StatementBlock statementBlock) {
        ArrayList<StatementBlock> arrayList = new ArrayList<>();
        arrayList.add(statementBlock);
        if (statementBlock instanceof FunctionStatementBlock) {
            FunctionStatement functionStatement = (FunctionStatement) ((FunctionStatementBlock) statementBlock).getStatement(0);
            functionStatement.setBody(rRewriteLops(functionStatement.getBody()));
        } else if (statementBlock instanceof WhileStatementBlock) {
            WhileStatement whileStatement = (WhileStatement) ((WhileStatementBlock) statementBlock).getStatement(0);
            whileStatement.setBody(rRewriteLops(whileStatement.getBody()));
        } else if (statementBlock instanceof IfStatementBlock) {
            IfStatement ifStatement = (IfStatement) ((IfStatementBlock) statementBlock).getStatement(0);
            ifStatement.setIfBody(rRewriteLops(ifStatement.getIfBody()));
            ifStatement.setElseBody(rRewriteLops(ifStatement.getElseBody()));
        } else if (statementBlock instanceof ForStatementBlock) {
            ForStatement forStatement = (ForStatement) ((ForStatementBlock) statementBlock).getStatement(0);
            forStatement.setBody(rRewriteLops(forStatement.getBody()));
        }
        Iterator<LopRewriteRule> it = this._lopSBRuleSet.iterator();
        while (it.hasNext()) {
            LopRewriteRule next = it.next();
            ArrayList arrayList2 = new ArrayList();
            Iterator<StatementBlock> it2 = arrayList.iterator();
            while (it2.hasNext()) {
                arrayList2.addAll(next.rewriteLOPinStatementBlock(it2.next()));
            }
            arrayList.clear();
            arrayList.addAll(arrayList2);
        }
        return arrayList;
    }
}
