package org.apache.sysds.hops.codegen.opt;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.codegen.opt.PlanSelection;
import org.apache.sysds.hops.codegen.template.CPlanMemoTable;
import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysds.runtime.util.CollectionUtils;
import org.apache.sysds.runtime.util.UtilFunctions;

/* loaded from: input_file:org/apache/sysds/hops/codegen/opt/ReachabilityGraph.class */
public class ReachabilityGraph {
    private HashMap<Pair<Long, Long>, NodeLink> _matPoints;
    private NodeLink _root;
    private InterestingPoint[] _searchSpace;
    private CutSet[] _cutSets;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/hops/codegen/opt/ReachabilityGraph$CutSet.class */
    public static class CutSet {
        private final InterestingPoint[] cut;
        private final InterestingPoint[] left;
        private final InterestingPoint[] right;
        private int[] posCut;
        private int[] posLeft;
        private int[] posRight;

        private CutSet(InterestingPoint[] interestingPointArr, InterestingPoint[] interestingPointArr2, InterestingPoint[] interestingPointArr3) {
            this.cut = interestingPointArr;
            this.left = (InterestingPoint[]) ArrayUtils.addAll(this.cut, interestingPointArr2);
            this.right = (InterestingPoint[]) ArrayUtils.addAll(this.cut, interestingPointArr3);
        }

        private void updatePositions(HashMap<InterestingPoint, Integer> hashMap) {
            int length = this.cut.length;
            this.posCut = new int[length];
            for (int i = 0; i < length; i++) {
                this.posCut[i] = hashMap.get(this.cut[i]).intValue();
            }
            int length2 = this.left.length - this.cut.length;
            this.posLeft = new int[length2];
            for (int i2 = 0; i2 < length2; i2++) {
                this.posLeft[i2] = hashMap.get(this.left[length + i2]).intValue();
            }
            int length3 = this.right.length - this.cut.length;
            this.posRight = new int[length3];
            for (int i3 = 0; i3 < length3; i3++) {
                this.posRight[i3] = hashMap.get(this.right[length + i3]).intValue();
            }
        }

        public String toString() {
            return "Cut : " + Arrays.toString(this.cut);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/hops/codegen/opt/ReachabilityGraph$NodeLink.class */
    public static class NodeLink implements Comparable<NodeLink> {
        private static final IDSequence _seqID = new IDSequence();
        private ArrayList<NodeLink> _inputs = new ArrayList<>();
        private long _ID = _seqID.getNextID();
        private InterestingPoint _p;

        private NodeLink(InterestingPoint interestingPoint) {
            this._p = interestingPoint;
        }

        private void addInput(NodeLink nodeLink) {
            this._inputs.add(nodeLink);
        }

        public int hashCode() {
            return Arrays.hashCode(new int[]{this._inputs.hashCode(), Long.hashCode(this._ID), this._p.hashCode()});
        }

        public boolean equals(Object obj) {
            if (!(obj instanceof NodeLink)) {
                return false;
            }
            NodeLink nodeLink = (NodeLink) obj;
            boolean z = this._inputs.size() == nodeLink._inputs.size();
            for (int i = 0; i < this._inputs.size() && z; i++) {
                z &= this._inputs.get(i)._ID == nodeLink._inputs.get(i)._ID;
            }
            return z;
        }

        @Override // java.lang.Comparable
        public int compareTo(NodeLink nodeLink) {
            if (this._inputs.size() > nodeLink._inputs.size()) {
                return -1;
            }
            if (this._inputs.size() < nodeLink._inputs.size()) {
                return 1;
            }
            for (int i = 0; i < this._inputs.size(); i++) {
                int compare = Long.compare(this._inputs.get(i)._ID, nodeLink._inputs.get(i)._ID);
                if (compare != 0) {
                    return compare;
                }
            }
            return 0;
        }

        public String toString() {
            StringBuilder sb = new StringBuilder();
            Iterator<NodeLink> it = this._inputs.iterator();
            while (it.hasNext()) {
                NodeLink next = it.next();
                if (sb.length() > 0) {
                    sb.append(",");
                }
                sb.append(next._ID);
            }
            long j = this._ID;
            String sb2 = sb.toString();
            if (this._p != null) {
                InterestingPoint interestingPoint = this._p;
            }
            return j + " (" + j + ") " + sb2;
        }

        private String explain(HashSet<Long> hashSet) {
            if (hashSet.contains(Long.valueOf(this._ID))) {
                return "";
            }
            StringBuilder sb = new StringBuilder();
            StringBuilder sb2 = new StringBuilder();
            Iterator<NodeLink> it = this._inputs.iterator();
            while (it.hasNext()) {
                NodeLink next = it.next();
                String explain = next.explain(hashSet);
                if (!explain.isEmpty()) {
                    sb.append(explain + "\n");
                }
                if (sb2.length() > 0) {
                    sb2.append(",");
                }
                sb2.append(next._ID);
            }
            long j = this._ID;
            if (this._p != null) {
                InterestingPoint interestingPoint = this._p;
            }
            sb.append(j + " (" + sb + ") " + sb2);
            hashSet.add(Long.valueOf(this._ID));
            return sb.toString();
        }
    }

    /* loaded from: input_file:org/apache/sysds/hops/codegen/opt/ReachabilityGraph$SubProblem.class */
    public static class SubProblem {
        public int offset;
        public int[] freePos;
        public InterestingPoint[] freeMat;

        public SubProblem(int i, int[] iArr, InterestingPoint[] interestingPointArr) {
            this.offset = i;
            this.freePos = iArr;
            this.freeMat = interestingPointArr;
        }

        public String toString() {
            return "SubProblem: " + Arrays.toString(this.freeMat) + "; " + this.offset + "; " + Arrays.toString(this.freePos);
        }
    }

    public ReachabilityGraph(PlanPartition planPartition, CPlanMemoTable cPlanMemoTable) {
        this._matPoints = null;
        this._root = null;
        this._matPoints = new HashMap<>();
        for (InterestingPoint interestingPoint : planPartition.getMatPointsExt()) {
            this._matPoints.put(Pair.of(Long.valueOf(interestingPoint._fromHopID), Long.valueOf(interestingPoint._toHopID)), new NodeLink(interestingPoint));
        }
        this._root = new NodeLink(null);
        HashSet<PlanSelection.VisitMarkCost> hashSet = new HashSet<>();
        Iterator<Long> it = planPartition.getRoots().iterator();
        while (it.hasNext()) {
            addInputNodeLinks(cPlanMemoTable.getHopRefs().get(it.next()), this._root, planPartition, cPlanMemoTable, hashSet);
        }
        List<NodeLink> list = (List) this._matPoints.values().stream().filter(nodeLink -> {
            return nodeLink._inputs.size() > 0 && nodeLink._p != null;
        }).sorted().collect(Collectors.toList());
        if (list.isEmpty()) {
            this._cutSets = new CutSet[0];
            this._searchSpace = sortBySize(planPartition.getMatPointsExt(), cPlanMemoTable, false);
            return;
        }
        ArrayList<ArrayList<NodeLink>> arrayList = new ArrayList<>();
        ArrayList<NodeLink> arrayList2 = new ArrayList<>();
        for (NodeLink nodeLink2 : list) {
            if (arrayList2.isEmpty()) {
                arrayList2.add(nodeLink2);
            } else if (arrayList2.get(0).equals(nodeLink2)) {
                arrayList2.add(nodeLink2);
            } else {
                arrayList.add(arrayList2);
                arrayList2 = new ArrayList<>();
                arrayList2.add(nodeLink2);
            }
        }
        if (!arrayList2.isEmpty()) {
            arrayList.add(arrayList2);
        }
        ArrayList<ArrayList<NodeLink>> arrayList3 = new ArrayList<>();
        ArrayList<Pair<CutSet, Double>> evaluateCutSets = evaluateCutSets(arrayList, arrayList3);
        if (!arrayList3.isEmpty() && arrayList3.size() < 5) {
            ArrayList<ArrayList<NodeLink>> arrayList4 = new ArrayList<>();
            for (int i = 0; i < arrayList3.size() - 1; i++) {
                for (int i2 = i + 1; i2 < arrayList3.size(); i2++) {
                    ArrayList<NodeLink> arrayList5 = new ArrayList<>();
                    arrayList5.addAll(arrayList3.get(i));
                    arrayList5.addAll(arrayList3.get(i2));
                    arrayList4.add(arrayList5);
                }
            }
            ArrayList<Pair<CutSet, Double>> evaluateCutSets2 = evaluateCutSets(arrayList4, arrayList3);
            HashSet hashSet2 = new HashSet();
            Iterator<Pair<CutSet, Double>> it2 = evaluateCutSets2.iterator();
            while (it2.hasNext()) {
                Pair<CutSet, Double> next = it2.next();
                if (!CollectionUtils.containsAny(hashSet2, Arrays.asList(((CutSet) next.getLeft()).cut))) {
                    evaluateCutSets.add(next);
                    CollectionUtils.addAll(hashSet2, ((CutSet) next.getLeft()).cut);
                }
            }
        }
        this._cutSets = (CutSet[]) evaluateCutSets.stream().sorted(Comparator.comparing(pair -> {
            return (Double) pair.getRight();
        })).map(pair2 -> {
            return (CutSet) pair2.getLeft();
        }).toArray(i3 -> {
            return new CutSet[i3];
        });
        HashMap<InterestingPoint, Integer> hashMap = new HashMap<>();
        ArrayList arrayList6 = new ArrayList();
        for (CutSet cutSet : this._cutSets) {
            CollectionUtils.addAll(arrayList6, cutSet.cut);
            for (InterestingPoint interestingPoint2 : cutSet.cut) {
                hashMap.put(interestingPoint2, Integer.valueOf(hashMap.size()));
            }
        }
        for (InterestingPoint interestingPoint3 : sortBySize(planPartition.getMatPointsExt(), cPlanMemoTable, false)) {
            if (!hashMap.containsKey(interestingPoint3)) {
                arrayList6.add(interestingPoint3);
                hashMap.put(interestingPoint3, Integer.valueOf(hashMap.size()));
            }
        }
        this._searchSpace = (InterestingPoint[]) arrayList6.toArray(new InterestingPoint[0]);
        for (CutSet cutSet2 : this._cutSets) {
            cutSet2.updatePositions(hashMap);
        }
        if (this._searchSpace.length != planPartition.getMatPointsExt().length) {
            throw new RuntimeException("Corrupt linearized search space: " + this._searchSpace.length + " vs " + planPartition.getMatPointsExt().length);
        }
    }

    public InterestingPoint[] getSortedSearchSpace() {
        return this._searchSpace;
    }

    public boolean isCutSet(boolean[] zArr) {
        for (CutSet cutSet : this._cutSets) {
            if (isCutSet(cutSet, zArr)) {
                return true;
            }
        }
        return false;
    }

    public boolean isCutSet(CutSet cutSet, boolean[] zArr) {
        boolean z = true;
        for (int i = 0; i < cutSet.posCut.length && z; i++) {
            z &= zArr[cutSet.posCut[i]];
        }
        return z;
    }

    public CutSet getCutSet(boolean[] zArr) {
        for (CutSet cutSet : this._cutSets) {
            if (isCutSet(cutSet, zArr)) {
                return cutSet;
            }
        }
        throw new RuntimeException("No valid cut set found.");
    }

    public long getNumSkipPlans(boolean[] zArr) {
        for (CutSet cutSet : this._cutSets) {
            if (isCutSet(cutSet, zArr)) {
                return UtilFunctions.pow(2, (zArr.length - cutSet.posCut[cutSet.posCut.length - 1]) - 1);
            }
        }
        throw new RuntimeException("Failed to compute number of skip plans for plan without cutset.");
    }

    public SubProblem[] getSubproblems(boolean[] zArr) {
        CutSet cutSet = getCutSet(zArr);
        return new SubProblem[]{new SubProblem(cutSet.cut.length, cutSet.posLeft, cutSet.left), new SubProblem(cutSet.cut.length, cutSet.posRight, cutSet.right)};
    }

    public String toString() {
        return "ReachabilityGraph(" + this._matPoints.size() + "):\n" + this._root.explain(new HashSet<>());
    }

    private void addInputNodeLinks(Hop hop, NodeLink nodeLink, PlanPartition planPartition, CPlanMemoTable cPlanMemoTable, HashSet<PlanSelection.VisitMarkCost> hashSet) {
        if (hashSet.contains(new PlanSelection.VisitMarkCost(hop.getHopID(), nodeLink._ID))) {
            return;
        }
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            Hop next = it.next();
            if (InterestingPoint.isMatPoint(planPartition.getMatPointsExt(), hop.getHopID(), next.getHopID())) {
                NodeLink nodeLink2 = this._matPoints.get(Pair.of(Long.valueOf(hop.getHopID()), Long.valueOf(next.getHopID())));
                nodeLink.addInput(nodeLink2);
                addInputNodeLinks(next, nodeLink2, planPartition, cPlanMemoTable, hashSet);
            } else {
                addInputNodeLinks(next, nodeLink, planPartition, cPlanMemoTable, hashSet);
            }
        }
        hashSet.add(new PlanSelection.VisitMarkCost(hop.getHopID(), nodeLink._ID));
    }

    private void rCollectInputs(NodeLink nodeLink, HashSet<NodeLink> hashSet, HashSet<NodeLink> hashSet2) {
        Iterator<NodeLink> it = nodeLink._inputs.iterator();
        while (it.hasNext()) {
            NodeLink next = it.next();
            if (!hashSet.contains(next)) {
                rCollectInputs(next, hashSet, hashSet2);
                hashSet2.add(next);
            }
        }
    }

    private ArrayList<Pair<CutSet, Double>> evaluateCutSets(ArrayList<ArrayList<NodeLink>> arrayList, ArrayList<ArrayList<NodeLink>> arrayList2) {
        ArrayList<Pair<CutSet, Double>> arrayList3 = new ArrayList<>();
        Iterator<ArrayList<NodeLink>> it = arrayList.iterator();
        while (it.hasNext()) {
            ArrayList<NodeLink> next = it.next();
            HashSet<NodeLink> hashSet = new HashSet<>(next);
            HashSet<NodeLink> hashSet2 = new HashSet<>();
            rCollectInputs(this._root, hashSet, hashSet2);
            HashSet<NodeLink> hashSet3 = new HashSet<>();
            Iterator<NodeLink> it2 = next.iterator();
            while (it2.hasNext()) {
                rCollectInputs(it2.next(), hashSet, hashSet3);
            }
            if (CollectionUtils.containsAny(hashSet2, hashSet3) || hashSet2.isEmpty() || hashSet3.isEmpty()) {
                arrayList2.add(next);
            } else {
                double pow = UtilFunctions.pow(2, this._matPoints.size());
                double pow2 = UtilFunctions.pow(2, next.size());
                arrayList3.add(Pair.of(new CutSet((InterestingPoint[]) next.stream().map(nodeLink -> {
                    return nodeLink._p;
                }).toArray(i -> {
                    return new InterestingPoint[i];
                }), (InterestingPoint[]) hashSet2.stream().map(nodeLink2 -> {
                    return nodeLink2._p;
                }).toArray(i2 -> {
                    return new InterestingPoint[i2];
                }), (InterestingPoint[]) hashSet3.stream().map(nodeLink3 -> {
                    return nodeLink3._p;
                }).toArray(i3 -> {
                    return new InterestingPoint[i3];
                })), Double.valueOf((((pow2 - 1.0d) / pow2) * pow) + ((1.0d / pow2) * UtilFunctions.pow(2, hashSet2.size())) + ((1.0d / pow2) * UtilFunctions.pow(2, hashSet3.size())))));
            }
        }
        return arrayList3;
    }

    private static InterestingPoint[] sortBySize(InterestingPoint[] interestingPointArr, CPlanMemoTable cPlanMemoTable, boolean z) {
        return (InterestingPoint[]) Arrays.stream(interestingPointArr).sorted(Comparator.comparing(interestingPoint -> {
            return Long.valueOf((z ? 1 : -1) * getSize(cPlanMemoTable.getHopRefs().get(Long.valueOf(interestingPoint.getToHopID()))));
        })).toArray(i -> {
            return new InterestingPoint[i];
        });
    }

    private static long getSize(Hop hop) {
        return Math.max(hop.getDim1(), 1L) * Math.max(hop.getDim2(), 1L);
    }
}
