package org.apache.sysds.hops.rewrite;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.HopsException;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.VariableSet;

/* loaded from: input_file:org/apache/sysds/hops/rewrite/RewriteSplitDagUnknownCSVRead.class */
public class RewriteSplitDagUnknownCSVRead extends StatementBlockRewriteRule {
    @Override // org.apache.sysds.hops.rewrite.StatementBlockRewriteRule
    public boolean createsSplitDag() {
        return true;
    }

    @Override // org.apache.sysds.hops.rewrite.StatementBlockRewriteRule
    public List<StatementBlock> rewriteStatementBlock(StatementBlock statementBlock, ProgramRewriteStatus programRewriteStatus) {
        if (DMLScript.getGlobalExecMode() == Types.ExecMode.SINGLE_NODE || !HopRewriteUtils.isLastLevelStatementBlock(statementBlock)) {
            return Arrays.asList(statementBlock);
        }
        ArrayList arrayList = new ArrayList();
        ArrayList<Hop> arrayList2 = new ArrayList<>();
        collectCSVReadHopsUnknownSize(statementBlock.getHops(), arrayList2);
        if (arrayList2.isEmpty()) {
            arrayList.add(statementBlock);
        } else {
            try {
                StatementBlock statementBlock2 = new StatementBlock();
                statementBlock2.setDMLProg(statementBlock.getDMLProg());
                statementBlock2.setParseInfo(statementBlock);
                statementBlock2.setLiveIn(new VariableSet());
                statementBlock2.setLiveOut(new VariableSet());
                ArrayList<Hop> arrayList3 = new ArrayList<>();
                Iterator<Hop> it = arrayList2.iterator();
                while (it.hasNext()) {
                    Hop next = it.next();
                    for (int i = 0; i < next.getInput().size(); i++) {
                        if (next.getInput().get(i) instanceof LiteralOp) {
                            HopRewriteUtils.replaceChildReference(next, next.getInput().get(i), new LiteralOp((LiteralOp) next.getInput().get(i)));
                        }
                    }
                    DataOp createTransientRead = HopRewriteUtils.createTransientRead(next.getName(), next);
                    ArrayList arrayList4 = new ArrayList(next.getParent());
                    for (int i2 = 0; i2 < arrayList4.size(); i2++) {
                        HopRewriteUtils.replaceChildReference((Hop) arrayList4.get(i2), next, createTransientRead);
                    }
                    arrayList3.add(HopRewriteUtils.createTransientWrite(next.getName(), next));
                    DataIdentifier variable = statementBlock.variablesRead().getVariable(next.getName());
                    if (variable != null) {
                        statementBlock2.liveOut().addVariable(next.getName(), new DataIdentifier(variable));
                        statementBlock.liveIn().addVariable(next.getName(), new DataIdentifier(variable));
                    }
                }
                statementBlock2.setHops(arrayList3);
                statementBlock2.updateRecompilationFlag();
                arrayList.add(statementBlock2);
                arrayList.add(statementBlock);
                statementBlock.setSplitDag(true);
                LOG.debug("Applied splitDagUnknownCSVRead.");
            } catch (Exception e) {
                throw new HopsException("Failed to split hops dag for csv read with unknown size.", e);
            }
        }
        return arrayList;
    }

    @Override // org.apache.sysds.hops.rewrite.StatementBlockRewriteRule
    public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> list, ProgramRewriteStatus programRewriteStatus) {
        return list;
    }

    private void collectCSVReadHopsUnknownSize(ArrayList<Hop> arrayList, ArrayList<Hop> arrayList2) {
        if (arrayList == null) {
            return;
        }
        Hop.resetVisitStatus(arrayList);
        Iterator<Hop> it = arrayList.iterator();
        while (it.hasNext()) {
            collectCSVReadHopsUnknownSize(it.next(), arrayList2);
        }
    }

    private void collectCSVReadHopsUnknownSize(Hop hop, ArrayList<Hop> arrayList) {
        if (hop.isVisited()) {
            return;
        }
        if (hop instanceof DataOp) {
            DataOp dataOp = (DataOp) hop;
            if (dataOp.getOp() == Types.OpOpData.PERSISTENTREAD && dataOp.getInputFormatType() == Types.FileFormat.CSV && !dataOp.dimsKnown() && !HopRewriteUtils.hasOnlyWriteParents(dataOp, true, false)) {
                arrayList.add(dataOp);
            }
        }
        if (hop.getInput() != null) {
            Iterator<Hop> it = hop.getInput().iterator();
            while (it.hasNext()) {
                collectCSVReadHopsUnknownSize(it.next(), arrayList);
            }
        }
        hop.setVisited();
    }
}
