package org.apache.sysds.lops.compile.linearization;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Stack;
import java.util.stream.Collectors;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.Data;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.parser.DataExpression;

/* loaded from: input_file:org/apache/sysds/lops/compile/linearization/CostBasedLinearize.class */
public class CostBasedLinearize {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/lops/compile/linearization/CostBasedLinearize$Order.class */
    public static class Order {
        private List<Lop> _order;
        private double _pinnedMemEstimate;
        private double _bufferpoolEstimate;
        private int _numEvictions;
        private double _computeCost;

        public Order(List<Lop> list, double d, double d2, double d3) {
            this._order = new ArrayList(list);
            this._pinnedMemEstimate = d;
            this._bufferpoolEstimate = d2;
            this._numEvictions = 0;
            this._computeCost = d3;
        }

        public Order(Lop lop) {
            this(Arrays.asList(lop), lop.getOutputMemoryEstimate(), DataExpression.DEFAULT_DELIM_FILL_VALUE, lop.getComputeEstimate());
        }

        public Order(Order order) {
            this._order = order.getOrder();
            this._pinnedMemEstimate = order._pinnedMemEstimate;
            this._bufferpoolEstimate = order._bufferpoolEstimate;
            this._numEvictions = order._numEvictions;
            this._computeCost = order._computeCost;
        }

        public void addOperator(Lop lop, boolean z) {
            this._order.add(lop);
            this._computeCost += lop.getComputeEstimate();
            this._bufferpoolEstimate += lop.getOutputMemoryEstimate();
            if (z) {
                lop.getInputs().forEach(lop2 -> {
                    this._bufferpoolEstimate -= lop2.getOutputMemoryEstimate();
                });
                this._bufferpoolEstimate = this._bufferpoolEstimate < DataExpression.DEFAULT_DELIM_FILL_VALUE ? DataExpression.DEFAULT_DELIM_FILL_VALUE : this._bufferpoolEstimate;
            }
            if (this._bufferpoolEstimate > OptimizerUtils.getBufferPoolLimit()) {
                this._numEvictions++;
            }
            this._pinnedMemEstimate = lop.getTotalMemoryEstimate();
        }

        protected List<Lop> getOrder() {
            return this._order;
        }

        protected double getComputeCost() {
            return this._computeCost;
        }

        protected boolean contains(Lop lop) {
            return this._order.contains(lop);
        }

        protected int size() {
            return this._order.size();
        }
    }

    public static List<Lop> getBestOrder(List<Lop> list) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        simplifyDag(list, arrayList, arrayList2, hashMap, hashMap2);
        List list2 = (List) list.stream().filter(lop -> {
            return lop.getInputs().isEmpty();
        }).collect(Collectors.toList());
        ArrayList arrayList3 = new ArrayList();
        Iterator it = list2.iterator();
        while (it.hasNext()) {
            generateOrders((Lop) it.next(), list2, arrayList3, list.size());
        }
        List<Lop> order = ((Order) arrayList3.get((int) (Math.random() * arrayList3.size()))).getOrder();
        addRemovedNodes(order, arrayList, arrayList2, hashMap, hashMap2);
        return order;
    }

    private static void generateOrders(Lop lop, List<Lop> list, List<Order> list2, int i) {
        Stack stack = new Stack();
        stack.push(new Order(lop));
        while (!stack.isEmpty()) {
            Order order = (Order) stack.pop();
            if (order.size() == i) {
                list2.add(order);
            } else {
                ArrayList<Lop> arrayList = new ArrayList();
                Iterator<Lop> it = order.getOrder().iterator();
                while (it.hasNext()) {
                    Iterator<Lop> it2 = it.next().getOutputs().iterator();
                    while (it2.hasNext()) {
                        Lop next = it2.next();
                        if (!next.isVisited() && allInputsLinearized(next, order) && !order.contains(next)) {
                            next.setVisited();
                            arrayList.add(next);
                        }
                    }
                }
                for (Lop lop2 : arrayList) {
                    lop2.resetVisitStatus();
                    stack.push(copyAndAdd(order, lop2, true));
                }
                for (Lop lop3 : list) {
                    if (!order.contains(lop3)) {
                        stack.push(copyAndAdd(order, lop3, false));
                    }
                }
            }
        }
    }

    private static boolean allInputsLinearized(Lop lop, Order order) {
        List<Lop> order2 = order.getOrder();
        Iterator<Lop> it = lop.getInputs().iterator();
        while (it.hasNext()) {
            if (!order2.contains(it.next())) {
                return false;
            }
        }
        return true;
    }

    private static Order copyAndAdd(Order order, Lop lop, boolean z) {
        Order order2 = new Order(order);
        order2.addOperator(lop, z);
        return order2;
    }

    private static void simplifyDag(List<Lop> list, List<Lop> list2, List<Lop> list3, HashMap<Long, ArrayList<Lop>> hashMap, HashMap<Long, ArrayList<Lop>> hashMap2) {
        for (Lop lop : list) {
            if (lop.getInputs().isEmpty() && (((lop instanceof Data) && ((Data) lop).isTransientRead()) || lop.getDataType() == Types.DataType.SCALAR)) {
                list2.add(lop);
                Iterator<Lop> it = lop.getOutputs().iterator();
                while (it.hasNext()) {
                    Lop next = it.next();
                    hashMap.putIfAbsent(Long.valueOf(next.getID()), new ArrayList<>(next.getInputs()));
                    next.removeInput(lop);
                }
            }
            if (lop.getOutputs().isEmpty() && (lop instanceof Data) && ((Data) lop).isTransientWrite()) {
                list3.add(lop);
                Iterator<Lop> it2 = lop.getInputs().iterator();
                while (it2.hasNext()) {
                    Lop next2 = it2.next();
                    hashMap2.putIfAbsent(Long.valueOf(next2.getID()), new ArrayList<>(next2.getOutputs()));
                    next2.removeOutput(lop);
                }
            }
        }
        list.removeAll(list2);
        list.removeAll(list3);
    }

    private static void addRemovedNodes(List<Lop> list, List<Lop> list2, List<Lop> list3, HashMap<Long, ArrayList<Lop>> hashMap, HashMap<Long, ArrayList<Lop>> hashMap2) {
        Iterator<Lop> it = list2.iterator();
        while (it.hasNext()) {
            it.next().getOutputs().forEach(lop -> {
                lop.replaceAllInputs((ArrayList) hashMap.get(Long.valueOf(lop.getID())));
            });
        }
        list.addAll(0, list2);
        Iterator<Lop> it2 = list3.iterator();
        while (it2.hasNext()) {
            it2.next().getInputs().forEach(lop2 -> {
                lop2.replaceAllOutputs((ArrayList) hashMap2.get(Long.valueOf(lop2.getID())));
            });
        }
        list.addAll(list3);
    }
}
