package org.apache.sysds.hops.codegen.opt;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.AggUnaryOp;
import org.apache.sysds.hops.BinaryOp;
import org.apache.sysds.hops.DnnOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.IndexingOp;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.NaryOp;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.ParameterizedBuiltinOp;
import org.apache.sysds.hops.ReorgOp;
import org.apache.sysds.hops.TernaryOp;
import org.apache.sysds.hops.UnaryOp;
import org.apache.sysds.hops.codegen.opt.PlanSelection;
import org.apache.sysds.hops.codegen.opt.ReachabilityGraph;
import org.apache.sysds.hops.codegen.template.CPlanMemoTable;
import org.apache.sysds.hops.codegen.template.TemplateBase;
import org.apache.sysds.hops.codegen.template.TemplateOuterProduct;
import org.apache.sysds.hops.codegen.template.TemplateRow;
import org.apache.sysds.hops.codegen.template.TemplateUtils;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.codegen.LibSpoofPrimitives;
import org.apache.sysds.runtime.controlprogram.caching.LazyWriteBuffer;
import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysds.runtime.util.CollectionUtils;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.utils.Statistics;

/* loaded from: input_file:org/apache/sysds/hops/codegen/opt/PlanSelectionFuseCostBasedV2.class */
public class PlanSelectionFuseCostBasedV2 extends PlanSelection {
    private static final double WRITE_BANDWIDTH_IO = 5.36870912E8d;
    private static final double READ_BANDWIDTH_MEM = 3.4359738368E10d;
    private static final double READ_BANDWIDTH_BROADCAST = 1.34217728E8d;
    private static final double SPARSE_SAFE_SPARSITY_EST = 0.1d;
    public static final double COST_MIN_EPS = 0.01d;
    public static final int COST_MIN_EPS_NUM_POINTS = 20;
    private static final int PLAN_CACHE_NUM_POINTS = 10;
    private static final int PLAN_CACHE_SIZE = 1024;
    private final IDSequence COST_ID = new IDSequence();
    private static final Log LOG = LogFactory.getLog(PlanSelectionFuseCostBasedV2.class.getName());
    private static final double WRITE_BANDWIDTH_MEM = 2.147483648E9d;
    private static final double COMPUTE_BANDWIDTH = WRITE_BANDWIDTH_MEM * InfrastructureAnalyzer.getLocalParallelism();
    private static final LinkedHashMap<PartitionSignature, boolean[]> _planCache = new LinkedHashMap<>();
    public static boolean COST_PRUNING = true;
    public static boolean STRUCTURAL_PRUNING = true;
    public static boolean PLAN_CACHING = true;
    private static final TemplateRow ROW_TPL = new TemplateRow();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/hops/codegen/opt/PlanSelectionFuseCostBasedV2$AggregateInfo.class */
    public static class AggregateInfo {
        public final HashSet<Long> _inputAggs = new HashSet<>();
        public final HashSet<Long> _fusedInputs = new HashSet<>();
        public final HashMap<Long, Hop> _aggregates = new HashMap<>();

        public AggregateInfo(Hop hop) {
            this._aggregates.put(Long.valueOf(hop.getHopID()), hop);
        }

        public void addInputAggregate(long j) {
            this._inputAggs.add(Long.valueOf(j));
        }

        public void addFusedInput(long j) {
            this._fusedInputs.add(Long.valueOf(j));
        }

        public boolean isMergable(AggregateInfo aggregateInfo) {
            boolean z = this._aggregates.size() < 3 && this._aggregates.size() + aggregateInfo._aggregates.size() <= 3;
            Iterator<Long> it = aggregateInfo._aggregates.keySet().iterator();
            while (it.hasNext()) {
                z &= !this._inputAggs.contains(it.next());
            }
            Iterator<Long> it2 = this._aggregates.keySet().iterator();
            while (it2.hasNext()) {
                z &= !aggregateInfo._inputAggs.contains(it2.next());
            }
            boolean containsAny = z & CollectionUtils.containsAny(this._fusedInputs, aggregateInfo._fusedInputs);
            Hop next = this._aggregates.values().iterator().next();
            Hop next2 = aggregateInfo._aggregates.values().iterator().next();
            if (containsAny) {
                if (HopRewriteUtils.isEqualSize(next.getInput().get(HopRewriteUtils.isMatrixMultiply(next) ? 1 : 0), next2.getInput().get(HopRewriteUtils.isMatrixMultiply(next2) ? 1 : 0))) {
                    return true;
                }
            }
            return false;
        }

        public AggregateInfo merge(AggregateInfo aggregateInfo) {
            this._aggregates.putAll(aggregateInfo._aggregates);
            this._inputAggs.addAll(aggregateInfo._inputAggs);
            this._fusedInputs.addAll(aggregateInfo._fusedInputs);
            return this;
        }

        public String toString() {
            return "[" + Arrays.toString(this._aggregates.keySet().toArray(new Long[0])) + ": {" + Arrays.toString(this._inputAggs.toArray(new Long[0])) + "},{" + Arrays.toString(this._fusedInputs.toArray(new Long[0])) + "}]";
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/hops/codegen/opt/PlanSelectionFuseCostBasedV2$CostVector.class */
    public class CostVector {
        public final long ID;
        public final double outSize;
        public double computeCosts = DataExpression.DEFAULT_DELIM_FILL_VALUE;
        public final HashMap<Long, Double> inSizes = new HashMap<>();

        public CostVector(double d) {
            this.ID = PlanSelectionFuseCostBasedV2.this.COST_ID.getNextID();
            this.outSize = d;
        }

        public void addInputSize(long j, double d) {
            this.inSizes.put(Long.valueOf(j), Double.valueOf(d));
        }

        public double getInputSize() {
            return this.inSizes.values().stream().mapToDouble(d -> {
                return d.doubleValue();
            }).sum();
        }

        public double getSideInputSize() {
            double maxInputSize = getMaxInputSize();
            return this.inSizes.values().stream().filter(d -> {
                return d.doubleValue() < maxInputSize;
            }).mapToDouble(d2 -> {
                return d2.doubleValue();
            }).sum();
        }

        public double getMaxInputSize() {
            return this.inSizes.values().stream().mapToDouble(d -> {
                return d.doubleValue();
            }).max().orElse(DataExpression.DEFAULT_DELIM_FILL_VALUE);
        }

        public long getMaxInputSizeHopID() {
            long j = -1;
            double d = 0.0d;
            for (Map.Entry<Long, Double> entry : this.inSizes.entrySet()) {
                if (d < entry.getValue().doubleValue()) {
                    j = entry.getKey().longValue();
                    d = entry.getValue().doubleValue();
                }
            }
            return j;
        }

        public String toString() {
            return "[" + this.outSize + ", " + this.computeCosts + ", {" + Arrays.toString(this.inSizes.keySet().toArray(new Long[0])) + ", " + Arrays.toString(this.inSizes.values().toArray(new Double[0])) + "}]";
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/hops/codegen/opt/PlanSelectionFuseCostBasedV2$PartitionSignature.class */
    public static class PartitionSignature {
        private final int partNodes;
        private final int inputNodes;
        private final int rootNodes;
        private final int matPoints;
        private final double cCompute;
        private final double cRead;
        private final double cWrite;
        private final double cPlan0;
        private final double cPlanN;

        public PartitionSignature(PlanPartition planPartition, int i, StaticCosts staticCosts, double d, double d2) {
            this.partNodes = planPartition.getPartition().size();
            this.inputNodes = planPartition.getInputs().size();
            this.rootNodes = planPartition.getRoots().size();
            this.matPoints = i;
            this.cCompute = staticCosts._compute;
            this.cRead = staticCosts._read;
            this.cWrite = staticCosts._write;
            this.cPlan0 = d;
            this.cPlanN = d2;
        }

        public int hashCode() {
            return UtilFunctions.intHashCode(Arrays.hashCode(new int[]{this.partNodes, this.inputNodes, this.rootNodes, this.matPoints}), Arrays.hashCode(new double[]{this.cCompute, this.cRead, this.cWrite, this.cPlan0, this.cPlanN}));
        }

        public boolean equals(Object obj) {
            if (!(obj instanceof PartitionSignature)) {
                return false;
            }
            PartitionSignature partitionSignature = (PartitionSignature) obj;
            return this.partNodes == partitionSignature.partNodes && this.inputNodes == partitionSignature.inputNodes && this.rootNodes == partitionSignature.rootNodes && this.matPoints == partitionSignature.matPoints && this.cCompute == partitionSignature.cCompute && this.cRead == partitionSignature.cRead && this.cWrite == partitionSignature.cWrite && this.cPlan0 == partitionSignature.cPlan0 && this.cPlanN == partitionSignature.cPlanN;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/hops/codegen/opt/PlanSelectionFuseCostBasedV2$StaticCosts.class */
    public static class StaticCosts {
        public final HashMap<Long, Double> _computeCosts;
        public final double _compute;
        public final double _read;
        public final double _write;
        public final double _minSparsity;

        public StaticCosts(HashMap<Long, Double> hashMap, double d, double d2, double d3, double d4) {
            this._computeCosts = hashMap;
            this._compute = d;
            this._read = d2;
            this._write = d3;
            this._minSparsity = d4;
        }

        public double getMinCosts() {
            return Math.max(this._read, this._compute) + this._write;
        }
    }

    @Override // org.apache.sysds.hops.codegen.opt.PlanSelection
    public void selectPlans(CPlanMemoTable cPlanMemoTable, ArrayList<Hop> arrayList) {
        int i = 0;
        for (PlanPartition planPartition : PlanAnalyzer.analyzePlanPartitions(cPlanMemoTable, arrayList, true)) {
            createAndAddMultiAggPlans(cPlanMemoTable, planPartition.getPartition(), planPartition.getRoots());
            selectPlans(cPlanMemoTable, planPartition);
            i += planPartition.getMatPointsExt().length;
        }
        createAndAddMultiAggPlans(cPlanMemoTable, arrayList);
        for (Map.Entry<Long, List<CPlanMemoTable.MemoTableEntry>> entry : getBestPlans().entrySet()) {
            cPlanMemoTable.setDistinct(entry.getKey().longValue(), entry.getValue());
        }
        if (DMLScript.STATISTICS) {
            if (i >= 63) {
                LOG.warn("Long overflow on maintaining codegen statistics for a DAG with " + i + " interesting points.");
            }
            Statistics.incrementCodegenEnumAll(UtilFunctions.pow(2, i));
        }
    }

    private void selectPlans(CPlanMemoTable cPlanMemoTable, PlanPartition planPartition) {
        pruneInvalidAndSpecialCasePlans(cPlanMemoTable, planPartition);
        if (planPartition.getMatPointsExt() == null || planPartition.getMatPointsExt().length == 0) {
            Iterator<Long> it = planPartition.getRoots().iterator();
            while (it.hasNext()) {
                rSelectPlansFuseAll(cPlanMemoTable, cPlanMemoTable.getHopRefs().get(it.next()), null, planPartition.getPartition());
            }
            return;
        }
        HashMap hashMap = new HashMap();
        Iterator<Long> it2 = planPartition.getPartition().iterator();
        while (it2.hasNext()) {
            getComputeCosts(cPlanMemoTable.getHopRefs().get(it2.next()), hashMap);
        }
        StaticCosts staticCosts = new StaticCosts(hashMap, sumComputeCost(hashMap), getReadCost(planPartition, cPlanMemoTable), getWriteCost(planPartition.getRoots(), cPlanMemoTable), minOuterSparsity(planPartition, cPlanMemoTable));
        ReachabilityGraph reachabilityGraph = STRUCTURAL_PRUNING ? new ReachabilityGraph(planPartition, cPlanMemoTable) : null;
        if (STRUCTURAL_PRUNING) {
            planPartition.setMatPointsExt(reachabilityGraph.getSortedSearchSpace());
            Iterator<Long> it3 = planPartition.getPartition().iterator();
            while (it3.hasNext()) {
                cPlanMemoTable.pruneRedundant(it3.next().longValue(), true, planPartition.getMatPointsExt());
            }
        }
        boolean[] enumPlans = enumPlans(cPlanMemoTable, planPartition, staticCosts, reachabilityGraph, planPartition.getMatPointsExt(), 0);
        HashSet hashSet = new HashSet();
        Iterator<Long> it4 = planPartition.getRoots().iterator();
        while (it4.hasNext()) {
            rPruneSuboptimalPlans(cPlanMemoTable, cPlanMemoTable.getHopRefs().get(it4.next()), hashSet, planPartition, planPartition.getMatPointsExt(), enumPlans);
        }
        HashSet hashSet2 = new HashSet();
        Iterator<Long> it5 = planPartition.getRoots().iterator();
        while (it5.hasNext()) {
            rPruneInvalidPlans(cPlanMemoTable, cPlanMemoTable.getHopRefs().get(it5.next()), hashSet2, planPartition, enumPlans);
        }
        Iterator<Long> it6 = planPartition.getRoots().iterator();
        while (it6.hasNext()) {
            rSelectPlansFuseAll(cPlanMemoTable, cPlanMemoTable.getHopRefs().get(it6.next()), null, planPartition.getPartition());
        }
    }

    private boolean[] enumPlans(CPlanMemoTable cPlanMemoTable, PlanPartition planPartition, StaticCosts staticCosts, ReachabilityGraph reachabilityGraph, InterestingPoint[] interestingPointArr, int i) {
        long j;
        int length = interestingPointArr.length - i;
        long pow = UtilFunctions.pow(2, length);
        long j2 = 2;
        long j3 = 0;
        boolean[] createAssignment = createAssignment(length, i, 0L);
        boolean[] createAssignment2 = createAssignment(length, i, pow - 1);
        double planCost = getPlanCost(cPlanMemoTable, planPartition, interestingPointArr, createAssignment, staticCosts._computeCosts, Double.MAX_VALUE);
        double planCost2 = getPlanCost(cPlanMemoTable, planPartition, interestingPointArr, createAssignment2, staticCosts._computeCosts, Double.MAX_VALUE);
        boolean[] zArr = planCost <= planCost2 ? createAssignment : createAssignment2;
        double min = Math.min(planCost, planCost2);
        boolean z = length < 20 || !COST_PRUNING || min > 1.01d * staticCosts.getMinCosts();
        if (LOG.isTraceEnabled()) {
            LOG.trace("Enum opening: " + Arrays.toString(zArr) + " -> " + min);
        }
        if (!z) {
            LOG.warn("Skip enum for |M|=" + length + ", C=" + min + ", Cmin=" + staticCosts.getMinCosts());
        }
        PartitionSignature partitionSignature = null;
        if (probePlanCache(interestingPointArr)) {
            partitionSignature = new PartitionSignature(planPartition, interestingPointArr.length, staticCosts, planCost, planCost2);
            boolean[] plan = getPlan(partitionSignature);
            if (plan != null) {
                Statistics.incrementCodegenEnumAllP((reachabilityGraph == null && STRUCTURAL_PRUNING) ? 0L : pow);
                return plan;
            }
        }
        long j4 = 1;
        while (true) {
            long j5 = j4;
            if (!(j5 < pow - 1) || !z) {
                break;
            }
            boolean[] createAssignment3 = createAssignment(length, i, j5);
            long j6 = 0;
            if (STRUCTURAL_PRUNING && reachabilityGraph != null && reachabilityGraph.isCutSet(createAssignment3)) {
                j6 = reachabilityGraph.getNumSkipPlans(createAssignment3);
                if (LOG.isTraceEnabled()) {
                    LOG.trace("Enum: Structural pruning for cut set: " + reachabilityGraph.getCutSet(createAssignment3));
                }
                ReachabilityGraph.SubProblem[] subproblems = reachabilityGraph.getSubproblems(createAssignment3);
                for (int i2 = 0; i2 < subproblems.length; i2++) {
                    if (LOG.isTraceEnabled()) {
                        LOG.trace("Enum: Subproblem " + (i2 + 1) + Lop.FILE_SEPARATOR + subproblems.length + ": " + subproblems[i2]);
                    }
                    LibSpoofPrimitives.vectWrite(enumPlans(cPlanMemoTable, planPartition, staticCosts, null, subproblems[i2].freeMat, subproblems[i2].offset), createAssignment3, subproblems[i2].freePos);
                }
            } else if (COST_PRUNING && getLowerBoundCosts(planPartition, interestingPointArr, cPlanMemoTable, staticCosts, createAssignment3) >= min) {
                long numSkipPlans = getNumSkipPlans(createAssignment3);
                if (LOG.isTraceEnabled()) {
                    LOG.trace("Enum: Skip " + numSkipPlans + " plans (by cost).");
                }
                j = j5 + (numSkipPlans - 1);
                j4 = j + 1;
            }
            double planCost3 = getPlanCost(cPlanMemoTable, planPartition, interestingPointArr, createAssignment3, staticCosts._computeCosts, COST_PRUNING ? min : Double.MAX_VALUE);
            if (LOG.isTraceEnabled()) {
                LOG.trace("Enum: " + Arrays.toString(createAssignment3) + " -> " + planCost3);
            }
            j3 += planCost3 == Double.POSITIVE_INFINITY ? 1L : 0L;
            j2++;
            if (zArr == null || planCost3 < min) {
                min = planCost3;
                zArr = createAssignment3;
                if (LOG.isTraceEnabled()) {
                    LOG.trace("Enum: Found new best plan.");
                }
            }
            j = j5 + j6;
            if (j6 != 0 && LOG.isTraceEnabled()) {
                LOG.trace("Enum: Skip " + j6 + " plans (by structure).");
            }
            j4 = j + 1;
        }
        if (DMLScript.STATISTICS) {
            Statistics.incrementCodegenEnumAllP((reachabilityGraph == null && STRUCTURAL_PRUNING) ? 0L : pow);
            Statistics.incrementCodegenEnumEval(j2);
            Statistics.incrementCodegenEnumEvalP(j3);
        }
        if (LOG.isTraceEnabled()) {
            LOG.trace("Enum: Optimal plan: " + Arrays.toString(zArr));
        }
        if (probePlanCache(interestingPointArr)) {
            putPlan(partitionSignature, zArr);
        }
        return zArr == null ? new boolean[length] : Arrays.copyOfRange(zArr, i, zArr.length);
    }

    private static boolean[] createAssignment(int i, int i2, long j) {
        boolean[] zArr = new boolean[i2 + i];
        Arrays.fill(zArr, 0, i2, true);
        long j2 = j;
        for (int i3 = 0; i3 < i; i3++) {
            long pow = UtilFunctions.pow(2, (i - i3) - 1);
            zArr[i2 + i3] = j2 >= pow;
            j2 %= pow;
        }
        return zArr;
    }

    private static long getNumSkipPlans(boolean[] zArr) {
        return UtilFunctions.pow(2, (zArr.length - ArrayUtils.lastIndexOf(zArr, true)) - 1);
    }

    private static double getLowerBoundCosts(PlanPartition planPartition, InterestingPoint[] interestingPointArr, CPlanMemoTable cPlanMemoTable, StaticCosts staticCosts, boolean[] zArr) {
        double max = Math.max(staticCosts._read, staticCosts._compute) + staticCosts._write + getMaterializationCost(planPartition, interestingPointArr, cPlanMemoTable, zArr);
        if (planPartition.hasOuter()) {
            max *= staticCosts._minSparsity;
        }
        return max;
    }

    private static double getMaterializationCost(PlanPartition planPartition, InterestingPoint[] interestingPointArr, CPlanMemoTable cPlanMemoTable, boolean[] zArr) {
        double d = 0.0d;
        HashSet hashSet = new HashSet();
        for (int i = 0; i < zArr.length; i++) {
            long toHopID = interestingPointArr[i].getToHopID();
            if (zArr[i] && !hashSet.contains(Long.valueOf(toHopID))) {
                hashSet.add(Long.valueOf(toHopID));
                long size = getSize(cPlanMemoTable.getHopRefs().get(Long.valueOf(toHopID)));
                d += ((size * 8) / WRITE_BANDWIDTH_MEM) + ((size * 8) / READ_BANDWIDTH_MEM);
            }
        }
        Iterator<Long> it = planPartition.getExtConsumed().iterator();
        while (it.hasNext()) {
            Long next = it.next();
            if (!hashSet.contains(next)) {
                hashSet.add(next);
                d += (getSize(cPlanMemoTable.getHopRefs().get(next)) * 8) / WRITE_BANDWIDTH_MEM;
            }
        }
        return d;
    }

    private static double getReadCost(PlanPartition planPartition, CPlanMemoTable cPlanMemoTable) {
        double d = 0.0d;
        Iterator<Long> it = planPartition.getInputs().iterator();
        while (it.hasNext()) {
            d += getSafeMemEst(cPlanMemoTable.getHopRefs().get(it.next())) / READ_BANDWIDTH_MEM;
        }
        return d;
    }

    private static double getWriteCost(Collection<Long> collection, CPlanMemoTable cPlanMemoTable) {
        double d = 0.0d;
        Iterator<Long> it = collection.iterator();
        while (it.hasNext()) {
            d += (getSize(cPlanMemoTable.getHopRefs().get(it.next())) * 8) / WRITE_BANDWIDTH_MEM;
        }
        return d;
    }

    private static double sumComputeCost(HashMap<Long, Double> hashMap) {
        return hashMap.values().stream().mapToDouble(d -> {
            return d.doubleValue() / COMPUTE_BANDWIDTH;
        }).sum();
    }

    private static double minOuterSparsity(PlanPartition planPartition, CPlanMemoTable cPlanMemoTable) {
        if (planPartition.hasOuter()) {
            return planPartition.getPartition().stream().map(l -> {
                return HopRewriteUtils.getLargestInput(cPlanMemoTable.getHopRefs().get(l));
            }).mapToDouble(hop -> {
                return hop.dimsKnown(true) ? hop.getSparsity() : SPARSE_SAFE_SPARSITY_EST;
            }).min().orElse(SPARSE_SAFE_SPARSITY_EST);
        }
        return 1.0d;
    }

    private static double sumTmpInputOutputSize(CPlanMemoTable cPlanMemoTable, CostVector costVector) {
        return costVector.outSize + costVector.inSizes.entrySet().stream().filter(entry -> {
            return !HopRewriteUtils.isData(cPlanMemoTable.getHopRefs().get(entry.getKey()), Types.OpOpData.TRANSIENTREAD);
        }).mapToDouble(entry2 -> {
            return ((Double) entry2.getValue()).doubleValue();
        }).sum();
    }

    private static double sumInputMemoryEstimates(CPlanMemoTable cPlanMemoTable, CostVector costVector) {
        return costVector.inSizes.keySet().stream().mapToDouble(l -> {
            return getSafeMemEst(cPlanMemoTable.getHopRefs().get(l));
        }).sum();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static double getSafeMemEst(Hop hop) {
        return !hop.dimsKnown() ? getSize(hop) * 8 : hop.getOutputMemEstimate();
    }

    private static long getSize(Hop hop) {
        return Math.max(hop.getDim1(), 1L) * Math.max(hop.getDim2(), 1L);
    }

    private static void createAndAddMultiAggPlans(CPlanMemoTable cPlanMemoTable, HashSet<Long> hashSet, HashSet<Long> hashSet2) {
        HashSet hashSet3 = new HashSet();
        for (Map.Entry<Long, List<CPlanMemoTable.MemoTableEntry>> entry : cPlanMemoTable.getPlans().entrySet()) {
            if (!entry.getValue().isEmpty()) {
                Iterator<Hop> it = cPlanMemoTable.getHopRefs().get(entry.getKey()).getInput().iterator();
                while (it.hasNext()) {
                    hashSet3.add(Long.valueOf(it.next().getHopID()));
                }
            }
        }
        ArrayList arrayList = new ArrayList();
        Iterator<Long> it2 = hashSet2.iterator();
        while (it2.hasNext()) {
            Long next = it2.next();
            Hop hop = cPlanMemoTable.getHopRefs().get(next);
            if (!hashSet3.contains(next) && isMultiAggregateRoot(hop)) {
                arrayList.add(next);
            }
        }
        if (LOG.isTraceEnabled()) {
            LOG.trace("Found within-partition ua(RC) aggregations: " + Arrays.toString(arrayList.toArray(new Long[0])));
        }
        for (int i = 0; i < arrayList.size(); i += 3) {
            int min = Math.min(i + 3, arrayList.size());
            if (min - i >= 2) {
                CPlanMemoTable.MemoTableEntry memoTableEntry = new CPlanMemoTable.MemoTableEntry(TemplateBase.TemplateType.MAGG, ((Long) arrayList.get(i)).longValue(), ((Long) arrayList.get(i + 1)).longValue(), min - i == 3 ? ((Long) arrayList.get(i + 2)).longValue() : -1L, min - i);
                if (isValidMultiAggregate(cPlanMemoTable, memoTableEntry)) {
                    for (int i2 = i; i2 < min; i2++) {
                        cPlanMemoTable.add(cPlanMemoTable.getHopRefs().get(arrayList.get(i2)), memoTableEntry);
                        if (LOG.isTraceEnabled()) {
                            LOG.trace("Added multiagg plan: " + arrayList.get(i2) + " " + memoTableEntry);
                        }
                    }
                } else if (LOG.isTraceEnabled()) {
                    LOG.trace("Removed invalid multiagg plan: " + memoTableEntry);
                }
            }
        }
    }

    private void createAndAddMultiAggPlans(CPlanMemoTable cPlanMemoTable, ArrayList<Hop> arrayList) {
        HashSet hashSet = new HashSet();
        Hop.resetVisitStatus(arrayList);
        Iterator<Hop> it = arrayList.iterator();
        while (it.hasNext()) {
            rCollectFullAggregates(it.next(), hashSet);
        }
        Hop.resetVisitStatus(arrayList);
        hashSet.removeIf(l -> {
            return cPlanMemoTable.contains(l.longValue(), TemplateBase.TemplateType.MAGG);
        });
        if (hashSet.size() <= 1) {
            return;
        }
        if (LOG.isTraceEnabled()) {
            LOG.trace("Found across-partition ua(RC) aggregations: " + Arrays.toString(hashSet.toArray(new Long[0])));
        }
        ArrayList arrayList2 = new ArrayList();
        Iterator it2 = hashSet.iterator();
        while (it2.hasNext()) {
            Hop hop = cPlanMemoTable.getHopRefs().get((Long) it2.next());
            AggregateInfo aggregateInfo = new AggregateInfo(hop);
            int i = 0;
            while (i < hop.getInput().size()) {
                rExtractAggregateInfo(cPlanMemoTable, (HopRewriteUtils.isMatrixMultiply(hop) && i == 0) ? hop.getInput().get(0).getInput().get(0) : hop.getInput().get(i), aggregateInfo, TemplateBase.TemplateType.CELL);
                i++;
            }
            if (aggregateInfo._fusedInputs.isEmpty()) {
                if (HopRewriteUtils.isMatrixMultiply(hop)) {
                    aggregateInfo.addFusedInput(hop.getInput().get(0).getInput().get(0).getHopID());
                    aggregateInfo.addFusedInput(hop.getInput().get(1).getHopID());
                } else {
                    aggregateInfo.addFusedInput(hop.getInput().get(0).getHopID());
                }
            }
            arrayList2.add(aggregateInfo);
        }
        if (LOG.isTraceEnabled()) {
            LOG.trace("Extracted across-partition ua(RC) aggregation info: ");
            Iterator it3 = arrayList2.iterator();
            while (it3.hasNext()) {
                LOG.trace((AggregateInfo) it3.next());
            }
        }
        List<AggregateInfo> list = (List) arrayList2.stream().sorted(Comparator.comparing(aggregateInfo2 -> {
            return Integer.valueOf(aggregateInfo2._inputAggs.size());
        })).collect(Collectors.toList());
        boolean z = false;
        while (!z) {
            AggregateInfo aggregateInfo3 = null;
            for (int i2 = 0; i2 < list.size(); i2++) {
                AggregateInfo aggregateInfo4 = (AggregateInfo) list.get(i2);
                int i3 = i2 + 1;
                while (i3 < list.size()) {
                    AggregateInfo aggregateInfo5 = (AggregateInfo) list.get(i3);
                    if (aggregateInfo4.isMergable(aggregateInfo5)) {
                        aggregateInfo3 = aggregateInfo4.merge(aggregateInfo5);
                        list.remove(i3);
                        i3--;
                    }
                    i3++;
                }
            }
            z = aggregateInfo3 == null;
        }
        if (LOG.isTraceEnabled()) {
            LOG.trace("Merged across-partition ua(RC) aggregation info: ");
            Iterator it4 = list.iterator();
            while (it4.hasNext()) {
                LOG.trace((AggregateInfo) it4.next());
            }
        }
        for (AggregateInfo aggregateInfo6 : list) {
            if (aggregateInfo6._aggregates.size() > 1) {
                Long[] lArr = (Long[]) aggregateInfo6._aggregates.keySet().toArray(new Long[0]);
                CPlanMemoTable.MemoTableEntry memoTableEntry = new CPlanMemoTable.MemoTableEntry(TemplateBase.TemplateType.MAGG, lArr[0].longValue(), lArr[1].longValue(), lArr.length > 2 ? lArr[2].longValue() : -1L, lArr.length);
                for (int i4 = 0; i4 < lArr.length; i4++) {
                    cPlanMemoTable.add(cPlanMemoTable.getHopRefs().get(lArr[i4]), memoTableEntry);
                    addBestPlan(lArr[i4].longValue(), memoTableEntry);
                    if (LOG.isTraceEnabled()) {
                        LOG.trace("Added multiagg* plan: " + lArr[i4] + " " + memoTableEntry);
                    }
                }
            }
        }
    }

    private static boolean isMultiAggregateRoot(Hop hop) {
        return (HopRewriteUtils.isAggUnaryOp(hop, Types.AggOp.SUM, Types.AggOp.SUM_SQ, Types.AggOp.MIN, Types.AggOp.MAX) && ((AggUnaryOp) hop).getDirection() == Types.Direction.RowCol) || ((hop instanceof AggBinaryOp) && hop.getDim1() == 1 && hop.getDim2() == 1 && HopRewriteUtils.isTransposeOperation(hop.getInput().get(0)));
    }

    private static boolean isValidMultiAggregate(CPlanMemoTable cPlanMemoTable, CPlanMemoTable.MemoTableEntry memoTableEntry) {
        boolean z = true;
        Hop hop = cPlanMemoTable.getHopRefs().get(Long.valueOf(memoTableEntry.input1)).getInput().get(0);
        for (int i = 1; z && i < 3; i++) {
            if (memoTableEntry.isPlanRef(i)) {
                z &= HopRewriteUtils.isEqualSize(hop, cPlanMemoTable.getHopRefs().get(Long.valueOf(memoTableEntry.input(i))).getInput().get(0));
            }
        }
        for (int i2 = 0; z && i2 < 3; i2++) {
            if (memoTableEntry.isPlanRef(i2)) {
                HashSet hashSet = new HashSet();
                for (int i3 = 0; i3 < 3; i3++) {
                    if (i2 != i3) {
                        hashSet.add(Long.valueOf(memoTableEntry.input(i3)));
                    }
                }
                z &= rCheckMultiAggregate(cPlanMemoTable.getHopRefs().get(Long.valueOf(memoTableEntry.input(i2))), hashSet);
            }
        }
        return z;
    }

    private static boolean rCheckMultiAggregate(Hop hop, HashSet<Long> hashSet) {
        boolean z = true;
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            z &= rCheckMultiAggregate(it.next(), hashSet);
        }
        return z & (!hashSet.contains(Long.valueOf(hop.getHopID())));
    }

    private static void rCollectFullAggregates(Hop hop, HashSet<Long> hashSet) {
        if (hop.isVisited()) {
            return;
        }
        if (isMultiAggregateRoot(hop)) {
            hashSet.add(Long.valueOf(hop.getHopID()));
        }
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            rCollectFullAggregates(it.next(), hashSet);
        }
        hop.setVisited();
    }

    private static void rExtractAggregateInfo(CPlanMemoTable cPlanMemoTable, Hop hop, AggregateInfo aggregateInfo, TemplateBase.TemplateType templateType) {
        if (isMultiAggregateRoot(hop)) {
            aggregateInfo.addInputAggregate(hop.getHopID());
        }
        CPlanMemoTable.MemoTableEntry best = templateType != null ? cPlanMemoTable.getBest(hop.getHopID()) : null;
        for (int i = 0; i < hop.getInput().size(); i++) {
            Hop hop2 = hop.getInput().get(i);
            if (best == null || !best.isPlanRef(i)) {
                if (templateType != null && hop2.getDataType().isMatrix()) {
                    aggregateInfo.addFusedInput(hop2.getHopID());
                }
                rExtractAggregateInfo(cPlanMemoTable, hop2, aggregateInfo, null);
            } else {
                rExtractAggregateInfo(cPlanMemoTable, hop2, aggregateInfo, templateType);
            }
        }
    }

    private static HashSet<Long> collectIrreplaceableRowOps(CPlanMemoTable cPlanMemoTable, PlanPartition planPartition) {
        HashSet<Long> hashSet = new HashSet<>();
        HashSet hashSet2 = new HashSet();
        Iterator<Long> it = planPartition.getRoots().iterator();
        while (it.hasNext()) {
            rCollectDependentRowOps(cPlanMemoTable.getHopRefs().get(it.next()), cPlanMemoTable, planPartition, hashSet, hashSet2, null, false);
        }
        return hashSet;
    }

    private static void rCollectDependentRowOps(Hop hop, CPlanMemoTable cPlanMemoTable, PlanPartition planPartition, HashSet<Long> hashSet, HashSet<Pair<Long, Integer>> hashSet2, TemplateBase.TemplateType templateType, boolean z) {
        Pair<Long, Integer> of = Pair.of(Long.valueOf(hop.getHopID()), Integer.valueOf((z ? 32767 : 0) + (templateType != null ? templateType.ordinal() + 1 : 0)));
        if (hashSet2.contains(of) || !planPartition.getPartition().contains(Long.valueOf(hop.getHopID()))) {
            return;
        }
        CPlanMemoTable.MemoTableEntry best = templateType == null ? cPlanMemoTable.getBest(hop.getHopID()) : cPlanMemoTable.getBest(hop.getHopID(), templateType);
        boolean z2 = best != null && best.type == TemplateBase.TemplateType.ROW && templateType == TemplateBase.TemplateType.ROW;
        boolean z3 = planPartition.getMatPointsExt().length > 0 && cPlanMemoTable.contains(hop.getHopID(), TemplateBase.TemplateType.ROW) && !cPlanMemoTable.hasOnlyExactMatches(hop.getHopID(), TemplateBase.TemplateType.ROW, TemplateBase.TemplateType.CELL);
        if (z2 && z) {
            hashSet.add(Long.valueOf(hop.getHopID()));
        }
        if (isRowAggOp(hop, z2) || z3) {
            hashSet.add(Long.valueOf(hop.getHopID()));
            z = true;
        }
        for (int i = 0; i < hop.getInput().size(); i++) {
            rCollectDependentRowOps(hop.getInput().get(i), cPlanMemoTable, planPartition, hashSet, hashSet2, best != null ? best.type : null, z && best != null && (best.isPlanRef(i) || isImplicitlyFused(hop, i, best.type)));
        }
        if (!hashSet.contains(Long.valueOf(hop.getHopID()))) {
            for (int i2 = 0; i2 < hop.getInput().size(); i2++) {
                if (best != null && best.type == TemplateBase.TemplateType.ROW && ((best.isPlanRef(i2) || isImplicitlyFused(hop, i2, best.type)) && hashSet.contains(Long.valueOf(hop.getInput().get(i2).getHopID())))) {
                    hashSet.add(Long.valueOf(hop.getHopID()));
                }
            }
        }
        hashSet2.add(of);
    }

    private static boolean isRowAggOp(Hop hop, boolean z) {
        return HopRewriteUtils.isBinary(hop, Types.OpOp2.CBIND) || HopRewriteUtils.isNary(hop, Types.OpOpN.CBIND) || ((hop instanceof AggBinaryOp) && !(!z && hop.dimsKnown() && (hop.getDim1() == 1 || hop.getDim2() == 1))) || (!(!HopRewriteUtils.isTransposeOperation(hop) || hop.getDim1() == 1 || hop.getDim2() == 1 || HopRewriteUtils.isDataGenOp(hop.getInput().get(0), Types.OpOpDG.SEQ)) || ((hop instanceof AggUnaryOp) && z));
    }

    private static boolean isValidRow2CellOp(Hop hop) {
        return !HopRewriteUtils.isBinary(hop, Types.OpOp2.CBIND) && (!(hop instanceof AggBinaryOp) || hop.getDim1() == 1 || hop.getDim2() == 1);
    }

    private static void pruneInvalidAndSpecialCasePlans(CPlanMemoTable cPlanMemoTable, PlanPartition planPartition) {
        CPlanMemoTable.MemoTableEntry best;
        if (OptimizerUtils.isSparkExecutionMode()) {
            Iterator<Long> it = planPartition.getPartition().iterator();
            while (it.hasNext()) {
                Long next = it.next();
                if (cPlanMemoTable.contains(next.longValue(), TemplateBase.TemplateType.ROW)) {
                    Hop hop = cPlanMemoTable.getHopRefs().get(next);
                    boolean z = DMLScript.getGlobalExecMode() == Types.ExecMode.SPARK || OptimizerUtils.getTotalMemEstimate((Hop[]) hop.getInput().toArray(new Hop[0]), hop, true) > OptimizerUtils.getLocalMemBudget();
                    boolean z2 = hop.getDataType().isScalar() || (!HopRewriteUtils.isTransposeOperation(hop) ? hop.getDim2() > ((long) hop.getBlocksize()) : hop.getDim1() > ((long) hop.getBlocksize()));
                    Iterator<Hop> it2 = hop.getInput().iterator();
                    while (it2.hasNext()) {
                        Hop next2 = it2.next();
                        z2 &= next2.getDataType().isScalar() || next2.getDim2() <= ((long) next2.getBlocksize()) || ((hop instanceof AggBinaryOp) && next2.getDim1() <= ((long) next2.getBlocksize()) && HopRewriteUtils.isTransposeOperation(next2));
                    }
                    if (z && !z2) {
                        List<CPlanMemoTable.MemoTableEntry> list = cPlanMemoTable.get(next.longValue(), TemplateBase.TemplateType.ROW);
                        cPlanMemoTable.remove(cPlanMemoTable.getHopRefs().get(next), TemplateBase.TemplateType.ROW);
                        cPlanMemoTable.removeAllRefTo(next.longValue(), TemplateBase.TemplateType.ROW);
                        if (LOG.isTraceEnabled()) {
                            LOG.trace("Removed row memo table entries w/ violated blocksize constraint (" + next + "): " + Arrays.toString(list.toArray(new CPlanMemoTable.MemoTableEntry[0])));
                        }
                    }
                }
            }
        }
        HashSet<Long> collectIrreplaceableRowOps = collectIrreplaceableRowOps(cPlanMemoTable, planPartition);
        Iterator<Long> it3 = planPartition.getPartition().iterator();
        while (it3.hasNext()) {
            Long next3 = it3.next();
            if (!collectIrreplaceableRowOps.contains(next3) && (best = cPlanMemoTable.getBest(next3.longValue(), TemplateBase.TemplateType.ROW)) != null && best.type == TemplateBase.TemplateType.ROW && cPlanMemoTable.hasOnlyExactMatches(next3.longValue(), TemplateBase.TemplateType.ROW, TemplateBase.TemplateType.CELL)) {
                List<CPlanMemoTable.MemoTableEntry> list2 = cPlanMemoTable.get(next3.longValue(), TemplateBase.TemplateType.ROW);
                cPlanMemoTable.remove(cPlanMemoTable.getHopRefs().get(next3), new HashSet(list2));
                if (LOG.isTraceEnabled()) {
                    LOG.trace("Removed row memo table entries w/o aggregation: " + Arrays.toString(list2.toArray(new CPlanMemoTable.MemoTableEntry[0])));
                }
            }
        }
        Iterator<Long> it4 = planPartition.getPartition().iterator();
        while (it4.hasNext()) {
            Long next4 = it4.next();
            if (cPlanMemoTable.countEntries(next4.longValue(), TemplateBase.TemplateType.OUTER) == 2) {
                List<CPlanMemoTable.MemoTableEntry> list3 = cPlanMemoTable.get(next4.longValue(), TemplateBase.TemplateType.OUTER);
                CPlanMemoTable.MemoTableEntry dropAlternativePlan = TemplateOuterProduct.dropAlternativePlan(cPlanMemoTable, list3.get(0), list3.get(1));
                if (dropAlternativePlan != null) {
                    cPlanMemoTable.remove(cPlanMemoTable.getHopRefs().get(next4), Collections.singleton(dropAlternativePlan));
                    cPlanMemoTable.getPlansExcludeListed().remove(Long.valueOf(dropAlternativePlan.input(dropAlternativePlan.getPlanRefIndex())));
                    if (LOG.isTraceEnabled()) {
                        LOG.trace("Removed dominated outer product memo table entry: " + dropAlternativePlan);
                    }
                }
            }
        }
    }

    private static void rPruneSuboptimalPlans(CPlanMemoTable cPlanMemoTable, Hop hop, HashSet<Long> hashSet, PlanPartition planPartition, InterestingPoint[] interestingPointArr, boolean[] zArr) {
        if (hashSet.contains(Long.valueOf(hop.getHopID()))) {
            return;
        }
        long hopID = hop.getHopID();
        if (planPartition.getPartition().contains(Long.valueOf(hopID)) && cPlanMemoTable.contains(hopID)) {
            Iterator<CPlanMemoTable.MemoTableEntry> it = cPlanMemoTable.get(hopID).iterator();
            while (it.hasNext()) {
                CPlanMemoTable.MemoTableEntry next = it.next();
                if (!hasNoRefToMatPoint(hopID, next, interestingPointArr, zArr) && next.type != TemplateBase.TemplateType.OUTER) {
                    it.remove();
                    if (LOG.isTraceEnabled()) {
                        LOG.trace("Removed memo table entry: " + next);
                    }
                }
            }
        }
        Iterator<Hop> it2 = hop.getInput().iterator();
        while (it2.hasNext()) {
            rPruneSuboptimalPlans(cPlanMemoTable, it2.next(), hashSet, planPartition, interestingPointArr, zArr);
        }
        hashSet.add(Long.valueOf(hop.getHopID()));
    }

    private static void rPruneInvalidPlans(CPlanMemoTable cPlanMemoTable, Hop hop, HashSet<Long> hashSet, PlanPartition planPartition, boolean[] zArr) {
        if (hashSet.contains(Long.valueOf(hop.getHopID()))) {
            return;
        }
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            rPruneInvalidPlans(cPlanMemoTable, it.next(), hashSet, planPartition, zArr);
        }
        long hopID = hop.getHopID();
        if (planPartition.getPartition().contains(Long.valueOf(hopID)) && cPlanMemoTable.contains(hopID, TemplateBase.TemplateType.ROW)) {
            Iterator<CPlanMemoTable.MemoTableEntry> it2 = cPlanMemoTable.get(hopID, TemplateBase.TemplateType.ROW).iterator();
            while (it2.hasNext()) {
                CPlanMemoTable.MemoTableEntry next = it2.next();
                boolean z = (next.hasPlanRef() || TemplateUtils.hasMatrixInput(hop)) ? false : true;
                boolean z2 = (z || ROW_TPL.open(hop)) ? false : true;
                int i = 0;
                while (true) {
                    if (!(i < 3) || !z2) {
                        break;
                    }
                    if (next.isPlanRef(i)) {
                        z2 &= !cPlanMemoTable.contains(next.input(i), TemplateBase.TemplateType.ROW);
                    }
                    i++;
                }
                if (z || z2) {
                    String str = z ? "leaf" : "inner";
                    if (isValidRow2CellOp(hop)) {
                        next.type = TemplateBase.TemplateType.CELL;
                        if (LOG.isTraceEnabled()) {
                            LOG.trace("Converted " + str + " memo table entry from row to cell: " + next);
                        }
                    } else {
                        if (LOG.isTraceEnabled()) {
                            LOG.trace("Removed " + str + " memo table entry row (unsupported cell): " + next);
                        }
                        it2.remove();
                    }
                }
            }
        }
        hashSet.add(Long.valueOf(hop.getHopID()));
    }

    private double getPlanCost(CPlanMemoTable cPlanMemoTable, PlanPartition planPartition, InterestingPoint[] interestingPointArr, boolean[] zArr, HashMap<Long, Double> hashMap, double d) {
        HashSet<PlanSelection.VisitMarkCost> hashSet = new HashSet<>();
        double d2 = 0.0d;
        int size = planPartition.getRoots().size();
        Iterator<Long> it = planPartition.getRoots().iterator();
        while (it.hasNext()) {
            d2 += rGetPlanCosts(cPlanMemoTable, cPlanMemoTable.getHopRefs().get(it.next()), hashSet, planPartition, interestingPointArr, zArr, hashMap, null, null, d - d2);
            if (d2 >= d) {
                size--;
                if (size > 0) {
                    return Double.POSITIVE_INFINITY;
                }
            }
        }
        return d2;
    }

    private double rGetPlanCosts(CPlanMemoTable cPlanMemoTable, Hop hop, HashSet<PlanSelection.VisitMarkCost> hashSet, PlanPartition planPartition, InterestingPoint[] interestingPointArr, boolean[] zArr, HashMap<Long, Double> hashMap, CostVector costVector, TemplateBase.TemplateType templateType, double d) {
        long hopID = hop.getHopID();
        if (!hashSet.add(new PlanSelection.VisitMarkCost(hopID, (costVector == null || templateType == TemplateBase.TemplateType.MAGG) ? -1L : costVector.ID))) {
            return DataExpression.DEFAULT_DELIM_FILL_VALUE;
        }
        CPlanMemoTable.MemoTableEntry memoTableEntry = null;
        boolean z = templateType == null;
        if (cPlanMemoTable.contains(hopID)) {
            if (templateType == null) {
                for (CPlanMemoTable.MemoTableEntry memoTableEntry2 : cPlanMemoTable.get(hopID)) {
                    memoTableEntry = (memoTableEntry2.isValid() && hasNoRefToMatPoint(hopID, memoTableEntry2, interestingPointArr, zArr) && PlanSelection.BasicPlanComparator.icompare(memoTableEntry2, memoTableEntry) < 0) ? memoTableEntry2 : memoTableEntry;
                }
                z = true;
            } else {
                for (CPlanMemoTable.MemoTableEntry memoTableEntry3 : cPlanMemoTable.get(hopID)) {
                    memoTableEntry = ((memoTableEntry3.type == templateType || memoTableEntry3.type == TemplateBase.TemplateType.CELL) && hasNoRefToMatPoint(hopID, memoTableEntry3, interestingPointArr, zArr) && PlanSelection.TypedPlanComparator.icompare(memoTableEntry3, memoTableEntry, templateType) < 0) ? memoTableEntry3 : memoTableEntry;
                }
            }
        }
        CostVector costVector2 = !z ? costVector : new CostVector(getSize(hop));
        double d2 = 0.0d;
        if (z && memoTableEntry != null && memoTableEntry.type == TemplateBase.TemplateType.MAGG) {
            if (memoTableEntry.input1 != hopID) {
                return DataExpression.DEFAULT_DELIM_FILL_VALUE;
            }
            for (int i = 1; i < 3; i++) {
                if (memoTableEntry.isPlanRef(i)) {
                    d2 += rGetPlanCosts(cPlanMemoTable, cPlanMemoTable.getHopRefs().get(Long.valueOf(memoTableEntry.input(i))), hashSet, planPartition, interestingPointArr, zArr, hashMap, costVector2, TemplateBase.TemplateType.MAGG, d - d2);
                    if (d2 >= d) {
                        return Double.POSITIVE_INFINITY;
                    }
                }
            }
        }
        if (hashMap.containsKey(Long.valueOf(hopID))) {
            costVector2.computeCosts += hashMap.get(Long.valueOf(hopID)).doubleValue();
        }
        for (int i2 = 0; i2 < hop.getInput().size(); i2++) {
            Hop hop2 = hop.getInput().get(i2);
            if (memoTableEntry != null && memoTableEntry.isPlanRef(i2)) {
                d2 += rGetPlanCosts(cPlanMemoTable, hop2, hashSet, planPartition, interestingPointArr, zArr, hashMap, costVector2, memoTableEntry.type, d - d2);
            } else if (memoTableEntry == null || !isImplicitlyFused(hop, i2, memoTableEntry.type)) {
                if (planPartition.getPartition().contains(Long.valueOf(hop2.getHopID()))) {
                    d2 += rGetPlanCosts(cPlanMemoTable, hop2, hashSet, planPartition, interestingPointArr, zArr, hashMap, null, null, d - d2);
                }
                if (costVector2 != null && hop2.getDataType().isMatrix()) {
                    costVector2.addInputSize(hop2.getHopID(), getSize(hop2));
                }
            } else {
                costVector2.addInputSize(hop2.getInput().get(0).getHopID(), getSize(hop2));
            }
            if (d2 >= d) {
                return Double.POSITIVE_INFINITY;
            }
        }
        if (z) {
            double sumInputMemoryEstimates = sumInputMemoryEstimates(cPlanMemoTable, costVector2);
            double max = ((costVector2.outSize * 8.0d) / WRITE_BANDWIDTH_MEM) + Math.max(sumInputMemoryEstimates / READ_BANDWIDTH_MEM, costVector2.computeCosts / COMPUTE_BANDWIDTH);
            if (sumInputMemoryEstimates > OptimizerUtils.getLocalMemBudget()) {
                max += (costVector2.getSideInputSize() * 8.0d) / READ_BANDWIDTH_BROADCAST;
            }
            Hop hop3 = cPlanMemoTable.getHopRefs().get(Long.valueOf(costVector2.getMaxInputSizeHopID()));
            if (memoTableEntry != null && memoTableEntry.type == TemplateBase.TemplateType.OUTER) {
                max *= hop3.dimsKnown(true) ? hop3.getSparsity() : SPARSE_SAFE_SPARSITY_EST;
            } else if (sumInputMemoryEstimates <= OptimizerUtils.getLocalMemBudget() && sumTmpInputOutputSize(cPlanMemoTable, costVector2) * 8.0d > LazyWriteBuffer.getWriteBufferLimit()) {
                max += (costVector2.outSize * 8.0d) / WRITE_BANDWIDTH_IO;
            }
            d2 += max;
            if (LOG.isTraceEnabled()) {
                LOG.trace("Cost vector (" + (memoTableEntry != null ? memoTableEntry.type.name() : "HOP") + " " + hopID + "): " + costVector2 + " -> " + max);
            }
        } else if (planPartition.getExtConsumed().contains(Long.valueOf(hop.getHopID()))) {
            d2 += rGetPlanCosts(cPlanMemoTable, hop, hashSet, planPartition, interestingPointArr, zArr, hashMap, null, null, d - d2);
            if (d2 >= d) {
                return Double.POSITIVE_INFINITY;
            }
        }
        if (d2 < DataExpression.DEFAULT_DELIM_FILL_VALUE || Double.isNaN(d2) || Double.isInfinite(d2)) {
            throw new RuntimeException("Wrong cost estimate: " + d2);
        }
        return d2;
    }

    private static void getComputeCosts(Hop hop, HashMap<Long, Double> hashMap) {
        double d = 1.0d;
        if (!(hop instanceof UnaryOp)) {
            if (!(hop instanceof BinaryOp)) {
                if (!(hop instanceof TernaryOp)) {
                    if (!(hop instanceof NaryOp)) {
                        if (!(hop instanceof ParameterizedBuiltinOp)) {
                            if (!(hop instanceof IndexingOp)) {
                                if (!(hop instanceof ReorgOp)) {
                                    if (!(hop instanceof DnnOp)) {
                                        if (!(hop instanceof AggBinaryOp)) {
                                            if (hop instanceof AggUnaryOp) {
                                                switch (((AggUnaryOp) hop).getOp()) {
                                                    case SUM:
                                                        d = 4.0d;
                                                        break;
                                                    case SUM_SQ:
                                                        d = 5.0d;
                                                        break;
                                                    case MIN:
                                                    case MAX:
                                                        d = 1.0d;
                                                        break;
                                                    default:
                                                        LOG.warn("Cost model not implemented yet for: " + ((AggUnaryOp) hop).getOp());
                                                        break;
                                                }
                                                switch (((AggUnaryOp) hop).getDirection()) {
                                                    case Col:
                                                        d *= Math.max(hop.getInput().get(0).getDim1(), 1L);
                                                        break;
                                                    case Row:
                                                        d *= Math.max(hop.getInput().get(0).getDim2(), 1L);
                                                        break;
                                                    case RowCol:
                                                        d *= getSize(hop.getInput().get(0));
                                                        break;
                                                }
                                            }
                                        } else {
                                            d = 2 * hop.getInput().get(0).getDim2();
                                            if (hop.getInput().get(0).dimsKnown(true)) {
                                                d *= hop.getInput().get(0).getSparsity();
                                            }
                                        }
                                    } else {
                                        switch (((DnnOp) hop).getOp()) {
                                            case BIASADD:
                                            case BIASMULT:
                                                d = 2.0d;
                                                break;
                                        }
                                        LOG.warn("Cost model not implemented yet for: " + ((DnnOp) hop).getOp());
                                    }
                                } else {
                                    d = 1.0d;
                                }
                            } else {
                                d = 1.0d;
                            }
                        } else {
                            d = 1.0d;
                        }
                    } else {
                        d = HopRewriteUtils.isNary(hop, Types.OpOpN.MIN, Types.OpOpN.MAX, Types.OpOpN.PLUS) ? hop.getInput().size() : 1.0d;
                    }
                } else {
                    switch (((TernaryOp) hop).getOp()) {
                        case IFELSE:
                        case PLUS_MULT:
                        case MINUS_MULT:
                            d = 2.0d;
                            break;
                        case CTABLE:
                            d = 3.0d;
                            break;
                        case MOMENT:
                            switch ((int) (hop.getInput().get(1) instanceof LiteralOp ? HopRewriteUtils.getIntValueSafe((LiteralOp) hop.getInput().get(1)) : 2L)) {
                                case 0:
                                    d = 2.0d;
                                    break;
                                case 1:
                                    d = 9.0d;
                                    break;
                                case 2:
                                    d = 17.0d;
                                    break;
                                case 3:
                                    d = 32.0d;
                                    break;
                                case 4:
                                    d = 52.0d;
                                    break;
                                case 5:
                                    d = 17.0d;
                                    break;
                            }
                        case COV:
                            d = 23.0d;
                            break;
                        default:
                            LOG.warn("Cost model not implemented yet for: " + ((TernaryOp) hop).getOp());
                            break;
                    }
                }
            } else {
                switch (((BinaryOp) hop).getOp()) {
                    case MULT:
                    case PLUS:
                    case MINUS:
                    case MIN:
                    case MAX:
                    case AND:
                    case OR:
                    case EQUAL:
                    case NOTEQUAL:
                    case LESS:
                    case LESSEQUAL:
                    case GREATER:
                    case GREATEREQUAL:
                    case CBIND:
                    case RBIND:
                        d = 1.0d;
                        break;
                    case INTDIV:
                        d = 6.0d;
                        break;
                    case MODULUS:
                        d = 8.0d;
                        break;
                    case DIV:
                        d = 22.0d;
                        break;
                    case LOG:
                    case LOG_NZ:
                        d = 32.0d;
                        break;
                    case POW:
                        d = HopRewriteUtils.isLiteralOfValue(hop.getInput().get(1), 2.0d) ? 1 : 16;
                        break;
                    case MINUS_NZ:
                    case MINUS1_MULT:
                        d = 2.0d;
                        break;
                    case MOMENT:
                        switch ((int) (hop.getInput().get(1) instanceof LiteralOp ? HopRewriteUtils.getIntValueSafe((LiteralOp) hop.getInput().get(1)) : 2L)) {
                            case 0:
                                d = 1.0d;
                                break;
                            case 1:
                                d = 8.0d;
                                break;
                            case 2:
                                d = 16.0d;
                                break;
                            case 3:
                                d = 31.0d;
                                break;
                            case 4:
                                d = 51.0d;
                                break;
                            case 5:
                                d = 16.0d;
                                break;
                        }
                    case COV:
                        d = 23.0d;
                        break;
                    default:
                        LOG.warn("Cost model not implemented yet for: " + ((BinaryOp) hop).getOp());
                        break;
                }
            }
        } else {
            switch (((UnaryOp) hop).getOp()) {
                case ABS:
                case ROUND:
                case CEIL:
                case FLOOR:
                case SIGN:
                    d = 1.0d;
                    break;
                case SPROP:
                case SQRT:
                    d = 2.0d;
                    break;
                case EXP:
                    d = 18.0d;
                    break;
                case SIGMOID:
                    d = 21.0d;
                    break;
                case LOG:
                case LOG_NZ:
                    d = 32.0d;
                    break;
                case NCOL:
                case NROW:
                case PRINT:
                case ASSERT:
                case CAST_AS_BOOLEAN:
                case CAST_AS_DOUBLE:
                case CAST_AS_INT:
                case CAST_AS_MATRIX:
                case CAST_AS_SCALAR:
                    d = 1.0d;
                    break;
                case SIN:
                    d = 18.0d;
                    break;
                case COS:
                    d = 22.0d;
                    break;
                case TAN:
                    d = 42.0d;
                    break;
                case ASIN:
                    d = 93.0d;
                    break;
                case ACOS:
                    d = 103.0d;
                    break;
                case ATAN:
                    d = 40.0d;
                    break;
                case SINH:
                    d = 93.0d;
                    break;
                case COSH:
                    d = 103.0d;
                    break;
                case TANH:
                    d = 40.0d;
                    break;
                case CUMSUM:
                case CUMMIN:
                case CUMMAX:
                case CUMPROD:
                    d = 1.0d;
                    break;
                case CUMSUMPROD:
                    d = 2.0d;
                    break;
                default:
                    LOG.warn("Cost model not implemented yet for: " + ((UnaryOp) hop).getOp());
                    break;
            }
        }
        hashMap.put(Long.valueOf(hop.getHopID()), Double.valueOf(d * getSize(hop)));
    }

    private static boolean hasNoRefToMatPoint(long j, CPlanMemoTable.MemoTableEntry memoTableEntry, InterestingPoint[] interestingPointArr, boolean[] zArr) {
        return !InterestingPoint.isMatPoint(interestingPointArr, j, memoTableEntry, zArr);
    }

    private static boolean isImplicitlyFused(Hop hop, int i, TemplateBase.TemplateType templateType) {
        return templateType == TemplateBase.TemplateType.ROW && HopRewriteUtils.isMatrixMultiply(hop) && i == 0 && HopRewriteUtils.isTransposeOperation(hop.getInput().get(i));
    }

    private static boolean probePlanCache(InterestingPoint[] interestingPointArr) {
        return interestingPointArr.length >= 10;
    }

    private static boolean[] getPlan(PartitionSignature partitionSignature) {
        boolean[] zArr;
        synchronized (_planCache) {
            zArr = _planCache.get(partitionSignature);
        }
        if (DMLScript.STATISTICS) {
            if (zArr != null) {
                Statistics.incrementCodegenPlanCacheHits();
            }
            Statistics.incrementCodegenPlanCacheTotal();
        }
        return zArr;
    }

    private static void putPlan(PartitionSignature partitionSignature, boolean[] zArr) {
        synchronized (_planCache) {
            if (_planCache.size() >= 1024) {
                Iterator<Map.Entry<PartitionSignature, boolean[]>> it = _planCache.entrySet().iterator();
                it.next();
                it.remove();
            }
            _planCache.put(partitionSignature, zArr);
        }
    }
}
