package org.apache.sysds.hops.codegen.opt;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.AggUnaryOp;
import org.apache.sysds.hops.BinaryOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.IndexingOp;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.ParameterizedBuiltinOp;
import org.apache.sysds.hops.ReorgOp;
import org.apache.sysds.hops.TernaryOp;
import org.apache.sysds.hops.UnaryOp;
import org.apache.sysds.hops.codegen.opt.PlanSelection;
import org.apache.sysds.hops.codegen.template.CPlanMemoTable;
import org.apache.sysds.hops.codegen.template.TemplateBase;
import org.apache.sysds.hops.codegen.template.TemplateOuterProduct;
import org.apache.sysds.hops.codegen.template.TemplateRow;
import org.apache.sysds.hops.codegen.template.TemplateUtils;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysds.runtime.util.CollectionUtils;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.utils.Statistics;

/* loaded from: input_file:org/apache/sysds/hops/codegen/opt/PlanSelectionFuseCostBased.class */
public class PlanSelectionFuseCostBased extends PlanSelection {
    private static final double READ_BANDWIDTH = 3.4359738368E10d;
    private static final Log LOG = LogFactory.getLog(PlanSelectionFuseCostBased.class.getName());
    private static final double WRITE_BANDWIDTH = 2.147483648E9d;
    private static final double COMPUTE_BANDWIDTH = WRITE_BANDWIDTH * InfrastructureAnalyzer.getLocalParallelism();
    private static final IDSequence COST_ID = new IDSequence();
    private static final TemplateRow ROW_TPL = new TemplateRow();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/hops/codegen/opt/PlanSelectionFuseCostBased$AggregateInfo.class */
    public static class AggregateInfo {
        public final HashSet<Long> _inputAggs = new HashSet<>();
        public final HashSet<Long> _fusedInputs = new HashSet<>();
        public final HashMap<Long, Hop> _aggregates = new HashMap<>();

        public AggregateInfo(Hop hop) {
            this._aggregates.put(Long.valueOf(hop.getHopID()), hop);
        }

        public void addInputAggregate(long j) {
            this._inputAggs.add(Long.valueOf(j));
        }

        public void addFusedInput(long j) {
            this._fusedInputs.add(Long.valueOf(j));
        }

        public boolean isMergable(AggregateInfo aggregateInfo) {
            boolean z = this._aggregates.size() < 3 && this._aggregates.size() + aggregateInfo._aggregates.size() <= 3;
            Iterator<Long> it = aggregateInfo._aggregates.keySet().iterator();
            while (it.hasNext()) {
                z &= !this._inputAggs.contains(it.next());
            }
            Iterator<Long> it2 = this._aggregates.keySet().iterator();
            while (it2.hasNext()) {
                z &= !aggregateInfo._inputAggs.contains(it2.next());
            }
            boolean containsAny = z & CollectionUtils.containsAny(this._fusedInputs, aggregateInfo._fusedInputs);
            Hop next = this._aggregates.values().iterator().next();
            Hop next2 = aggregateInfo._aggregates.values().iterator().next();
            if (containsAny) {
                if (HopRewriteUtils.isEqualSize(next.getInput().get(HopRewriteUtils.isMatrixMultiply(next) ? 1 : 0), next2.getInput().get(HopRewriteUtils.isMatrixMultiply(next2) ? 1 : 0))) {
                    return true;
                }
            }
            return false;
        }

        public AggregateInfo merge(AggregateInfo aggregateInfo) {
            this._aggregates.putAll(aggregateInfo._aggregates);
            this._inputAggs.addAll(aggregateInfo._inputAggs);
            this._fusedInputs.addAll(aggregateInfo._fusedInputs);
            return this;
        }

        public String toString() {
            return "[" + Arrays.toString(this._aggregates.keySet().toArray(new Long[0])) + ": {" + Arrays.toString(this._inputAggs.toArray(new Long[0])) + "},{" + Arrays.toString(this._fusedInputs.toArray(new Long[0])) + "}]";
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/hops/codegen/opt/PlanSelectionFuseCostBased$CostVector.class */
    public static class CostVector {
        public final double outSize;
        public double computeCosts = DataExpression.DEFAULT_DELIM_FILL_VALUE;
        public final HashMap<Long, Double> inSizes = new HashMap<>();
        public final long ID = PlanSelectionFuseCostBased.COST_ID.getNextID();

        public CostVector(double d) {
            this.outSize = d;
        }

        public void addInputSize(long j, double d) {
            this.inSizes.put(Long.valueOf(j), Double.valueOf(d));
        }

        public double getSumInputSizes() {
            return this.inSizes.values().stream().mapToDouble(d -> {
                return d.doubleValue();
            }).sum();
        }

        public double getMaxInputSize() {
            return this.inSizes.values().stream().mapToDouble(d -> {
                return d.doubleValue();
            }).max().orElse(DataExpression.DEFAULT_DELIM_FILL_VALUE);
        }

        public String toString() {
            return "[" + this.outSize + ", " + this.computeCosts + ", {" + Arrays.toString(this.inSizes.keySet().toArray(new Long[0])) + ", " + Arrays.toString(this.inSizes.values().toArray(new Double[0])) + "}]";
        }
    }

    @Override // org.apache.sysds.hops.codegen.opt.PlanSelection
    public void selectPlans(CPlanMemoTable cPlanMemoTable, ArrayList<Hop> arrayList) {
        int i = 0;
        for (PlanPartition planPartition : PlanAnalyzer.analyzePlanPartitions(cPlanMemoTable, arrayList, false)) {
            createAndAddMultiAggPlans(cPlanMemoTable, planPartition.getPartition(), planPartition.getRoots());
            selectPlans(cPlanMemoTable, planPartition.getPartition(), planPartition.getRoots(), planPartition.getMatPoints());
            i += planPartition.getMatPoints().size();
        }
        createAndAddMultiAggPlans(cPlanMemoTable, arrayList);
        for (Map.Entry<Long, List<CPlanMemoTable.MemoTableEntry>> entry : getBestPlans().entrySet()) {
            cPlanMemoTable.setDistinct(entry.getKey().longValue(), entry.getValue());
        }
        if (DMLScript.STATISTICS) {
            Statistics.incrementCodegenEnumAll(UtilFunctions.pow(2, i));
        }
    }

    private static void createAndAddMultiAggPlans(CPlanMemoTable cPlanMemoTable, HashSet<Long> hashSet, HashSet<Long> hashSet2) {
        HashSet hashSet3 = new HashSet();
        for (Map.Entry<Long, List<CPlanMemoTable.MemoTableEntry>> entry : cPlanMemoTable.getPlans().entrySet()) {
            if (!entry.getValue().isEmpty()) {
                Iterator<Hop> it = cPlanMemoTable.getHopRefs().get(entry.getKey()).getInput().iterator();
                while (it.hasNext()) {
                    hashSet3.add(Long.valueOf(it.next().getHopID()));
                }
            }
        }
        ArrayList arrayList = new ArrayList();
        Iterator<Long> it2 = hashSet2.iterator();
        while (it2.hasNext()) {
            Long next = it2.next();
            Hop hop = cPlanMemoTable.getHopRefs().get(next);
            if (!hashSet3.contains(next) && isMultiAggregateRoot(hop)) {
                arrayList.add(next);
            }
        }
        if (LOG.isTraceEnabled()) {
            LOG.trace("Found within-partition ua(RC) aggregations: " + Arrays.toString(arrayList.toArray(new Long[0])));
        }
        for (int i = 0; i < arrayList.size(); i += 3) {
            int min = Math.min(i + 3, arrayList.size());
            if (min - i >= 2) {
                CPlanMemoTable.MemoTableEntry memoTableEntry = new CPlanMemoTable.MemoTableEntry(TemplateBase.TemplateType.MAGG, ((Long) arrayList.get(i)).longValue(), ((Long) arrayList.get(i + 1)).longValue(), min - i == 3 ? ((Long) arrayList.get(i + 2)).longValue() : -1L, min - i);
                if (isValidMultiAggregate(cPlanMemoTable, memoTableEntry)) {
                    for (int i2 = i; i2 < min; i2++) {
                        cPlanMemoTable.add(cPlanMemoTable.getHopRefs().get(arrayList.get(i2)), memoTableEntry);
                        if (LOG.isTraceEnabled()) {
                            LOG.trace("Added multiagg plan: " + arrayList.get(i2) + " " + memoTableEntry);
                        }
                    }
                } else if (LOG.isTraceEnabled()) {
                    LOG.trace("Removed invalid multiagg plan: " + memoTableEntry);
                }
            }
        }
    }

    private void createAndAddMultiAggPlans(CPlanMemoTable cPlanMemoTable, ArrayList<Hop> arrayList) {
        HashSet hashSet = new HashSet();
        Hop.resetVisitStatus(arrayList);
        Iterator<Hop> it = arrayList.iterator();
        while (it.hasNext()) {
            rCollectFullAggregates(it.next(), hashSet);
        }
        Hop.resetVisitStatus(arrayList);
        hashSet.removeIf(l -> {
            return cPlanMemoTable.contains(l.longValue(), TemplateBase.TemplateType.MAGG);
        });
        if (hashSet.size() <= 1) {
            return;
        }
        if (LOG.isTraceEnabled()) {
            LOG.trace("Found across-partition ua(RC) aggregations: " + Arrays.toString(hashSet.toArray(new Long[0])));
        }
        ArrayList arrayList2 = new ArrayList();
        Iterator it2 = hashSet.iterator();
        while (it2.hasNext()) {
            Hop hop = cPlanMemoTable.getHopRefs().get((Long) it2.next());
            AggregateInfo aggregateInfo = new AggregateInfo(hop);
            int i = 0;
            while (i < hop.getInput().size()) {
                rExtractAggregateInfo(cPlanMemoTable, (HopRewriteUtils.isMatrixMultiply(hop) && i == 0) ? hop.getInput().get(0).getInput().get(0) : hop.getInput().get(i), aggregateInfo, TemplateBase.TemplateType.CELL);
                i++;
            }
            if (aggregateInfo._fusedInputs.isEmpty()) {
                if (HopRewriteUtils.isMatrixMultiply(hop)) {
                    aggregateInfo.addFusedInput(hop.getInput().get(0).getInput().get(0).getHopID());
                    aggregateInfo.addFusedInput(hop.getInput().get(1).getHopID());
                } else {
                    aggregateInfo.addFusedInput(hop.getInput().get(0).getHopID());
                }
            }
            arrayList2.add(aggregateInfo);
        }
        if (LOG.isTraceEnabled()) {
            LOG.trace("Extracted across-partition ua(RC) aggregation info: ");
            Iterator it3 = arrayList2.iterator();
            while (it3.hasNext()) {
                LOG.trace((AggregateInfo) it3.next());
            }
        }
        List<AggregateInfo> list = (List) arrayList2.stream().sorted(Comparator.comparing(aggregateInfo2 -> {
            return Integer.valueOf(aggregateInfo2._inputAggs.size());
        })).collect(Collectors.toList());
        boolean z = false;
        while (!z) {
            AggregateInfo aggregateInfo3 = null;
            for (int i2 = 0; i2 < list.size(); i2++) {
                AggregateInfo aggregateInfo4 = (AggregateInfo) list.get(i2);
                int i3 = i2 + 1;
                while (i3 < list.size()) {
                    AggregateInfo aggregateInfo5 = (AggregateInfo) list.get(i3);
                    if (aggregateInfo4.isMergable(aggregateInfo5)) {
                        aggregateInfo3 = aggregateInfo4.merge(aggregateInfo5);
                        list.remove(i3);
                        i3--;
                    }
                    i3++;
                }
            }
            z = aggregateInfo3 == null;
        }
        if (LOG.isTraceEnabled()) {
            LOG.trace("Merged across-partition ua(RC) aggregation info: ");
            Iterator it4 = list.iterator();
            while (it4.hasNext()) {
                LOG.trace((AggregateInfo) it4.next());
            }
        }
        for (AggregateInfo aggregateInfo6 : list) {
            if (aggregateInfo6._aggregates.size() > 1) {
                Long[] lArr = (Long[]) aggregateInfo6._aggregates.keySet().toArray(new Long[0]);
                CPlanMemoTable.MemoTableEntry memoTableEntry = new CPlanMemoTable.MemoTableEntry(TemplateBase.TemplateType.MAGG, lArr[0].longValue(), lArr[1].longValue(), lArr.length > 2 ? lArr[2].longValue() : -1L, lArr.length);
                for (int i4 = 0; i4 < lArr.length; i4++) {
                    cPlanMemoTable.add(cPlanMemoTable.getHopRefs().get(lArr[i4]), memoTableEntry);
                    addBestPlan(lArr[i4].longValue(), memoTableEntry);
                    if (LOG.isTraceEnabled()) {
                        LOG.trace("Added multiagg* plan: " + lArr[i4] + " " + memoTableEntry);
                    }
                }
            }
        }
    }

    private static boolean isMultiAggregateRoot(Hop hop) {
        return (HopRewriteUtils.isAggUnaryOp(hop, Types.AggOp.SUM, Types.AggOp.SUM_SQ, Types.AggOp.MIN, Types.AggOp.MAX) && ((AggUnaryOp) hop).getDirection() == Types.Direction.RowCol) || ((hop instanceof AggBinaryOp) && hop.getDim1() == 1 && hop.getDim2() == 1 && HopRewriteUtils.isTransposeOperation(hop.getInput().get(0)));
    }

    private static boolean isValidMultiAggregate(CPlanMemoTable cPlanMemoTable, CPlanMemoTable.MemoTableEntry memoTableEntry) {
        boolean z = true;
        Hop hop = cPlanMemoTable.getHopRefs().get(Long.valueOf(memoTableEntry.input1)).getInput().get(0);
        for (int i = 1; z && i < 3; i++) {
            if (memoTableEntry.isPlanRef(i)) {
                z &= HopRewriteUtils.isEqualSize(hop, cPlanMemoTable.getHopRefs().get(Long.valueOf(memoTableEntry.input(i))).getInput().get(0));
            }
        }
        for (int i2 = 0; z && i2 < 3; i2++) {
            if (memoTableEntry.isPlanRef(i2)) {
                HashSet hashSet = new HashSet();
                for (int i3 = 0; i3 < 3; i3++) {
                    if (i2 != i3) {
                        hashSet.add(Long.valueOf(memoTableEntry.input(i3)));
                    }
                }
                z &= rCheckMultiAggregate(cPlanMemoTable.getHopRefs().get(Long.valueOf(memoTableEntry.input(i2))), hashSet);
            }
        }
        return z;
    }

    private static boolean rCheckMultiAggregate(Hop hop, HashSet<Long> hashSet) {
        boolean z = true;
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            z &= rCheckMultiAggregate(it.next(), hashSet);
        }
        return z & (!hashSet.contains(Long.valueOf(hop.getHopID())));
    }

    private static void rCollectFullAggregates(Hop hop, HashSet<Long> hashSet) {
        if (hop.isVisited()) {
            return;
        }
        if (isMultiAggregateRoot(hop)) {
            hashSet.add(Long.valueOf(hop.getHopID()));
        }
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            rCollectFullAggregates(it.next(), hashSet);
        }
        hop.setVisited();
    }

    private static void rExtractAggregateInfo(CPlanMemoTable cPlanMemoTable, Hop hop, AggregateInfo aggregateInfo, TemplateBase.TemplateType templateType) {
        if (isMultiAggregateRoot(hop)) {
            aggregateInfo.addInputAggregate(hop.getHopID());
        }
        CPlanMemoTable.MemoTableEntry best = templateType != null ? cPlanMemoTable.getBest(hop.getHopID()) : null;
        for (int i = 0; i < hop.getInput().size(); i++) {
            Hop hop2 = hop.getInput().get(i);
            if (best == null || !best.isPlanRef(i)) {
                if (templateType != null && hop2.getDataType().isMatrix()) {
                    aggregateInfo.addFusedInput(hop2.getHopID());
                }
                rExtractAggregateInfo(cPlanMemoTable, hop2, aggregateInfo, null);
            } else {
                rExtractAggregateInfo(cPlanMemoTable, hop2, aggregateInfo, templateType);
            }
        }
    }

    private void selectPlans(CPlanMemoTable cPlanMemoTable, HashSet<Long> hashSet, HashSet<Long> hashSet2, ArrayList<Long> arrayList) {
        Iterator<Long> it = hashSet2.iterator();
        while (it.hasNext()) {
            Long next = it.next();
            if (cPlanMemoTable.getBest(next.longValue(), TemplateBase.TemplateType.ROW).type == TemplateBase.TemplateType.ROW && cPlanMemoTable.contains(next.longValue(), TemplateBase.TemplateType.CELL) && isRowTemplateWithoutAgg(cPlanMemoTable, cPlanMemoTable.getHopRefs().get(next), new HashSet())) {
                List<CPlanMemoTable.MemoTableEntry> list = cPlanMemoTable.get(next.longValue(), TemplateBase.TemplateType.ROW);
                cPlanMemoTable.remove(cPlanMemoTable.getHopRefs().get(next), new HashSet(list));
                if (LOG.isTraceEnabled()) {
                    LOG.trace("Removed row memo table entries w/o aggregation: " + Arrays.toString(list.toArray(new CPlanMemoTable.MemoTableEntry[0])));
                }
            }
        }
        Iterator<Long> it2 = hashSet.iterator();
        while (it2.hasNext()) {
            Long next2 = it2.next();
            if (cPlanMemoTable.countEntries(next2.longValue(), TemplateBase.TemplateType.OUTER) == 2) {
                List<CPlanMemoTable.MemoTableEntry> list2 = cPlanMemoTable.get(next2.longValue(), TemplateBase.TemplateType.OUTER);
                CPlanMemoTable.MemoTableEntry dropAlternativePlan = TemplateOuterProduct.dropAlternativePlan(cPlanMemoTable, list2.get(0), list2.get(1));
                if (dropAlternativePlan != null) {
                    cPlanMemoTable.remove(cPlanMemoTable.getHopRefs().get(next2), Collections.singleton(dropAlternativePlan));
                    cPlanMemoTable.getPlansExcludeListed().remove(Long.valueOf(dropAlternativePlan.input(dropAlternativePlan.getPlanRefIndex())));
                    if (LOG.isTraceEnabled()) {
                        LOG.trace("Removed dominated outer product memo table entry: " + dropAlternativePlan);
                    }
                }
            }
        }
        if (arrayList == null || arrayList.isEmpty()) {
            Iterator<Long> it3 = hashSet2.iterator();
            while (it3.hasNext()) {
                rSelectPlansFuseAll(cPlanMemoTable, cPlanMemoTable.getHopRefs().get(it3.next()), null, hashSet);
            }
            return;
        }
        HashMap hashMap = new HashMap();
        Iterator<Long> it4 = hashSet2.iterator();
        while (it4.hasNext()) {
            rGetComputeCosts(cPlanMemoTable.getHopRefs().get(it4.next()), hashSet, hashMap);
        }
        int pow = (int) Math.pow(2.0d, arrayList.size());
        boolean[] zArr = null;
        double d = Double.MAX_VALUE;
        for (int i = 0; i < pow; i++) {
            boolean[] createAssignment = createAssignment(arrayList.size(), i);
            double planCost = getPlanCost(cPlanMemoTable, hashSet, hashSet2, arrayList, createAssignment, hashMap);
            if (LOG.isTraceEnabled()) {
                LOG.trace("Enum: " + Arrays.toString(createAssignment) + " -> " + planCost);
            }
            if (zArr == null || planCost < d) {
                d = planCost;
                zArr = createAssignment;
                if (LOG.isTraceEnabled()) {
                    LOG.trace("Enum: Found new best plan.");
                }
            }
        }
        if (DMLScript.STATISTICS) {
            Statistics.incrementCodegenEnumAllP(pow);
            Statistics.incrementCodegenEnumEval(pow);
        }
        HashSet hashSet3 = new HashSet();
        Iterator<Long> it5 = hashSet2.iterator();
        while (it5.hasNext()) {
            rPruneSuboptimalPlans(cPlanMemoTable, cPlanMemoTable.getHopRefs().get(it5.next()), hashSet3, hashSet, arrayList, zArr);
        }
        HashSet hashSet4 = new HashSet();
        Iterator<Long> it6 = hashSet2.iterator();
        while (it6.hasNext()) {
            rPruneInvalidPlans(cPlanMemoTable, cPlanMemoTable.getHopRefs().get(it6.next()), hashSet4, hashSet, arrayList, zArr);
        }
        Iterator<Long> it7 = hashSet2.iterator();
        while (it7.hasNext()) {
            rSelectPlansFuseAll(cPlanMemoTable, cPlanMemoTable.getHopRefs().get(it7.next()), null, hashSet);
        }
    }

    private static boolean isRowTemplateWithoutAgg(CPlanMemoTable cPlanMemoTable, Hop hop, HashSet<Long> hashSet) {
        CPlanMemoTable.MemoTableEntry best = cPlanMemoTable.getBest(hop.getHopID(), TemplateBase.TemplateType.ROW);
        boolean z = true;
        for (int i = 0; i < 3; i++) {
            if (best.isPlanRef(i)) {
                z &= rIsRowTemplateWithoutAgg(cPlanMemoTable, hop.getInput().get(i), hashSet);
            }
        }
        return z;
    }

    private static boolean rIsRowTemplateWithoutAgg(CPlanMemoTable cPlanMemoTable, Hop hop, HashSet<Long> hashSet) {
        if (hashSet.contains(Long.valueOf(hop.getHopID()))) {
            return true;
        }
        boolean z = true;
        CPlanMemoTable.MemoTableEntry best = cPlanMemoTable.getBest(hop.getHopID(), TemplateBase.TemplateType.ROW);
        for (int i = 0; i < 3; i++) {
            if (best.isPlanRef(i)) {
                z &= rIsRowTemplateWithoutAgg(cPlanMemoTable, hop.getInput().get(i), hashSet);
            }
        }
        boolean z2 = z & (((hop instanceof AggUnaryOp) || (hop instanceof AggBinaryOp)) ? false : true);
        hashSet.add(Long.valueOf(hop.getHopID()));
        return z2;
    }

    private static void rPruneSuboptimalPlans(CPlanMemoTable cPlanMemoTable, Hop hop, HashSet<Long> hashSet, HashSet<Long> hashSet2, ArrayList<Long> arrayList, boolean[] zArr) {
        if (hashSet.contains(Long.valueOf(hop.getHopID()))) {
            return;
        }
        long hopID = hop.getHopID();
        if (hashSet2.contains(Long.valueOf(hopID)) && cPlanMemoTable.contains(hopID)) {
            Iterator<CPlanMemoTable.MemoTableEntry> it = cPlanMemoTable.get(hopID).iterator();
            while (it.hasNext()) {
                CPlanMemoTable.MemoTableEntry next = it.next();
                if (!hasNoRefToMaterialization(next, arrayList, zArr) && next.type != TemplateBase.TemplateType.OUTER) {
                    it.remove();
                    if (LOG.isTraceEnabled()) {
                        LOG.trace("Removed memo table entry: " + next);
                    }
                }
            }
        }
        Iterator<Hop> it2 = hop.getInput().iterator();
        while (it2.hasNext()) {
            rPruneSuboptimalPlans(cPlanMemoTable, it2.next(), hashSet, hashSet2, arrayList, zArr);
        }
        hashSet.add(Long.valueOf(hop.getHopID()));
    }

    private static void rPruneInvalidPlans(CPlanMemoTable cPlanMemoTable, Hop hop, HashSet<Long> hashSet, HashSet<Long> hashSet2, ArrayList<Long> arrayList, boolean[] zArr) {
        if (hashSet.contains(Long.valueOf(hop.getHopID()))) {
            return;
        }
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            rPruneInvalidPlans(cPlanMemoTable, it.next(), hashSet, hashSet2, arrayList, zArr);
        }
        long hopID = hop.getHopID();
        if (hashSet2.contains(Long.valueOf(hopID)) && cPlanMemoTable.contains(hopID, TemplateBase.TemplateType.ROW)) {
            for (CPlanMemoTable.MemoTableEntry memoTableEntry : cPlanMemoTable.get(hopID)) {
                if (memoTableEntry.type == TemplateBase.TemplateType.ROW) {
                    if (!memoTableEntry.hasPlanRef() && !TemplateUtils.hasMatrixInput(hop)) {
                        memoTableEntry.type = TemplateBase.TemplateType.CELL;
                        if (LOG.isTraceEnabled()) {
                            LOG.trace("Converted leaf memo table entry from row to cell: " + memoTableEntry);
                        }
                    }
                    if (memoTableEntry.hasPlanRef() && !ROW_TPL.open(hop)) {
                        boolean z = false;
                        for (int i = 0; i < 3; i++) {
                            if (memoTableEntry.isPlanRef(i)) {
                                z |= cPlanMemoTable.contains(memoTableEntry.input(i), TemplateBase.TemplateType.ROW);
                            }
                        }
                        if (!z) {
                            memoTableEntry.type = TemplateBase.TemplateType.CELL;
                            if (LOG.isTraceEnabled()) {
                                LOG.trace("Converted inner memo table entry from row to cell: " + memoTableEntry);
                            }
                        }
                    }
                }
            }
        }
        hashSet.add(Long.valueOf(hop.getHopID()));
    }

    private static boolean[] createAssignment(int i, int i2) {
        boolean[] zArr = new boolean[i];
        int i3 = i2;
        for (int i4 = 0; i4 < i; i4++) {
            zArr[i4] = i3 < ((int) Math.pow(2.0d, (double) ((i - i4) - 1)));
            i3 = (int) (i3 % Math.pow(2.0d, (i - i4) - 1));
        }
        return zArr;
    }

    private static double getPlanCost(CPlanMemoTable cPlanMemoTable, HashSet<Long> hashSet, HashSet<Long> hashSet2, ArrayList<Long> arrayList, boolean[] zArr, HashMap<Long, Double> hashMap) {
        HashSet hashSet3 = new HashSet();
        double d = 0.0d;
        Iterator<Long> it = hashSet2.iterator();
        while (it.hasNext()) {
            d += rGetPlanCosts(cPlanMemoTable, cPlanMemoTable.getHopRefs().get(it.next()), hashSet3, hashSet, arrayList, zArr, hashMap, null, null);
        }
        return d;
    }

    private static double rGetPlanCosts(CPlanMemoTable cPlanMemoTable, Hop hop, HashSet<Pair<Long, Long>> hashSet, HashSet<Long> hashSet2, ArrayList<Long> arrayList, boolean[] zArr, HashMap<Long, Double> hashMap, CostVector costVector, TemplateBase.TemplateType templateType) {
        Pair<Long, Long> of = Pair.of(Long.valueOf(hop.getHopID()), Long.valueOf(costVector == null ? 0L : costVector.ID));
        if (hashSet.contains(of)) {
            return DataExpression.DEFAULT_DELIM_FILL_VALUE;
        }
        hashSet.add(of);
        CPlanMemoTable.MemoTableEntry memoTableEntry = null;
        boolean z = false;
        if (cPlanMemoTable.contains(hop.getHopID())) {
            if (templateType == null) {
                memoTableEntry = cPlanMemoTable.get(hop.getHopID()).stream().filter(memoTableEntry2 -> {
                    return memoTableEntry2.isValid();
                }).filter(memoTableEntry3 -> {
                    return hasNoRefToMaterialization(memoTableEntry3, arrayList, zArr);
                }).min(new PlanSelection.BasicPlanComparator()).orElse(null);
                z = true;
            } else {
                memoTableEntry = cPlanMemoTable.get(hop.getHopID()).stream().filter(memoTableEntry4 -> {
                    return memoTableEntry4.type == templateType || memoTableEntry4.type == TemplateBase.TemplateType.CELL;
                }).filter(memoTableEntry5 -> {
                    return hasNoRefToMaterialization(memoTableEntry5, arrayList, zArr);
                }).min(Comparator.comparing(memoTableEntry6 -> {
                    return Integer.valueOf((7 - (memoTableEntry6.type == templateType ? 4 : 0)) - memoTableEntry6.countPlanRefs());
                })).orElse(null);
            }
        }
        CostVector costVector2 = !z ? costVector : new CostVector(Math.max(hop.getDim1(), 1L) * Math.max(hop.getDim2(), 1L));
        if (hashSet2.contains(Long.valueOf(hop.getHopID()))) {
            costVector2.computeCosts += hashMap.get(Long.valueOf(hop.getHopID())).doubleValue();
        }
        double d = 0.0d;
        for (int i = 0; i < hop.getInput().size(); i++) {
            Hop hop2 = hop.getInput().get(i);
            if (memoTableEntry != null && memoTableEntry.isPlanRef(i)) {
                d += rGetPlanCosts(cPlanMemoTable, hop2, hashSet, hashSet2, arrayList, zArr, hashMap, costVector2, memoTableEntry.type);
            } else if (memoTableEntry == null || !isImplicitlyFused(hop, i, memoTableEntry.type)) {
                d += rGetPlanCosts(cPlanMemoTable, hop2, hashSet, hashSet2, arrayList, zArr, hashMap, null, null);
                if (costVector2 != null && hop2.getDataType().isMatrix()) {
                    costVector2.addInputSize(hop2.getHopID(), Math.max(hop2.getDim1(), 1L) * Math.max(hop2.getDim2(), 1L));
                }
            } else {
                costVector2.addInputSize(hop2.getInput().get(0).getHopID(), Math.max(hop2.getDim1(), 1L) * Math.max(hop2.getDim2(), 1L));
            }
        }
        if (hashSet2.contains(Long.valueOf(hop.getHopID()))) {
            if (z) {
                if (LOG.isTraceEnabled()) {
                    LOG.trace("Cost vector for fused operator (hop " + hop.getHopID() + "): " + costVector2);
                }
                d = d + ((costVector2.outSize * 8.0d) / WRITE_BANDWIDTH) + Math.max((costVector2.computeCosts * costVector2.getMaxInputSize()) / COMPUTE_BANDWIDTH, (costVector2.getSumInputSizes() * 8.0d) / READ_BANDWIDTH);
            } else if (hasNonPartitionConsumer(hop, hashSet2)) {
                d += rGetPlanCosts(cPlanMemoTable, hop, hashSet, hashSet2, arrayList, zArr, hashMap, null, null);
            }
        }
        if (d < DataExpression.DEFAULT_DELIM_FILL_VALUE || Double.isNaN(d) || Double.isInfinite(d)) {
            throw new RuntimeException("Wrong cost estimate: " + d);
        }
        return d;
    }

    private static void rGetComputeCosts(Hop hop, HashSet<Long> hashSet, HashMap<Long, Double> hashMap) {
        if (hashMap.containsKey(Long.valueOf(hop.getHopID()))) {
            return;
        }
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            rGetComputeCosts(it.next(), hashSet, hashMap);
        }
        double d = 1.0d;
        if (!(hop instanceof UnaryOp)) {
            if (!(hop instanceof BinaryOp)) {
                if (!(hop instanceof TernaryOp)) {
                    if (!(hop instanceof ParameterizedBuiltinOp)) {
                        if (!(hop instanceof IndexingOp)) {
                            if (!(hop instanceof ReorgOp)) {
                                if (!(hop instanceof AggBinaryOp)) {
                                    if (hop instanceof AggUnaryOp) {
                                        switch (((AggUnaryOp) hop).getOp()) {
                                            case SUM:
                                                d = 4.0d;
                                                break;
                                            case SUM_SQ:
                                                d = 5.0d;
                                                break;
                                            case MIN:
                                            case MAX:
                                                d = 1.0d;
                                                break;
                                            default:
                                                LOG.warn("Cost model not implemented yet for: " + ((AggUnaryOp) hop).getOp());
                                                break;
                                        }
                                    }
                                } else {
                                    d = 2.0d;
                                }
                            } else {
                                d = 1.0d;
                            }
                        } else {
                            d = 1.0d;
                        }
                    } else {
                        d = 1.0d;
                    }
                } else {
                    switch (((TernaryOp) hop).getOp()) {
                        case PLUS_MULT:
                        case MINUS_MULT:
                            d = 2.0d;
                            break;
                        case CTABLE:
                            d = 3.0d;
                            break;
                        case MOMENT:
                            switch ((int) (hop.getInput().get(1) instanceof LiteralOp ? HopRewriteUtils.getIntValueSafe((LiteralOp) hop.getInput().get(1)) : 2L)) {
                                case 0:
                                    d = 2.0d;
                                    break;
                                case 1:
                                    d = 9.0d;
                                    break;
                                case 2:
                                    d = 17.0d;
                                    break;
                                case 3:
                                    d = 32.0d;
                                    break;
                                case 4:
                                    d = 52.0d;
                                    break;
                                case 5:
                                    d = 17.0d;
                                    break;
                            }
                        case COV:
                            d = 23.0d;
                            break;
                        default:
                            LOG.warn("Cost model not implemented yet for: " + ((TernaryOp) hop).getOp());
                            break;
                    }
                }
            } else {
                switch (((BinaryOp) hop).getOp()) {
                    case MULT:
                    case PLUS:
                    case MINUS:
                    case MIN:
                    case MAX:
                    case AND:
                    case OR:
                    case EQUAL:
                    case NOTEQUAL:
                    case LESS:
                    case LESSEQUAL:
                    case GREATER:
                    case GREATEREQUAL:
                    case CBIND:
                    case RBIND:
                        d = 1.0d;
                        break;
                    case INTDIV:
                        d = 6.0d;
                        break;
                    case MODULUS:
                        d = 8.0d;
                        break;
                    case DIV:
                        d = 22.0d;
                        break;
                    case LOG:
                    case LOG_NZ:
                        d = 32.0d;
                        break;
                    case POW:
                        d = HopRewriteUtils.isLiteralOfValue(hop.getInput().get(1), 2.0d) ? 1 : 16;
                        break;
                    case MINUS_NZ:
                    case MINUS1_MULT:
                        d = 2.0d;
                        break;
                    case MOMENT:
                        switch ((int) (hop.getInput().get(1) instanceof LiteralOp ? HopRewriteUtils.getIntValueSafe((LiteralOp) hop.getInput().get(1)) : 2L)) {
                            case 0:
                                d = 1.0d;
                                break;
                            case 1:
                                d = 8.0d;
                                break;
                            case 2:
                                d = 16.0d;
                                break;
                            case 3:
                                d = 31.0d;
                                break;
                            case 4:
                                d = 51.0d;
                                break;
                            case 5:
                                d = 16.0d;
                                break;
                        }
                    case COV:
                        d = 23.0d;
                        break;
                    default:
                        LOG.warn("Cost model not implemented yet for: " + ((BinaryOp) hop).getOp());
                        break;
                }
            }
        } else {
            switch (((UnaryOp) hop).getOp()) {
                case ABS:
                case ROUND:
                case CEIL:
                case FLOOR:
                case SIGN:
                    d = 1.0d;
                    break;
                case SPROP:
                case SQRT:
                    d = 2.0d;
                    break;
                case EXP:
                    d = 18.0d;
                    break;
                case SIGMOID:
                    d = 21.0d;
                    break;
                case LOG:
                case LOG_NZ:
                    d = 32.0d;
                    break;
                case NCOL:
                case NROW:
                case PRINT:
                case ASSERT:
                case CAST_AS_BOOLEAN:
                case CAST_AS_DOUBLE:
                case CAST_AS_INT:
                case CAST_AS_MATRIX:
                case CAST_AS_SCALAR:
                    d = 1.0d;
                    break;
                case SIN:
                    d = 18.0d;
                    break;
                case COS:
                    d = 22.0d;
                    break;
                case TAN:
                    d = 42.0d;
                    break;
                case ASIN:
                    d = 93.0d;
                    break;
                case ACOS:
                    d = 103.0d;
                    break;
                case ATAN:
                    d = 40.0d;
                    break;
                case SINH:
                    d = 93.0d;
                    break;
                case COSH:
                    d = 103.0d;
                    break;
                case TANH:
                    d = 40.0d;
                    break;
                case CUMSUM:
                case CUMMIN:
                case CUMMAX:
                case CUMPROD:
                    d = 1.0d;
                    break;
                case CUMSUMPROD:
                    d = 2.0d;
                    break;
                default:
                    LOG.warn("Cost model not implemented yet for: " + ((UnaryOp) hop).getOp());
                    break;
            }
        }
        hashMap.put(Long.valueOf(hop.getHopID()), Double.valueOf(d));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static boolean hasNoRefToMaterialization(CPlanMemoTable.MemoTableEntry memoTableEntry, ArrayList<Long> arrayList, boolean[] zArr) {
        boolean z = true;
        for (int i = 0; z && i < 3; i++) {
            z &= (arrayList.contains(Long.valueOf(memoTableEntry.input(i))) && zArr[arrayList.indexOf(Long.valueOf(memoTableEntry.input(i)))]) ? false : true;
        }
        return z;
    }

    private static boolean hasNonPartitionConsumer(Hop hop, HashSet<Long> hashSet) {
        boolean z = false;
        Iterator<Hop> it = hop.getParent().iterator();
        while (it.hasNext()) {
            z |= !hashSet.contains(Long.valueOf(it.next().getHopID()));
        }
        return z;
    }

    private static boolean isImplicitlyFused(Hop hop, int i, TemplateBase.TemplateType templateType) {
        return templateType == TemplateBase.TemplateType.ROW && HopRewriteUtils.isMatrixMultiply(hop) && i == 0 && HopRewriteUtils.isTransposeOperation(hop.getInput().get(i));
    }
}
