package org.apache.sysds.lops.compile.linearization;

import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.CSVReBlock;
import org.apache.sysds.lops.CentralMoment;
import org.apache.sysds.lops.Checkpoint;
import org.apache.sysds.lops.CoVariance;
import org.apache.sysds.lops.GroupedAggregate;
import org.apache.sysds.lops.GroupedAggregateM;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.MMTSJ;
import org.apache.sysds.lops.MMZip;
import org.apache.sysds.lops.MapMultChain;
import org.apache.sysds.lops.OperatorOrderingUtils;
import org.apache.sysds.lops.ParameterizedBuiltin;
import org.apache.sysds.lops.PickByCount;
import org.apache.sysds.lops.ReBlock;
import org.apache.sysds.lops.SpoofFused;
import org.apache.sysds.lops.UAggOuterChain;
import org.apache.sysds.lops.UnaryCP;

/* loaded from: input_file:org/apache/sysds/lops/compile/linearization/ILinearize.class */
public class ILinearize {
    public static Log LOG = LogFactory.getLog(ILinearize.class.getName());

    /* loaded from: input_file:org/apache/sysds/lops/compile/linearization/ILinearize$DagLinearization.class */
    public enum DagLinearization {
        DEPTH_FIRST,
        BREADTH_FIRST,
        MIN_INTERMEDIATE,
        MAX_PARALLELIZE,
        AUTO
    }

    public static List<Lop> linearize(List<Lop> list) {
        try {
            switch (ConfigurationManager.getLinearizationOrder()) {
                case MAX_PARALLELIZE:
                    return doMaxParallelizeSort(list);
                case AUTO:
                    return CostBasedLinearize.getBestOrder(list);
                case MIN_INTERMEDIATE:
                    return doMinIntermediateSort(list);
                case BREADTH_FIRST:
                    return doBreadthFirstSort(list);
                case DEPTH_FIRST:
                default:
                    return depthFirst(list);
            }
        } catch (Exception e) {
            LOG.warn("Invalid DAG_LINEARIZATION " + ConfigurationManager.getLinearizationOrder() + ", fallback to DEPTH_FIRST ordering");
            return depthFirst(list);
        }
    }

    private static List<Lop> depthFirst(List<Lop> list) {
        return (List) Stream.concat(list.stream().filter(lop -> {
            return !lop.getOutputs().isEmpty();
        }).sorted(Comparator.comparing(lop2 -> {
            return Long.valueOf(lop2.getID());
        })), list.stream().filter(lop3 -> {
            return lop3.getOutputs().isEmpty();
        })).collect(Collectors.toList());
    }

    private static List<Lop> doBreadthFirstSort(List<Lop> list) {
        return (List) list.stream().sorted(Comparator.comparing((v0) -> {
            return v0.getLevel();
        })).collect(Collectors.toList());
    }

    private static List<Lop> doMinIntermediateSort(List<Lop> list) {
        ArrayList arrayList = new ArrayList(list.size());
        List list2 = (List) list.stream().filter(lop -> {
            return lop.getOutputs().isEmpty();
        }).collect(Collectors.toList());
        LinkedList linkedList = new LinkedList(list);
        sortRecursive(arrayList, list2, linkedList);
        while (!linkedList.isEmpty()) {
            int orElse = linkedList.stream().mapToInt((v0) -> {
                return v0.getLevel();
            }).max().orElse(-1);
            sortRecursive(arrayList, (List) linkedList.stream().filter(lop2 -> {
                return lop2.getLevel() == orElse;
            }).collect(Collectors.toList()), linkedList);
        }
        Collections.reverse(arrayList);
        return arrayList;
    }

    private static void sortRecursive(List<Lop> list, List<Lop> list2, List<Lop> list3) {
        List<Map.Entry> list4 = (List) list2.stream().distinct().map(lop -> {
            return new AbstractMap.SimpleEntry(lop, Long.valueOf(lop.getOutputs().isEmpty() ? 0L : OptimizerUtils.estimateSizeExactSparsity(lop.getOutputParameters().getNumRows(), lop.getOutputParameters().getNumCols(), lop.getOutputParameters().getNnz())));
        }).sorted(Comparator.comparing(simpleEntry -> {
            return (Long) simpleEntry.getValue();
        })).collect(Collectors.toList());
        Collections.reverse(list4);
        for (Map.Entry entry : list4) {
            if (!list.contains(entry.getKey()) && (list.containsAll(((Lop) entry.getKey()).getOutputs()) || !list3.stream().anyMatch(lop2 -> {
                return ((Lop) entry.getKey()).getOutputs().contains(lop2);
            }))) {
                list.add((Lop) entry.getKey());
                list3.remove(entry.getKey());
                sortRecursive(list, ((Lop) entry.getKey()).getInputs(), list3);
            }
        }
    }

    private static List<Lop> doMaxParallelizeSort(List<Lop> list) {
        List<Lop> list2 = list;
        boolean anyMatch = list.stream().anyMatch(ILinearize::isDistributedOp);
        boolean anyMatch2 = list.stream().anyMatch(ILinearize::isGPUOp);
        if (!anyMatch && !anyMatch2) {
            return depthFirst(list);
        }
        if (anyMatch) {
            HashMap hashMap = new HashMap();
            List list3 = (List) list.stream().filter(OperatorOrderingUtils::isLopRoot).collect(Collectors.toList());
            HashSet hashSet = new HashSet();
            list3.forEach(lop -> {
                OperatorOrderingUtils.collectSparkRoots(lop, hashMap, hashSet);
            });
            hashSet.forEach(lop2 -> {
                lop2.setAsynchronous(true);
            });
            ArrayList arrayList = new ArrayList();
            hashSet.forEach(lop3 -> {
                depthFirst(lop3, arrayList, hashMap, false);
            });
            list3.forEach(lop4 -> {
                depthFirst(lop4, arrayList, hashMap, false);
            });
            list3.forEach((v0) -> {
                v0.resetVisitStatus();
            });
            list2 = arrayList;
        }
        if (anyMatch2) {
            HashMap hashMap2 = new HashMap();
            List list4 = (List) list2.stream().filter(OperatorOrderingUtils::isLopRoot).collect(Collectors.toList());
            HashSet hashSet2 = new HashSet();
            list4.forEach(lop5 -> {
                OperatorOrderingUtils.collectGPURoots(lop5, hashMap2, hashSet2);
            });
            hashSet2.forEach(lop6 -> {
                lop6.setAsynchronous(true);
            });
            ArrayList arrayList2 = new ArrayList();
            hashSet2.forEach(lop7 -> {
                depthFirst(lop7, arrayList2, hashMap2, false);
            });
            list4.forEach(lop8 -> {
                depthFirst(lop8, arrayList2, hashMap2, false);
            });
            list4.forEach((v0) -> {
                v0.resetVisitStatus();
            });
            list2 = arrayList2;
        }
        return list2;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void depthFirst(Lop lop, ArrayList<Lop> arrayList, Map<Long, Integer> map, boolean z) {
        if (lop.isVisited()) {
            return;
        }
        if (lop.getInputs().isEmpty()) {
            arrayList.add(lop);
            lop.setVisited();
            return;
        }
        Lop[] lopArr = (Lop[]) lop.getInputs().toArray(new Lop[0]);
        if (z) {
            Arrays.sort(lopArr, (lop2, lop3) -> {
                return ((Integer) map.get(Long.valueOf(lop3.getID()))).intValue() - ((Integer) map.get(Long.valueOf(lop2.getID()))).intValue();
            });
        } else {
            Arrays.sort(lopArr, Comparator.comparingInt(lop4 -> {
                return ((Integer) map.get(Long.valueOf(lop4.getID()))).intValue();
            }));
        }
        for (Lop lop5 : lopArr) {
            depthFirst(lop5, arrayList, map, z);
        }
        arrayList.add(lop);
        lop.setVisited();
    }

    private static boolean isDistributedOp(Lop lop) {
        return lop.isExecSpark() || ((lop instanceof UnaryCP) && (((UnaryCP) lop).getOpCode().equalsIgnoreCase("prefetch") || ((UnaryCP) lop).getOpCode().equalsIgnoreCase("broadcast")));
    }

    private static boolean isGPUOp(Lop lop) {
        return lop.isExecGPU() || ((lop instanceof UnaryCP) && (((UnaryCP) lop).getOpCode().equalsIgnoreCase("prefetch") || ((UnaryCP) lop).getOpCode().equalsIgnoreCase("broadcast")));
    }

    private static List<Lop> addAsyncEagerCheckpointLop(List<Lop> list) {
        ArrayList arrayList = new ArrayList();
        for (Lop lop : list) {
            if (isCheckpointNeeded(lop)) {
                for (Lop lop2 : new ArrayList(lop.getInputs())) {
                    if (lop2.getExecType() == Types.ExecType.SPARK) {
                        Checkpoint checkpoint = new Checkpoint(lop2, lop2.getDataType(), lop2.getValueType(), Checkpoint.getDefaultStorageLevelString(), true);
                        checkpoint.addOutput(lop);
                        lop.replaceInput(lop2, checkpoint);
                        lop2.removeOutput(lop);
                        arrayList.add(checkpoint);
                    }
                }
            }
            arrayList.add(lop);
        }
        return arrayList;
    }

    private static boolean isCheckpointNeeded(Lop lop) {
        return (lop.getExecType() == Types.ExecType.SPARK && ((lop.getAggType() == AggBinaryOp.SparkAggType.SINGLE_BLOCK || lop.getDataType() == Types.DataType.SCALAR || (lop instanceof MapMultChain) || (lop instanceof PickByCount) || (lop instanceof MMZip) || (lop instanceof CentralMoment) || (lop instanceof CoVariance) || (lop instanceof MMTSJ)) && !(lop instanceof Checkpoint) && !(lop instanceof ReBlock) && !(lop instanceof CSVReBlock) && !(lop instanceof UAggOuterChain) && !(lop instanceof ParameterizedBuiltin) && !(lop instanceof SpoofFused))) && !lop.getOutputs().stream().anyMatch(lop2 -> {
            return (lop2 instanceof ParameterizedBuiltin) || (lop2 instanceof GroupedAggregate) || (lop2 instanceof GroupedAggregateM);
        });
    }
}
