package org.apache.sysds.runtime.compress.cocode;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Queue;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.compress.CompressionSettings;
import org.apache.sysds.runtime.compress.estim.CompressedSizeEstimator;
import org.apache.sysds.runtime.compress.estim.CompressedSizeEstimatorSample;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;

/* loaded from: input_file:org/apache/sysds/runtime/compress/cocode/CoCodeCostTSMM.class */
public class CoCodeCostTSMM extends AColumnCoCoder {
    /* JADX INFO: Access modifiers changed from: protected */
    public CoCodeCostTSMM(CompressedSizeEstimator compressedSizeEstimator, CompressionSettings compressionSettings) {
        super(compressedSizeEstimator, compressionSettings);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.sysds.runtime.compress.cocode.AColumnCoCoder
    public CompressedSizeInfo coCodeColumns(CompressedSizeInfo compressedSizeInfo, int i) {
        List<CompressedSizeInfoColGroup> join = join(compressedSizeInfo.getInfo());
        if (this._cs.samplingRatio < 0.1d && (this._est instanceof CompressedSizeEstimatorSample)) {
            LOG.debug("Performing second join with double sample rate");
            CompressedSizeEstimatorSample compressedSizeEstimatorSample = (CompressedSizeEstimatorSample) this._est;
            compressedSizeEstimatorSample.sampleData(compressedSizeEstimatorSample.getSample().getNumRows() * 2);
            ArrayList arrayList = new ArrayList(join.size());
            Iterator<CompressedSizeInfoColGroup> it = join.iterator();
            while (it.hasNext()) {
                arrayList.add(it.next().getColumns());
            }
            join = join(compressedSizeEstimatorSample.computeCompressedSizeInfos(arrayList, i));
        }
        compressedSizeInfo.setInfo(join);
        return compressedSizeInfo;
    }

    private List<CompressedSizeInfoColGroup> join(List<CompressedSizeInfoColGroup> list) {
        PriorityQueue priorityQueue = new PriorityQueue(list.size(), Comparator.comparing((v0) -> {
            return v0.getNumVals();
        }));
        ArrayList arrayList = new ArrayList();
        Iterator<CompressedSizeInfoColGroup> it = list.iterator();
        while (it.hasNext()) {
            priorityQueue.add(it.next());
        }
        double cost = getCost(priorityQueue, arrayList);
        while (true) {
            if (priorityQueue.peek() == null) {
                break;
            }
            CompressedSizeInfoColGroup poll = priorityQueue.poll();
            if (priorityQueue.peek() == null) {
                arrayList.add(poll);
                break;
            }
            CompressedSizeInfoColGroup poll2 = priorityQueue.poll();
            CompressedSizeInfoColGroup joinWithAnalysis = joinWithAnalysis(poll, poll2);
            double cost2 = getCost(priorityQueue, arrayList, joinWithAnalysis);
            if (cost2 < cost) {
                cost = cost2;
                priorityQueue.add(joinWithAnalysis);
            } else {
                arrayList.add(poll);
                priorityQueue.add(poll2);
            }
        }
        Iterator<CompressedSizeInfoColGroup> it2 = priorityQueue.iterator();
        while (it2.hasNext()) {
            arrayList.add(it2.next());
        }
        return arrayList;
    }

    private double getCost(Queue<CompressedSizeInfoColGroup> queue, List<CompressedSizeInfoColGroup> list) {
        return getCost((CompressedSizeInfoColGroup[]) queue.toArray(new CompressedSizeInfoColGroup[queue.size()]), list);
    }

    private double getCost(Queue<CompressedSizeInfoColGroup> queue, List<CompressedSizeInfoColGroup> list, CompressedSizeInfoColGroup compressedSizeInfoColGroup) {
        CompressedSizeInfoColGroup[] compressedSizeInfoColGroupArr = (CompressedSizeInfoColGroup[]) queue.toArray(new CompressedSizeInfoColGroup[queue.size()]);
        double cost = getCost(compressedSizeInfoColGroupArr, list) + getCostOfSelfTSMM(compressedSizeInfoColGroup);
        for (CompressedSizeInfoColGroup compressedSizeInfoColGroup2 : compressedSizeInfoColGroupArr) {
            cost += getCostOfLeftTransposedMM(compressedSizeInfoColGroup2, compressedSizeInfoColGroup);
        }
        for (int i = 0; i < list.size(); i++) {
            cost += getCostOfLeftTransposedMM(list.get(i), compressedSizeInfoColGroup);
        }
        return cost;
    }

    private double getCost(CompressedSizeInfoColGroup[] compressedSizeInfoColGroupArr, List<CompressedSizeInfoColGroup> list) {
        double d = 0.0d;
        for (int i = 0; i < compressedSizeInfoColGroupArr.length; i++) {
            d += getCostOfSelfTSMM(compressedSizeInfoColGroupArr[i]);
            for (int i2 = i + 1; i2 < compressedSizeInfoColGroupArr.length; i2++) {
                d += getCostOfLeftTransposedMM(compressedSizeInfoColGroupArr[i], compressedSizeInfoColGroupArr[i2]);
            }
            Iterator<CompressedSizeInfoColGroup> it = list.iterator();
            while (it.hasNext()) {
                d += getCostOfLeftTransposedMM(compressedSizeInfoColGroupArr[i], it.next());
            }
        }
        for (int i3 = 0; i3 < list.size(); i3++) {
            d += getCostOfSelfTSMM(list.get(i3));
            for (int i4 = i3 + 1; i4 < list.size(); i4++) {
                d += getCostOfLeftTransposedMM(list.get(i3), list.get(i4));
            }
        }
        return d;
    }

    private static double getCostOfSelfTSMM(CompressedSizeInfoColGroup compressedSizeInfoColGroup) {
        int length = compressedSizeInfoColGroup.getColumns().length;
        return DataExpression.DEFAULT_DELIM_FILL_VALUE + ((compressedSizeInfoColGroup.getNumVals() * (length * (length + 1))) / 2);
    }

    private double getCostOfLeftTransposedMM(CompressedSizeInfoColGroup compressedSizeInfoColGroup, CompressedSizeInfoColGroup compressedSizeInfoColGroup2) {
        int numRows = this._est.getNumRows();
        int length = compressedSizeInfoColGroup.getColumns().length;
        int length2 = compressedSizeInfoColGroup.getColumns().length;
        double d = numRows;
        double d2 = numRows;
        double tupleSparsity = compressedSizeInfoColGroup.getTupleSparsity();
        double tupleSparsity2 = compressedSizeInfoColGroup2.getTupleSparsity();
        int numVals = compressedSizeInfoColGroup.getNumVals();
        int numVals2 = compressedSizeInfoColGroup2.getNumVals();
        return Math.min(d + (((length <= 1 || tupleSparsity <= 0.4d) ? numVals * length * tupleSparsity : numVals * length) * 5.0d), d2 + (((length2 <= 1 || tupleSparsity2 <= 0.4d) ? numVals2 * length2 * tupleSparsity2 : numVals2 * length2) * 5.0d));
    }
}
