package org.apache.sysds.lops.compile.linearization;

import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.conf.DMLConfig;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.Lop;

/* loaded from: input_file:org/apache/sysds/lops/compile/linearization/ILinearize.class */
public interface ILinearize {
    public static final Log LOG = LogFactory.getLog(ILinearize.class.getName());

    /* loaded from: input_file:org/apache/sysds/lops/compile/linearization/ILinearize$DagLinearization.class */
    public enum DagLinearization {
        DEPTH_FIRST,
        BREADTH_FIRST,
        MIN_INTERMEDIATE
    }

    static List<Lop> linearize(List<Lop> list) {
        try {
            switch (DagLinearization.valueOf(ConfigurationManager.getDMLConfig().getTextValue(DMLConfig.DAG_LINEARIZATION).toUpperCase())) {
                case MIN_INTERMEDIATE:
                    return doMinIntermediateSort(list);
                case BREADTH_FIRST:
                    return doBreadthFirstSort(list);
                case DEPTH_FIRST:
                default:
                    return depthFirst(list);
            }
        } catch (Exception e) {
            LOG.warn("Invalid or failed DAG_LINEARIZATION, fallback to DEPTH_FIRST ordering");
            return depthFirst(list);
        }
    }

    private static List<Lop> depthFirst(List<Lop> list) {
        return (List) Stream.concat(list.stream().filter(lop -> {
            return !lop.getOutputs().isEmpty();
        }).sorted(Comparator.comparing(lop2 -> {
            return Long.valueOf(lop2.getID());
        })), list.stream().filter(lop3 -> {
            return lop3.getOutputs().isEmpty();
        })).collect(Collectors.toList());
    }

    private static List<Lop> doBreadthFirstSort(List<Lop> list) {
        return (List) list.stream().sorted(Comparator.comparing((v0) -> {
            return v0.getLevel();
        })).collect(Collectors.toList());
    }

    private static List<Lop> doMinIntermediateSort(List<Lop> list) {
        ArrayList arrayList = new ArrayList(list.size());
        List list2 = (List) list.stream().filter(lop -> {
            return lop.getOutputs().isEmpty();
        }).collect(Collectors.toList());
        LinkedList linkedList = new LinkedList(list);
        sortRecursive(arrayList, list2, linkedList);
        while (!linkedList.isEmpty()) {
            int orElse = linkedList.stream().mapToInt((v0) -> {
                return v0.getLevel();
            }).max().orElse(-1);
            sortRecursive(arrayList, (List) linkedList.stream().filter(lop2 -> {
                return lop2.getLevel() == orElse;
            }).collect(Collectors.toList()), linkedList);
        }
        Collections.reverse(arrayList);
        return arrayList;
    }

    private static void sortRecursive(List<Lop> list, List<Lop> list2, List<Lop> list3) {
        List<Map.Entry> list4 = (List) list2.stream().distinct().map(lop -> {
            return new AbstractMap.SimpleEntry(lop, Long.valueOf(lop.getOutputs().isEmpty() ? 0L : OptimizerUtils.estimateSizeExactSparsity(lop.getOutputParameters().getNumRows(), lop.getOutputParameters().getNumCols(), lop.getOutputParameters().getNnz())));
        }).sorted(Comparator.comparing(simpleEntry -> {
            return (Long) simpleEntry.getValue();
        })).collect(Collectors.toList());
        Collections.reverse(list4);
        for (Map.Entry entry : list4) {
            if (!list.contains(entry.getKey()) && (list.containsAll(((Lop) entry.getKey()).getOutputs()) || !list3.stream().anyMatch(lop2 -> {
                return ((Lop) entry.getKey()).getOutputs().contains(lop2);
            }))) {
                list.add((Lop) entry.getKey());
                list3.remove(entry.getKey());
                sortRecursive(list, ((Lop) entry.getKey()).getInputs(), list3);
            }
        }
    }
}
