package org.apache.sysds.runtime.matrix.data.sketch.countdistinctapprox;

import org.antlr.v4.runtime.atn.PredictionContext;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.instructions.spark.data.CorrMatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.sketch.CountDistinctSketch;
import org.apache.sysds.runtime.matrix.operators.CountDistinctOperatorTypes;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.utils.Hash;

/* loaded from: input_file:org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/KMVSketch.class */
public class KMVSketch extends CountDistinctSketch {
    private static final Log LOG;
    static final /* synthetic */ boolean $assertionsDisabled;

    public KMVSketch(Operator operator) {
        super(operator);
    }

    @Override // org.apache.sysds.runtime.matrix.data.sketch.MatrixSketch
    public MatrixBlock getValue(MatrixBlock matrixBlock) {
        if (this.op.getDirection().isRowCol()) {
            long nonZeros = matrixBlock.getNonZeros() + 1;
            long j = nonZeros * nonZeros;
            int i = j > OptimizerUtils.MAX_NUMCELLS_CP_DENSE ? PredictionContext.EMPTY_RETURN_STATE : (int) j;
            int i2 = nonZeros > 64 ? 64 : (int) nonZeros;
            SmallestPriorityQueue kSmallestHashes = getKSmallestHashes(matrixBlock, i2, i);
            if (LOG.isDebugEnabled()) {
                LOG.debug("M not forced to int size: " + j);
                LOG.debug("M: " + i);
                LOG.debug("M: " + i);
                LOG.debug("kth smallest hash:" + kSmallestHashes.peek());
                LOG.debug("spq: " + kSmallestHashes);
            }
            long countDistinctValuesKMV = countDistinctValuesKMV(kSmallestHashes, i2, i, nonZeros);
            if (countDistinctValuesKMV <= 0) {
                throw new DMLRuntimeException("Impossible estimate of distinct values");
            }
            return new MatrixBlock(countDistinctValuesKMV);
        }
        if (this.op.getDirection().isRow()) {
            long floor = ((long) Math.floor(matrixBlock.getNonZeros() / matrixBlock.getNumRows())) + 1;
            long j2 = floor * floor;
            int i3 = j2 > OptimizerUtils.MAX_NUMCELLS_CP_DENSE ? PredictionContext.EMPTY_RETURN_STATE : (int) j2;
            int i4 = floor > 64 ? 64 : (int) floor;
            MatrixBlock matrixBlock2 = new MatrixBlock(matrixBlock.getNumRows(), 1, false, matrixBlock.getNumRows());
            matrixBlock2.allocateBlock();
            SmallestPriorityQueue smallestPriorityQueue = new SmallestPriorityQueue(i4);
            for (int i5 = 0; i5 < matrixBlock.getNumRows(); i5++) {
                for (int i6 = 0; i6 < matrixBlock.getNumColumns(); i6++) {
                    smallestPriorityQueue.add(matrixBlock.getValue(i5, i6));
                }
                matrixBlock2.setValue(i5, 0, countDistinctValuesKMV(smallestPriorityQueue, i4, i3, floor));
                smallestPriorityQueue.clear();
            }
            return matrixBlock2;
        }
        long floor2 = ((long) Math.floor(matrixBlock.getNonZeros() / matrixBlock.getNumColumns())) + 1;
        long j3 = floor2 * floor2;
        int i7 = j3 > OptimizerUtils.MAX_NUMCELLS_CP_DENSE ? PredictionContext.EMPTY_RETURN_STATE : (int) j3;
        int i8 = floor2 > 64 ? 64 : (int) floor2;
        MatrixBlock matrixBlock3 = new MatrixBlock(1, matrixBlock.getNumColumns(), false, matrixBlock.getNumColumns());
        matrixBlock3.allocateBlock();
        SmallestPriorityQueue smallestPriorityQueue2 = new SmallestPriorityQueue(i8);
        for (int i9 = 0; i9 < matrixBlock.getNumColumns(); i9++) {
            for (int i10 = 0; i10 < matrixBlock.getNumRows(); i10++) {
                smallestPriorityQueue2.add(matrixBlock.getValue(i10, i9));
            }
            matrixBlock3.setValue(0, i9, countDistinctValuesKMV(smallestPriorityQueue2, i8, i7, floor2));
            smallestPriorityQueue2.clear();
        }
        return matrixBlock3;
    }

    private SmallestPriorityQueue getKSmallestHashes(MatrixBlock matrixBlock, int i, int i2) {
        SmallestPriorityQueue smallestPriorityQueue = new SmallestPriorityQueue(i);
        countDistinctValuesKMV(matrixBlock, this.op.getHashType(), i, smallestPriorityQueue, i2);
        return smallestPriorityQueue;
    }

    private void countDistinctValuesKMV(MatrixBlock matrixBlock, Hash.HashType hashType, int i, SmallestPriorityQueue smallestPriorityQueue, int i2) {
        if (matrixBlock.isEmpty()) {
            smallestPriorityQueue.add(DataExpression.DEFAULT_DELIM_FILL_VALUE);
            return;
        }
        if (matrixBlock instanceof CompressedMatrixBlock) {
            throw new NotImplementedException("Cannot approximate distinct count for compressed matrices");
        }
        if (matrixBlock.getSparseBlock() == null) {
            DenseBlock denseBlock = matrixBlock.getDenseBlock();
            int index = denseBlock.index(0);
            int index2 = denseBlock.index(matrixBlock.getNumRows());
            for (int i3 = index; i3 <= index2; i3++) {
                countDistinctValuesKMV(denseBlock.valuesAt(i3), hashType, i, smallestPriorityQueue, i2);
            }
            return;
        }
        SparseBlock sparseBlock = matrixBlock.getSparseBlock();
        if (sparseBlock.isContiguous()) {
            countDistinctValuesKMV(sparseBlock.values(0), hashType, i, smallestPriorityQueue, i2);
            return;
        }
        for (int i4 = 0; i4 < matrixBlock.getNumRows(); i4++) {
            if (!sparseBlock.isEmpty(i4)) {
                countDistinctValuesKMV(sparseBlock.values(i4), hashType, i, smallestPriorityQueue, i2);
            }
        }
    }

    private void countDistinctValuesKMV(double[] dArr, Hash.HashType hashType, int i, SmallestPriorityQueue smallestPriorityQueue, int i2) {
        for (double d : dArr) {
            smallestPriorityQueue.add((Math.abs(Hash.hash(d, hashType)) % (i2 - 1)) + 1);
        }
    }

    private long countDistinctValuesKMV(SmallestPriorityQueue smallestPriorityQueue, int i, int i2, long j) {
        long round;
        if (smallestPriorityQueue.size() < i) {
            round = smallestPriorityQueue.size();
        } else {
            double poll = smallestPriorityQueue.poll() / i2;
            double d = (i - 1) / poll;
            double min = Math.min(d, j);
            if (LOG.isDebugEnabled()) {
                LOG.debug("U_k : " + poll);
                LOG.debug("Estimate: " + d);
                LOG.debug("Ceil worst case: " + j);
            }
            round = Math.round(min);
        }
        return round;
    }

    @Override // org.apache.sysds.runtime.matrix.data.sketch.MatrixSketch
    public MatrixBlock getValueFromSketch(CorrMatrixBlock corrMatrixBlock) {
        MatrixBlock value = corrMatrixBlock.getValue();
        if (this.op.getDirection().isRow()) {
            MatrixBlock matrixBlock = new MatrixBlock(value.getNumRows(), 1, false, value.getNumRows());
            matrixBlock.allocateBlock();
            for (int i = 0; i < value.getNumRows(); i++) {
                getDistinctCountFromSketchByIndex(corrMatrixBlock, i, matrixBlock);
            }
            return matrixBlock;
        }
        if (!this.op.getDirection().isCol()) {
            MatrixBlock matrixBlock2 = new MatrixBlock(1, 1, false, 1L);
            matrixBlock2.allocateBlock();
            getDistinctCountFromSketchByIndex(corrMatrixBlock, 0, matrixBlock2);
            return matrixBlock2;
        }
        MatrixBlock matrixBlock3 = new MatrixBlock(1, value.getNumColumns(), false, value.getNumColumns());
        matrixBlock3.allocateBlock();
        for (int i2 = 0; i2 < value.getNumColumns(); i2++) {
            getDistinctCountFromSketchByIndex(corrMatrixBlock, i2, matrixBlock3);
        }
        return matrixBlock3;
    }

    private void getDistinctCountFromSketchByIndex(CorrMatrixBlock corrMatrixBlock, int i, MatrixBlock matrixBlock) {
        MatrixBlock value = corrMatrixBlock.getValue();
        MatrixBlock correction = corrMatrixBlock.getCorrection();
        if (this.op.getOperatorType() != CountDistinctOperatorTypes.KMV) {
            throw new IllegalArgumentException(getClass().getSimpleName() + " cannot use " + this.op.getOperatorType());
        }
        double value2 = (this.op.getDirection().isRow() || this.op.getDirection().isRowCol()) ? value.getValue(i, 0) : value.getValue(0, i);
        double value3 = correction.getValue(i, 0);
        double value4 = correction.getValue(i, 1);
        double value5 = correction.getValue(i, 2);
        double d = value5 * value5;
        double min = (value3 == DataExpression.DEFAULT_DELIM_FILL_VALUE || value3 >= value4) ? value3 == DataExpression.DEFAULT_DELIM_FILL_VALUE ? 1.0d : Math.min((value4 - 1.0d) / (value2 / (d > 2.147483647E9d ? 2.147483647E9d : d)), value5) : value3;
        if (this.op.getDirection().isRow() || this.op.getDirection().isRowCol()) {
            matrixBlock.setValue(i, 0, min);
        } else {
            matrixBlock.setValue(0, i, min);
        }
    }

    @Override // org.apache.sysds.runtime.matrix.data.sketch.MatrixSketch
    public CorrMatrixBlock create(MatrixBlock matrixBlock) {
        if (this.op.getDirection().isRowCol()) {
            MatrixBlock matrixBlock2 = new MatrixBlock(matrixBlock);
            MatrixBlock matrixBlock3 = new MatrixBlock(1, 3, false);
            createSketchByIndex(matrixBlock, matrixBlock3, 0, matrixBlock2);
            return new CorrMatrixBlock(matrixBlock2, matrixBlock3);
        }
        if (this.op.getDirection().isRow()) {
            MatrixBlock matrixBlock4 = new MatrixBlock(matrixBlock.getNumRows(), 3, false);
            for (int i = 0; i < matrixBlock.getNumRows(); i++) {
                createSketchByIndex(matrixBlock, matrixBlock4, i);
            }
            return new CorrMatrixBlock(matrixBlock, matrixBlock4);
        }
        if (!this.op.getDirection().isCol()) {
            throw new DMLRuntimeException(String.format("Unexpected direction: %s", this.op.getDirection()));
        }
        MatrixBlock matrixBlock5 = new MatrixBlock(matrixBlock.getNumColumns(), 3, false);
        for (int i2 = 0; i2 < matrixBlock.getNumColumns(); i2++) {
            createSketchByIndex(matrixBlock, matrixBlock5, i2);
        }
        return new CorrMatrixBlock(matrixBlock, matrixBlock5);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v11, types: [org.apache.sysds.runtime.matrix.data.MatrixBlock] */
    private MatrixBlock sliceMatrixBlockByIndexDirection(MatrixBlock matrixBlock, int i) {
        return this.op.getDirection().isRow() ? matrixBlock.slice2(i, i) : this.op.getDirection().isCol() ? matrixBlock.slice2(0, matrixBlock.getNumRows() - 1, i, i) : matrixBlock;
    }

    private void createSketchByIndex(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, int i) {
        createSketchByIndex(matrixBlock, matrixBlock2, i, null);
    }

    private void createSketchByIndex(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, int i, MatrixBlock matrixBlock3) {
        MatrixBlock matrixBlock4 = matrixBlock3 == null ? matrixBlock : matrixBlock3;
        MatrixBlock sliceMatrixBlockByIndexDirection = sliceMatrixBlockByIndexDirection(matrixBlock, i);
        long nonZeros = sliceMatrixBlockByIndexDirection.getNonZeros() + 1;
        long j = nonZeros * nonZeros;
        int i2 = j > OptimizerUtils.MAX_NUMCELLS_CP_DENSE ? PredictionContext.EMPTY_RETURN_STATE : (int) j;
        int i3 = nonZeros > 64 ? 64 : (int) nonZeros;
        if (matrixBlock3 != null) {
            matrixBlock4.reset(1, i3);
        }
        if (sliceMatrixBlockByIndexDirection.getLength() == 1 || sliceMatrixBlockByIndexDirection.isEmpty()) {
            matrixBlock2.setValue(i, 0, DataExpression.DEFAULT_DELIM_FILL_VALUE);
            matrixBlock2.setValue(i, 1, i3);
            matrixBlock2.setValue(i, 2, nonZeros);
            return;
        }
        SmallestPriorityQueue kSmallestHashes = getKSmallestHashes(sliceMatrixBlockByIndexDirection, i3, i2);
        int size = kSmallestHashes.size();
        if (!$assertionsDisabled && size <= 0) {
            throw new AssertionError();
        }
        int i4 = 0;
        while (!kSmallestHashes.isEmpty()) {
            double poll = kSmallestHashes.poll();
            if (this.op.getDirection().isRow()) {
                matrixBlock4.setValue(i, i4, poll);
            } else if (this.op.getDirection().isCol()) {
                matrixBlock4.setValue(i4, i, poll);
            } else {
                matrixBlock4.setValue(i, i4, poll);
            }
            i4++;
        }
        matrixBlock2.setValue(i, 0, size);
        matrixBlock2.setValue(i, 1, i3);
        matrixBlock2.setValue(i, 2, nonZeros);
    }

    @Override // org.apache.sysds.runtime.matrix.data.sketch.MatrixSketch
    public CorrMatrixBlock union(CorrMatrixBlock corrMatrixBlock, CorrMatrixBlock corrMatrixBlock2) {
        MatrixBlock value = corrMatrixBlock.getValue();
        MatrixBlock value2 = corrMatrixBlock2.getValue();
        if (this.op.getDirection().isRow()) {
            CorrMatrixBlock corrMatrixBlock3 = new CorrMatrixBlock(value.getNumColumns() > value2.getNumColumns() ? value : value2, new MatrixBlock(value.getNumRows(), 3, false));
            for (int i = 0; i < value.getNumRows(); i++) {
                unionSketchByIndex(corrMatrixBlock, corrMatrixBlock2, i, corrMatrixBlock3);
            }
            return corrMatrixBlock3;
        }
        if (!this.op.getDirection().isCol()) {
            CorrMatrixBlock corrMatrixBlock4 = new CorrMatrixBlock(value.getNumColumns() > value2.getNumColumns() ? value : value2, new MatrixBlock(1, 3, false));
            unionSketchByIndex(corrMatrixBlock, corrMatrixBlock2, 0, corrMatrixBlock4);
            return corrMatrixBlock4;
        }
        CorrMatrixBlock corrMatrixBlock5 = new CorrMatrixBlock(value.getNumRows() > value2.getNumRows() ? value : value2, new MatrixBlock(value.getNumColumns(), 3, false));
        for (int i2 = 0; i2 < value.getNumColumns(); i2++) {
            unionSketchByIndex(corrMatrixBlock, corrMatrixBlock2, i2, corrMatrixBlock5);
        }
        return corrMatrixBlock5;
    }

    public void unionSketchByIndex(CorrMatrixBlock corrMatrixBlock, CorrMatrixBlock corrMatrixBlock2, int i, CorrMatrixBlock corrMatrixBlock3) {
        MatrixBlock correction = corrMatrixBlock.getCorrection();
        MatrixBlock correction2 = corrMatrixBlock2.getCorrection();
        validateSketchMetadata(correction);
        validateSketchMetadata(correction2);
        MatrixBlock value = corrMatrixBlock.getValue();
        MatrixBlock value2 = corrMatrixBlock2.getValue();
        if ((this.op.getDirection().isRow() && value.getNumRows() != value2.getNumRows()) || (this.op.getDirection().isCol() && value.getNumColumns() != value2.getNumColumns())) {
            throw new DMLRuntimeException("Cannot take the union of sketches: rows/columns are not aligned");
        }
        MatrixBlock value3 = corrMatrixBlock3.getValue();
        MatrixBlock correction3 = corrMatrixBlock3.getCorrection();
        double value4 = correction.getValue(i, 0);
        double value5 = correction.getValue(i, 1);
        double value6 = correction.getValue(i, 2);
        double value7 = correction2.getValue(i, 0);
        double value8 = correction2.getValue(i, 1);
        double value9 = correction2.getValue(i, 2);
        double max = Math.max(value4, value7);
        double max2 = Math.max(value5, value8);
        double d = (value6 + value9) - 1.0d;
        SmallestPriorityQueue smallestPriorityQueue = new SmallestPriorityQueue((int) max);
        for (int i2 = 0; i2 < value4; i2++) {
            smallestPriorityQueue.add((this.op.getDirection().isRow() || this.op.getDirection().isRowCol()) ? value.getValue(i, i2) : value.getValue(i2, i));
        }
        for (int i3 = 0; i3 < value7; i3++) {
            smallestPriorityQueue.add((this.op.getDirection().isRow() || this.op.getDirection().isRowCol()) ? value2.getValue(i, i3) : value2.getValue(i3, i));
        }
        int i4 = 0;
        while (!smallestPriorityQueue.isEmpty()) {
            double poll = smallestPriorityQueue.poll();
            if (this.op.getDirection().isRow() || this.op.getDirection().isRowCol()) {
                value3.setValue(i, i4, poll);
            } else {
                value3.setValue(i4, i, poll);
            }
            i4++;
        }
        correction3.setValue(i, 0, max);
        correction3.setValue(i, 1, max2);
        correction3.setValue(i, 2, d);
    }

    @Override // org.apache.sysds.runtime.matrix.data.sketch.MatrixSketch
    public CorrMatrixBlock intersection(CorrMatrixBlock corrMatrixBlock, CorrMatrixBlock corrMatrixBlock2) {
        throw new NotImplementedException(String.format("%s intersection has not been implemented yet", KMVSketch.class.getSimpleName()));
    }

    static {
        $assertionsDisabled = !KMVSketch.class.desiredAssertionStatus();
        LOG = LogFactory.getLog(KMVSketch.class.getName());
    }
}
