package org.apache.sysds.hops.rewrite;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.conf.DMLConfig;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.HopsException;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.ParameterizedBuiltinOp;
import org.apache.sysds.hops.TernaryOp;
import org.apache.sysds.hops.recompile.Recompiler;
import org.apache.sysds.lops.Compression;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.VariableSet;
import org.apache.sysds.runtime.matrix.data.Pair;

/* loaded from: input_file:org/apache/sysds/hops/rewrite/RewriteSplitDagDataDependentOperators.class */
public class RewriteSplitDagDataDependentOperators extends StatementBlockRewriteRule {
    @Override // org.apache.sysds.hops.rewrite.StatementBlockRewriteRule
    public boolean createsSplitDag() {
        return true;
    }

    @Override // org.apache.sysds.hops.rewrite.StatementBlockRewriteRule
    public List<StatementBlock> rewriteStatementBlock(StatementBlock statementBlock, ProgramRewriteStatus programRewriteStatus) {
        String createCutVarName;
        Compression.CompressConfig valueOf = Compression.CompressConfig.valueOf(ConfigurationManager.getDMLConfig().getTextValue(DMLConfig.COMPRESSED_LINALG).toUpperCase());
        if ((DMLScript.getGlobalExecMode() == Types.ExecMode.SINGLE_NODE && valueOf == Compression.CompressConfig.FALSE) || !HopRewriteUtils.isLastLevelStatementBlock(statementBlock)) {
            return Arrays.asList(statementBlock);
        }
        ArrayList arrayList = new ArrayList();
        ArrayList<Hop> arrayList2 = new ArrayList<>();
        collectDataDependentOperators(statementBlock.getHops(), arrayList2);
        Hop.resetVisitStatus(statementBlock.getHops());
        if (arrayList2.isEmpty()) {
            arrayList.add(statementBlock);
        } else {
            HashSet<Hop> hashSet = new HashSet<>();
            collectCandidateChildOperators(arrayList2, hashSet);
            hashSet.addAll(arrayList2);
            try {
                StatementBlock statementBlock2 = new StatementBlock();
                statementBlock2.setDMLProg(statementBlock.getDMLProg());
                statementBlock2.setParseInfo(statementBlock);
                statementBlock2.setLiveIn(new VariableSet());
                statementBlock2.setLiveOut(new VariableSet());
                ArrayList<Hop> arrayList3 = new ArrayList<>();
                Iterator<Hop> it = arrayList2.iterator();
                while (it.hasNext()) {
                    Hop next = it.next();
                    boolean hasTransientWriteParents = hasTransientWriteParents(next);
                    boolean rHasSimpleReadChain = hasTransientWriteParents ? HopRewriteUtils.rHasSimpleReadChain(next, getFirstTransientWriteParent(next).getName()) : false;
                    long dim1 = next.getDim1();
                    long dim2 = next.getDim2();
                    int blocksize = next.getBlocksize();
                    if (hasTransientWriteParents && rHasSimpleReadChain) {
                        Hop firstTransientWriteParent = getFirstTransientWriteParent(next);
                        createCutVarName = firstTransientWriteParent.getName();
                        DataOp createTransientRead = HopRewriteUtils.createTransientRead(createCutVarName, next);
                        ArrayList arrayList4 = new ArrayList(next.getParent());
                        for (int i = 0; i < arrayList4.size(); i++) {
                            Hop hop = (Hop) arrayList4.get(i);
                            if (!hashSet.contains(hop)) {
                                if (hop != firstTransientWriteParent) {
                                    HopRewriteUtils.replaceChildReference(hop, next, createTransientRead);
                                } else {
                                    statementBlock.getHops().remove(hop);
                                }
                            }
                        }
                        arrayList3.add(firstTransientWriteParent);
                    } else {
                        createCutVarName = createCutVarName(false);
                        DataOp createTransientRead2 = HopRewriteUtils.createTransientRead(createCutVarName, next);
                        ArrayList arrayList5 = new ArrayList(next.getParent());
                        for (int i2 = 0; i2 < arrayList5.size(); i2++) {
                            Hop hop2 = (Hop) arrayList5.get(i2);
                            if (!hashSet.contains(hop2)) {
                                HopRewriteUtils.replaceChildReference(hop2, next, createTransientRead2);
                            }
                        }
                        arrayList3.add(HopRewriteUtils.createTransientWrite(createCutVarName, next));
                    }
                    DataIdentifier dataIdentifier = new DataIdentifier(createCutVarName);
                    dataIdentifier.setDimensions(dim1, dim2);
                    dataIdentifier.setBlocksize(blocksize);
                    dataIdentifier.setDataType(next.getDataType());
                    dataIdentifier.setValueType(next.getValueType());
                    statementBlock2.liveOut().addVariable(createCutVarName, new DataIdentifier(dataIdentifier));
                    statementBlock.liveIn().addVariable(createCutVarName, new DataIdentifier(dataIdentifier));
                    statementBlock.variablesRead().addVariable(createCutVarName, new DataIdentifier(dataIdentifier));
                }
                handleReplicatedOperators(arrayList3, statementBlock.getHops(), statementBlock2.liveOut(), statementBlock.liveIn());
                statementBlock2.setHops(Recompiler.deepCopyHopsDag(arrayList3));
                statementBlock2.updateRecompilationFlag();
                statementBlock2.setSplitDag(true);
                arrayList.addAll(rewriteStatementBlock(statementBlock2, programRewriteStatus));
                arrayList.add(statementBlock);
                statementBlock.setSplitDag(true);
                LOG.debug("Applied splitDagDataDependentOperators (lines " + statementBlock.getBeginLine() + "-" + statementBlock.getEndLine() + ").");
            } catch (Exception e) {
                throw new HopsException("Failed to split hops dag for data dependent operators with unknown size.", e);
            }
        }
        return arrayList;
    }

    private void collectDataDependentOperators(ArrayList<Hop> arrayList, ArrayList<Hop> arrayList2) {
        if (arrayList == null) {
            return;
        }
        Hop.resetVisitStatus(arrayList);
        Iterator<Hop> it = arrayList.iterator();
        while (it.hasNext()) {
            rCollectDataDependentOperators(it.next(), arrayList2);
        }
    }

    private void rCollectDataDependentOperators(Hop hop, ArrayList<Hop> arrayList) {
        if (hop.isVisited()) {
            return;
        }
        boolean z = HopRewriteUtils.hasOnlyWriteParents(hop, true, false) || hop.dimsKnown() || DMLScript.getGlobalExecMode() == Types.ExecMode.SINGLE_NODE;
        boolean z2 = true;
        if ((hop instanceof ParameterizedBuiltinOp) && ((ParameterizedBuiltinOp) hop).getOp() == Types.ParamBuiltinOp.RMEMPTY && !z && (hop.getParent().size() != 1 || !(hop.getParent().get(0) instanceof TernaryOp) || !((TernaryOp) hop.getParent().get(0)).isMatrixIgnoreZeroRewriteApplicable())) {
            ParameterizedBuiltinOp parameterizedBuiltinOp = (ParameterizedBuiltinOp) hop;
            arrayList.add(parameterizedBuiltinOp);
            z2 = false;
            boolean z3 = true;
            boolean z4 = true;
            boolean isTargetDiagInput = parameterizedBuiltinOp.isTargetDiagInput();
            Iterator<Hop> it = hop.getParent().iterator();
            while (it.hasNext()) {
                Hop next = it.next();
                z3 &= ((next instanceof AggBinaryOp) && hop == next.getInput().get(0)) || HopRewriteUtils.isUnary(next, Types.OpOp1.NROW);
                z4 &= (next instanceof AggBinaryOp) && hop == next.getInput().get(0);
            }
            parameterizedBuiltinOp.setOutputEmptyBlocks(!z3);
            if (z4 && isTargetDiagInput) {
                if (ConfigurationManager.isDynamicRecompilation()) {
                    parameterizedBuiltinOp.setOutputPermutationMatrix(true);
                }
                Iterator<Hop> it2 = hop.getParent().iterator();
                while (it2.hasNext()) {
                    ((AggBinaryOp) it2.next()).setHasLeftPMInput(true);
                }
            }
        }
        if (HopRewriteUtils.isTernary(hop, Types.OpOp3.CTABLE) && hop.getInput().size() < 4 && !z) {
            arrayList.add(hop);
            z2 = false;
            boolean z5 = true;
            Iterator<Hop> it3 = hop.getParent().iterator();
            while (it3.hasNext()) {
                Hop next2 = it3.next();
                z5 &= (next2 instanceof AggBinaryOp) && hop == next2.getInput().get(0);
            }
            if (z5 && HopRewriteUtils.isBasic1NSequence(hop.getInput().get(0))) {
                hop.setOutputEmptyBlocks(false);
            }
        }
        if (HopRewriteUtils.isReorg(hop, Types.ReOrgOp.SORT)) {
            for (int i = 2; i <= 3; i++) {
                Hop hop2 = hop.getInput().get(i);
                if (!(hop2 instanceof LiteralOp) && !(hop2 instanceof DataOp)) {
                    arrayList.add(hop2);
                    hop2.setVisited();
                    z2 = false;
                }
            }
        }
        if (isBasicDataDependentOperator(hop, z)) {
            arrayList.add(hop);
            z2 = false;
        }
        if (z2 && hop.getInput() != null) {
            Iterator<Hop> it4 = hop.getInput().iterator();
            while (it4.hasNext()) {
                rCollectDataDependentOperators(it4.next(), arrayList);
            }
        }
        hop.setVisited();
    }

    private static boolean isBasicDataDependentOperator(Hop hop, boolean z) {
        if (!(HopRewriteUtils.isNary(hop, Types.OpOpN.EVAL) & (!z))) {
            if (!(HopRewriteUtils.isData(hop, Types.OpOpData.SQLREAD) & (!z)) && ((!HopRewriteUtils.isParameterBuiltinOp(hop, Types.ParamBuiltinOp.GROUPEDAGG) || ((ParameterizedBuiltinOp) hop).isKnownNGroups() || z) && ((!HopRewriteUtils.isUnary(hop, Types.OpOp1.COMPRESS) && !hop.requiresCompression()) || HopRewriteUtils.hasOnlyWriteParents(hop, true, true)))) {
                if (!(HopRewriteUtils.isAggUnaryOp(hop, Types.AggOp.UNIQUE) & (!z))) {
                    return false;
                }
            }
        }
        return true;
    }

    private static boolean hasTransientWriteParents(Hop hop) {
        Iterator<Hop> it = hop.getParent().iterator();
        while (it.hasNext()) {
            Hop next = it.next();
            if ((next instanceof DataOp) && ((DataOp) next).getOp() == Types.OpOpData.TRANSIENTWRITE) {
                return true;
            }
        }
        return false;
    }

    private static Hop getFirstTransientWriteParent(Hop hop) {
        Iterator<Hop> it = hop.getParent().iterator();
        while (it.hasNext()) {
            Hop next = it.next();
            if ((next instanceof DataOp) && ((DataOp) next).getOp() == Types.OpOpData.TRANSIENTWRITE) {
                return next;
            }
        }
        return null;
    }

    private void handleReplicatedOperators(ArrayList<Hop> arrayList, ArrayList<Hop> arrayList2, VariableSet variableSet, VariableSet variableSet2) {
        HashSet<Hop> hashSet = new HashSet<>();
        Hop.resetVisitStatus(arrayList);
        Iterator<Hop> it = arrayList.iterator();
        while (it.hasNext()) {
            rAddHopsToProbeSet(it.next(), hashSet);
        }
        HashSet<Pair<Hop, Hop>> hashSet2 = new HashSet<>();
        Hop.resetVisitStatus(arrayList2);
        Iterator<Hop> it2 = arrayList2.iterator();
        while (it2.hasNext()) {
            rProbeAndAddHopsToCandidateSet(it2.next(), hashSet, hashSet2);
        }
        HashMap hashMap = new HashMap();
        Iterator<Pair<Hop, Hop>> it3 = hashSet2.iterator();
        while (it3.hasNext()) {
            Pair<Hop, Hop> next = it3.next();
            Hop key = next.getKey();
            Hop value = next.getValue();
            DataOp dataOp = (DataOp) hashMap.get(Long.valueOf(value.getHopID()));
            if (dataOp == null) {
                String createCutVarName = createCutVarName(false);
                dataOp = HopRewriteUtils.createTransientRead(createCutVarName, value);
                hashMap.put(Long.valueOf(value.getHopID()), dataOp);
                DataOp createTransientWrite = HopRewriteUtils.createTransientWrite(createCutVarName, value);
                DataIdentifier dataIdentifier = new DataIdentifier(createCutVarName);
                dataIdentifier.setDimensions(value.getDim1(), value.getDim2());
                dataIdentifier.setBlocksize(value.getBlocksize());
                dataIdentifier.setDataType(value.getDataType());
                dataIdentifier.setValueType(value.getValueType());
                variableSet.addVariable(createCutVarName, new DataIdentifier(dataIdentifier));
                variableSet2.addVariable(createCutVarName, new DataIdentifier(dataIdentifier));
                arrayList.add(createTransientWrite);
            }
            int childReferencePos = HopRewriteUtils.getChildReferencePos(key, value);
            HopRewriteUtils.removeChildReferenceByPos(key, value, childReferencePos);
            HopRewriteUtils.addChildReference(key, dataOp, childReferencePos);
        }
    }

    private void rAddHopsToProbeSet(Hop hop, HashSet<Hop> hashSet) {
        if (hop.isVisited()) {
            return;
        }
        if ((!(hop instanceof DataOp) || ((DataOp) hop).isPersistentReadWrite()) && !(hop instanceof LiteralOp)) {
            hashSet.add(hop);
        }
        if (hop.getInput() != null) {
            Iterator<Hop> it = hop.getInput().iterator();
            while (it.hasNext()) {
                rAddHopsToProbeSet(it.next(), hashSet);
            }
        }
        hop.setVisited();
    }

    private void rProbeAndAddHopsToCandidateSet(Hop hop, HashSet<Hop> hashSet, HashSet<Pair<Hop, Hop>> hashSet2) {
        if (hop.isVisited()) {
            return;
        }
        if (hop.getInput() != null) {
            Iterator<Hop> it = hop.getInput().iterator();
            while (it.hasNext()) {
                Hop next = it.next();
                if (hashSet.contains(next)) {
                    hashSet2.add(new Pair<>(hop, next));
                } else {
                    rProbeAndAddHopsToCandidateSet(next, hashSet, hashSet2);
                }
            }
        }
        hop.setVisited();
    }

    private void collectCandidateChildOperators(ArrayList<Hop> arrayList, HashSet<Hop> hashSet) {
        Hop.resetVisitStatus(arrayList);
        if (arrayList != null) {
            Iterator<Hop> it = arrayList.iterator();
            while (it.hasNext()) {
                rCollectCandidateChildOperators(it.next(), arrayList, hashSet, false);
            }
        }
        Hop.resetVisitStatus(arrayList);
    }

    private void rCollectCandidateChildOperators(Hop hop, ArrayList<Hop> arrayList, HashSet<Hop> hashSet, boolean z) {
        if (hop.isVisited()) {
            return;
        }
        if (z) {
            hashSet.add(hop);
        }
        boolean z2 = z;
        if (arrayList.contains(hop)) {
            z2 = true;
        }
        if (hop.getInput() != null) {
            Iterator<Hop> it = hop.getInput().iterator();
            while (it.hasNext()) {
                rCollectCandidateChildOperators(it.next(), arrayList, hashSet, z2);
            }
        }
        hop.setVisited();
    }

    @Override // org.apache.sysds.hops.rewrite.StatementBlockRewriteRule
    public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> list, ProgramRewriteStatus programRewriteStatus) {
        return list;
    }
}
