package org.apache.sysds.lops.rewrite;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.lops.CSVReBlock;
import org.apache.sysds.lops.CentralMoment;
import org.apache.sysds.lops.Checkpoint;
import org.apache.sysds.lops.CoVariance;
import org.apache.sysds.lops.DataGen;
import org.apache.sysds.lops.GroupedAggregate;
import org.apache.sysds.lops.GroupedAggregateM;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.MMTSJ;
import org.apache.sysds.lops.MMZip;
import org.apache.sysds.lops.MapMultChain;
import org.apache.sysds.lops.OperatorOrderingUtils;
import org.apache.sysds.lops.ParameterizedBuiltin;
import org.apache.sysds.lops.PickByCount;
import org.apache.sysds.lops.ReBlock;
import org.apache.sysds.lops.SpoofFused;
import org.apache.sysds.lops.UAggOuterChain;
import org.apache.sysds.lops.UnaryCP;
import org.apache.sysds.parser.StatementBlock;

/* loaded from: input_file:org/apache/sysds/lops/rewrite/RewriteAddPrefetchLop.class */
public class RewriteAddPrefetchLop extends LopRewriteRule {
    @Override // org.apache.sysds.lops.rewrite.LopRewriteRule
    public List<StatementBlock> rewriteLOPinStatementBlock(StatementBlock statementBlock) {
        ArrayList<Lop> lopList;
        if (ConfigurationManager.isPrefetchEnabled() && (lopList = OperatorOrderingUtils.getLopList(statementBlock)) != null) {
            ArrayList arrayList = new ArrayList();
            Iterator<Lop> it = lopList.iterator();
            while (it.hasNext()) {
                Lop next = it.next();
                arrayList.add(next);
                if (isPrefetchNeeded(next)) {
                    ArrayList<Lop> arrayList2 = new ArrayList(next.getOutputs());
                    UnaryCP unaryCP = new UnaryCP(next, Types.OpOp1.PREFETCH, next.getDataType(), next.getValueType(), Types.ExecType.CP);
                    unaryCP.setAsynchronous(true);
                    next.setAsynchronous(false);
                    for (Lop lop : arrayList2) {
                        unaryCP.addOutput(lop);
                        lop.replaceInput(next, unaryCP);
                        next.removeOutput(lop);
                    }
                    arrayList.add(unaryCP);
                }
            }
            return Arrays.asList(statementBlock);
        }
        return List.of(statementBlock);
    }

    @Override // org.apache.sysds.lops.rewrite.LopRewriteRule
    public List<StatementBlock> rewriteLOPinStatementBlocks(List<StatementBlock> list) {
        return list;
    }

    private boolean isPrefetchNeeded(Lop lop) {
        return isPrefetchFromSparkNeeded(lop) || isPrefetchFromGPUNeeded(lop);
    }

    private boolean isPrefetchFromSparkNeeded(Lop lop) {
        return (lop.getExecType() == Types.ExecType.SPARK && lop.getAggType() != AggBinaryOp.SparkAggType.SINGLE_BLOCK && lop.getDataType() != Types.DataType.SCALAR && !(lop instanceof MapMultChain) && !(lop instanceof PickByCount) && !(lop instanceof MMZip) && !(lop instanceof CentralMoment) && !(lop instanceof CoVariance) && !(lop instanceof Checkpoint) && !(lop instanceof ReBlock) && !(lop instanceof CSVReBlock) && !(lop instanceof DataGen) && !(lop instanceof MMTSJ) && !(lop instanceof UAggOuterChain) && !(lop instanceof ParameterizedBuiltin) && !(lop instanceof SpoofFused)) && !lop.getOutputs().stream().anyMatch(lop2 -> {
            return (lop2 instanceof ParameterizedBuiltin) || (lop2 instanceof GroupedAggregate) || (lop2 instanceof GroupedAggregateM);
        }) && !lop.getOutputs().stream().anyMatch(lop3 -> {
            return lop3.getDataType() == Types.DataType.LIST;
        }) && (lop.isAllOutputsCP() || OperatorOrderingUtils.isCollectForBroadcast(lop)) && lop.getDataType() == Types.DataType.MATRIX;
    }

    private boolean isPrefetchFromGPUNeeded(Lop lop) {
        return (!(lop.getDataType() == Types.DataType.MATRIX && lop.isExecGPU() && lop.isAllOutputsCP()) || lop.getOutputs().stream().anyMatch(lop2 -> {
            return (lop2 instanceof ParameterizedBuiltin) || (lop2 instanceof GroupedAggregate) || (lop2 instanceof GroupedAggregateM);
        }) || lop.getOutputs().stream().anyMatch(lop3 -> {
            return lop3.getDataType() == Types.DataType.LIST;
        })) ? false : true;
    }
}
