package org.apache.sysds.runtime.compress.estim;

import java.util.Arrays;
import java.util.Random;
import org.apache.commons.lang.NotImplementedException;
import org.apache.sysds.runtime.compress.CompressionSettings;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.estim.encoding.IEncode;
import org.apache.sysds.runtime.compress.estim.sample.SampleEstimatorFactory;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.data.SparseBlockMCSR;
import org.apache.sysds.runtime.data.SparseRow;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;

/* loaded from: input_file:org/apache/sysds/runtime/compress/estim/CompressedSizeEstimatorSample.class */
public class CompressedSizeEstimatorSample extends CompressedSizeEstimator {
    private final MatrixBlock _sample;
    private final int _k;
    private final int _sampleSize;
    private boolean _transposed;

    public CompressedSizeEstimatorSample(MatrixBlock matrixBlock, CompressionSettings compressionSettings, int i, int i2) {
        super(matrixBlock, compressionSettings);
        this._k = i2;
        this._sampleSize = i;
        this._transposed = this._cs.transposed;
        if (!LOG.isDebugEnabled()) {
            this._sample = sampleData(i);
            return;
        }
        Timing timing = new Timing(true);
        this._sample = sampleData(i);
        LOG.debug("Sampling time: " + timing.stop());
    }

    @Override // org.apache.sysds.runtime.compress.estim.CompressedSizeEstimator
    public CompressedSizeInfoColGroup getColGroupInfo(int[] iArr, int i, int i2) {
        return (this._data.isEmpty() || (this.nnzCols != null && iArr.length == 1 && this.nnzCols[iArr[0]] == 0) || (this._cs.transposed && iArr.length == 1 && this._data.isInSparseFormat() && this._data.getSparseBlock().isEmpty(iArr[0]))) ? new CompressedSizeInfoColGroup(iArr, getNumRows()) : extractInfo(IEncode.createFromMatrixBlock(this._sample, this._transposed, iArr), iArr, i2);
    }

    @Override // org.apache.sysds.runtime.compress.estim.CompressedSizeEstimator
    public CompressedSizeInfoColGroup getDeltaColGroupInfo(int[] iArr, int i, int i2) {
        return extractInfo(IEncode.createFromMatrixBlockDelta(this._data, this._transposed, iArr, this._sampleSize), iArr, i2);
    }

    @Override // org.apache.sysds.runtime.compress.estim.CompressedSizeEstimator
    protected int worstCaseUpperBound(int[] iArr) {
        return getNumColumns() == iArr.length ? Math.min(getNumRows(), (int) this._data.getNonZeros()) : getNumRows();
    }

    @Override // org.apache.sysds.runtime.compress.estim.CompressedSizeEstimator
    protected CompressedSizeInfoColGroup combine(int[] iArr, CompressedSizeInfoColGroup compressedSizeInfoColGroup, CompressedSizeInfoColGroup compressedSizeInfoColGroup2, int i) {
        return extractInfo(compressedSizeInfoColGroup.getMap().combine(compressedSizeInfoColGroup2.getMap()), iArr, i);
    }

    private CompressedSizeInfoColGroup extractInfo(IEncode iEncode, int[] iArr, int i) {
        return new CompressedSizeInfoColGroup(iArr, scaleFactors(iEncode.extractFacts(iArr, this._sampleSize, this._data.getSparsity(), this._data.getSparsity()), iArr, i, iEncode.isDense()), this._cs.validCompressions, iEncode);
    }

    private EstimationFactors scaleFactors(EstimationFactors estimationFactors, int[] iArr, int i, boolean z) {
        int numRows = getNumRows();
        int length = iArr.length;
        double d = numRows / this._sampleSize;
        long calculateNNZ = calculateNNZ(iArr, d);
        int calculateOffs = calculateOffs(estimationFactors, numRows, d, iArr, (int) calculateNNZ);
        int distinctCountScale = distinctCountScale(estimationFactors, calculateOffs, numRows, i, z, length);
        int i2 = (numRows - distinctCountScale) + 1;
        int floor = estimationFactors.largestOff < 0 ? calculateOffs / distinctCountScale : (int) Math.floor(estimationFactors.largestOff * d);
        int max = Math.max(Math.min(i2, floor), numRows - calculateOffs);
        double calculateSparsity = calculateSparsity(iArr, calculateNNZ, d, estimationFactors.overAllSparsity);
        try {
            return new EstimationFactors(iArr.length, distinctCountScale, calculateOffs, max, estimationFactors.frequencies, estimationFactors.numSingle, numRows, estimationFactors.lossy, estimationFactors.zeroIsMostFrequent, calculateSparsity, Math.min(calculateSparsity * 1.3d, 1.0d));
        } catch (Exception e) {
            String arrays = Arrays.toString(iArr);
            int i3 = this.nnzCols[iArr[0]];
            DMLCompressionException dMLCompressionException = new DMLCompressionException("Invalid construction of estimation factors with observed values:\n" + arrays + " " + calculateNNZ + " " + dMLCompressionException + "  " + calculateOffs + "  " + distinctCountScale + "  " + i2 + " " + floor + " " + max + " " + calculateSparsity + "\n" + dMLCompressionException, e);
            throw dMLCompressionException;
        }
    }

    private int distinctCountScale(EstimationFactors estimationFactors, int i, int i2, int i3, boolean z, int i4) {
        int[] iArr = estimationFactors.frequencies;
        if (iArr == null || iArr.length == 0) {
            return i;
        }
        int distinctCount = SampleEstimatorFactory.distinctCount(iArr, z ? i2 : i, estimationFactors.numOffs, this._cs.estimationType);
        if (distinctCount > 10000) {
            distinctCount = (int) (distinctCount + (distinctCount * 0.5d));
        }
        if (i4 > 4) {
            distinctCount = (int) (distinctCount + ((distinctCount * i4) / 10.0d));
        }
        return Math.max(Math.min(distinctCount, Math.min(i3, i)), 1);
    }

    private int calculateOffs(EstimationFactors estimationFactors, int i, double d, int[] iArr, int i2) {
        return getNumColumns() == 1 ? i2 : this.nnzCols != null ? iArr.length == 1 ? this.nnzCols[iArr[0]] : Math.min(i2, i - ((int) Math.floor((estimationFactors.numRows - estimationFactors.numOffs) * d))) : i - ((int) Math.floor((estimationFactors.numRows - estimationFactors.numOffs) * d));
    }

    private double calculateSparsity(int[] iArr, long j, double d, double d2) {
        return iArr.length == getNumColumns() ? this._data.getSparsity() : (this.nnzCols != null || (this._cs.transposed && this._data.isInSparseFormat()) || (this._transposed && this._sample.isInSparseFormat())) ? j / (getNumRows() * iArr.length) : this._sample.isEmpty() ? this._data.getSparsity() : d2;
    }

    private long calculateNNZ(int[] iArr, double d) {
        if (iArr.length == getNumColumns()) {
            return this._data.getNonZeros();
        }
        if (this._cs.transposed && this._data.isInSparseFormat()) {
            long j = 0;
            SparseBlock sparseBlock = this._data.getSparseBlock();
            for (int i = 0; i < iArr.length; i++) {
                j += sparseBlock.get(i).size();
            }
            return j;
        }
        if (this.nnzCols != null) {
            long j2 = 0;
            for (int i2 : iArr) {
                j2 += this.nnzCols[i2];
            }
            return j2;
        }
        if (this._sample.isEmpty()) {
            return 0L;
        }
        if (!this._transposed || !this._sample.isInSparseFormat()) {
            return this._sample.getNonZeros();
        }
        long j3 = 0;
        SparseBlock sparseBlock2 = this._sample.getSparseBlock();
        for (int i3 = 0; i3 < iArr.length; i3++) {
            if (!sparseBlock2.isEmpty(i3)) {
                j3 = (long) (j3 + (sparseBlock2.get(i3).size() * d));
            }
        }
        if (j3 == 0) {
            j3 += iArr.length;
        }
        return j3;
    }

    public static int[] getSortedSample(int i, int i2, long j, int i3) {
        int[] iArr = new int[i2];
        Random random = new Random(j);
        for (int i4 = 0; i4 < i2; i4++) {
            iArr[i4] = i4;
        }
        for (int i5 = i2; i5 < i; i5++) {
            if (random.nextInt(i5) < i2) {
                iArr[random.nextInt(i2)] = i5;
            }
        }
        if (i / 100 < i2) {
            for (int i6 = 0; i6 < i2 - 1; i6++) {
                int nextInt = random.nextInt(i2 - i6) + i6;
                int i7 = iArr[i6];
                iArr[i6] = iArr[nextInt];
                iArr[nextInt] = i7;
            }
        }
        if (i3 > 1) {
            Arrays.parallelSort(iArr);
        } else {
            Arrays.sort(iArr);
        }
        return iArr;
    }

    private MatrixBlock sampleData(int i) {
        int[] sortedSample = getSortedSample(getNumRows(), i, this._cs.seed, this._k);
        return !this._cs.transposed ? this._data.isInSparseFormat() ? sparseNotTransposedSamplePath(sortedSample) : denseSamplePath(sortedSample) : defaultSlowSamplingPath(sortedSample);
    }

    private MatrixBlock sparseNotTransposedSamplePath(int[] iArr) {
        MatrixBlock matrixBlock = new MatrixBlock(iArr.length, this._data.getNumColumns(), true);
        SparseRow[] sparseRowArr = new SparseRow[iArr.length];
        SparseBlock sparseBlock = this._data.getSparseBlock();
        for (int i = 0; i < iArr.length; i++) {
            sparseRowArr[i] = sparseBlock.get(iArr[i]);
        }
        matrixBlock.setSparseBlock(new SparseBlockMCSR(sparseRowArr, false));
        matrixBlock.recomputeNonZeros();
        this._transposed = true;
        return LibMatrixReorg.transposeInPlace(matrixBlock, this._k);
    }

    private MatrixBlock defaultSlowSamplingPath(int[] iArr) {
        MatrixBlock matrixBlock = this._cs.transposed ? new MatrixBlock(this._data.getNumColumns(), 1, false) : new MatrixBlock(this._data.getNumRows(), 1, false);
        for (int i : iArr) {
            matrixBlock.appendValue(i, 0, 1.0d);
        }
        return this._data.removeEmptyOperations(new MatrixBlock(), !this._cs.transposed, true, matrixBlock);
    }

    private MatrixBlock denseSamplePath(int[] iArr) {
        int length = iArr.length;
        long ceil = (long) Math.ceil(this._data.getNonZeros() / (this._cs.transposed ? this._data.getNumColumns() / length : this._data.getNumRows() / length));
        int numRows = this._cs.transposed ? this._data.getNumRows() : this._data.getNumColumns();
        MatrixBlock matrixBlock = new MatrixBlock(numRows, length, 0.4d > ((double) ceil) / ((double) (((long) length) * ((long) numRows))));
        matrixBlock.allocateBlock();
        DenseBlock denseBlock = this._data.getDenseBlock();
        if (matrixBlock.isInSparseFormat()) {
            SparseBlock sparseBlock = matrixBlock.getSparseBlock();
            if (!(sparseBlock instanceof SparseBlockMCSR)) {
                throw new NotImplementedException("Not Implemented support for dense sample into sparse: " + sparseBlock.getClass().getSimpleName());
            }
            SparseBlockMCSR sparseBlockMCSR = (SparseBlockMCSR) sparseBlock;
            int max = (int) Math.max(4.0d, Math.ceil(ceil / length));
            for (int i = 0; i < numRows; i++) {
                sparseBlockMCSR.allocate(i, max);
            }
            for (int i2 = 0; i2 < length; i2++) {
                int i3 = iArr[i2];
                double[] values = denseBlock.values(i3);
                int pos = denseBlock.pos(i3);
                for (int i4 = 0; i4 < numRows; i4++) {
                    sparseBlockMCSR.get(i4).append(i2, values[pos + i4]);
                }
            }
        } else {
            DenseBlock denseBlock2 = matrixBlock.getDenseBlock();
            for (int i5 = 0; i5 < length; i5++) {
                int i6 = iArr[i5];
                double[] values2 = denseBlock.values(i6);
                int pos2 = denseBlock.pos(i6);
                for (int i7 = 0; i7 < numRows; i7++) {
                    denseBlock2.values(i7)[(i7 * length) + i5] = values2[pos2 + i7];
                }
            }
        }
        matrixBlock.setNonZeros(ceil);
        this._transposed = true;
        return matrixBlock;
    }

    public String toString() {
        return super.toString() + " sampleSize: " + this._sampleSize + " transposed: " + this._transposed;
    }
}
