package org.apache.sysds.runtime.compress.cocode;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.PriorityQueue;
import org.apache.sysds.runtime.compress.CompressionSettings;
import org.apache.sysds.runtime.compress.estim.CompressedSizeEstimator;
import org.apache.sysds.runtime.compress.estim.CompressedSizeEstimatorSample;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;

/* loaded from: input_file:org/apache/sysds/runtime/compress/cocode/CoCodeCostMatrixMult.class */
public class CoCodeCostMatrixMult extends AColumnCoCoder {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/compress/cocode/CoCodeCostMatrixMult$CostOfJoin.class */
    public class CostOfJoin implements Comparable<CostOfJoin> {
        protected final CompressedSizeInfoColGroup elm;
        protected final double cost;

        protected CostOfJoin(CompressedSizeInfoColGroup compressedSizeInfoColGroup) {
            this.elm = compressedSizeInfoColGroup;
            if (compressedSizeInfoColGroup == null) {
                this.cost = Double.POSITIVE_INFINITY;
                return;
            }
            int length = compressedSizeInfoColGroup.getColumns().length;
            double numRows = CoCodeCostMatrixMult.this._est.getNumRows();
            int numVals = compressedSizeInfoColGroup.getNumVals();
            double tupleSparsity = compressedSizeInfoColGroup.getTupleSparsity();
            this.cost = numRows + ((length <= 1 || tupleSparsity <= 0.4d) ? numVals * length * tupleSparsity : numVals * length);
        }

        @Override // java.lang.Comparable
        public int compareTo(CostOfJoin costOfJoin) {
            if (this.cost == costOfJoin.cost) {
                return 0;
            }
            return this.cost > costOfJoin.cost ? 1 : -1;
        }

        public String toString() {
            return this.cost + " - " + this.elm.getBestCompressionType() + " nrVals: " + this.elm.getNumVals() + " " + Arrays.toString(this.elm.getColumns());
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public CoCodeCostMatrixMult(CompressedSizeEstimator compressedSizeEstimator, CompressionSettings compressionSettings) {
        super(compressedSizeEstimator, compressionSettings);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.sysds.runtime.compress.cocode.AColumnCoCoder
    public CompressedSizeInfo coCodeColumns(CompressedSizeInfo compressedSizeInfo, int i) {
        List<CompressedSizeInfoColGroup> join = join(compressedSizeInfo.getInfo());
        if (this._cs.samplingRatio < 0.1d && (this._est instanceof CompressedSizeEstimatorSample)) {
            LOG.debug("Performing second join with double sample rate");
            CompressedSizeEstimatorSample compressedSizeEstimatorSample = (CompressedSizeEstimatorSample) this._est;
            compressedSizeEstimatorSample.sampleData(compressedSizeEstimatorSample.getSample().getNumRows() * 2);
            ArrayList arrayList = new ArrayList(join.size());
            Iterator<CompressedSizeInfoColGroup> it = join.iterator();
            while (it.hasNext()) {
                arrayList.add(it.next().getColumns());
            }
            join = join(compressedSizeEstimatorSample.computeCompressedSizeInfos(arrayList, i));
        }
        compressedSizeInfo.setInfo(join);
        return compressedSizeInfo;
    }

    private List<CompressedSizeInfoColGroup> join(List<CompressedSizeInfoColGroup> list) {
        CostOfJoin costOfJoin;
        PriorityQueue priorityQueue = new PriorityQueue(list.size());
        ArrayList arrayList = new ArrayList();
        for (CompressedSizeInfoColGroup compressedSizeInfoColGroup : list) {
            if (compressedSizeInfoColGroup != null) {
                priorityQueue.add(new CostOfJoin(compressedSizeInfoColGroup));
            }
        }
        Object poll = priorityQueue.poll();
        while (true) {
            costOfJoin = (CostOfJoin) poll;
            if (priorityQueue.peek() == null) {
                break;
            }
            CostOfJoin costOfJoin2 = (CostOfJoin) priorityQueue.peek();
            double d = costOfJoin.cost + costOfJoin2.cost;
            CostOfJoin costOfJoin3 = new CostOfJoin(joinWithAnalysis(costOfJoin.elm, costOfJoin2.elm));
            if (costOfJoin3.cost < d) {
                if (LOG.isDebugEnabled()) {
                    LOG.debug("\nl:      " + costOfJoin + "\nr:      " + costOfJoin2 + "\njoined: " + costOfJoin3);
                }
                priorityQueue.poll();
                priorityQueue.add(costOfJoin3);
            } else {
                arrayList.add(costOfJoin.elm);
            }
            poll = priorityQueue.poll();
        }
        if (costOfJoin != null) {
            arrayList.add(costOfJoin.elm);
        }
        Iterator it = priorityQueue.iterator();
        while (it.hasNext()) {
            arrayList.add(((CostOfJoin) it.next()).elm);
        }
        return arrayList;
    }
}
