package org.apache.sysds.runtime.compress.estim;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Random;
import org.apache.commons.lang.NotImplementedException;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.compress.CompressionSettings;
import org.apache.sysds.runtime.compress.bitmap.ABitmap;
import org.apache.sysds.runtime.compress.bitmap.BitmapEncoder;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
import org.apache.sysds.runtime.compress.estim.sample.SampleEstimatorFactory;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.data.SparseBlockMCSR;
import org.apache.sysds.runtime.data.SparseRow;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;

/* loaded from: input_file:org/apache/sysds/runtime/compress/estim/CompressedSizeEstimatorSample.class */
public class CompressedSizeEstimatorSample extends CompressedSizeEstimator {
    private final MatrixBlock _sample;
    private final HashMap<Integer, Double> _solveCache;
    private final int _k;
    private final int _sampleSize;
    private boolean _transposed;

    /* JADX INFO: Access modifiers changed from: protected */
    public CompressedSizeEstimatorSample(MatrixBlock matrixBlock, CompressionSettings compressionSettings, int i, int i2) {
        super(matrixBlock, compressionSettings);
        this._k = i2;
        this._sampleSize = i;
        this._transposed = this._cs.transposed;
        if (LOG.isDebugEnabled()) {
            Timing timing = new Timing(true);
            this._sample = sampleData(i);
            LOG.debug("Sampling time: " + timing.stop());
        } else {
            this._sample = sampleData(i);
        }
        this._solveCache = new HashMap<>();
    }

    public MatrixBlock getSample() {
        return this._sample;
    }

    public final int getSampleSize() {
        return this._sampleSize;
    }

    @Override // org.apache.sysds.runtime.compress.estim.CompressedSizeEstimator
    public CompressedSizeInfoColGroup estimateCompressedColGroupSize(int[] iArr, int i, int i2) {
        ABitmap extractBitmap = BitmapEncoder.extractBitmap(iArr, this._sample, this._transposed, i);
        EstimationFactors computeSizeEstimationFactors = EstimationFactors.computeSizeEstimationFactors(extractBitmap, this._sampleSize, false, iArr);
        AMapToData create = MapToFactory.create(this._sampleSize, extractBitmap);
        return new CompressedSizeInfoColGroup(iArr, estimateCompressionFactors(computeSizeEstimationFactors, create, iArr, i2), this._cs.validCompressions, create);
    }

    @Override // org.apache.sysds.runtime.compress.estim.CompressedSizeEstimator
    protected int worstCaseUpperBound(int[] iArr) {
        return getNumColumns() == iArr.length ? Math.min(getNumRows(), (int) this._data.getNonZeros()) : getNumRows();
    }

    @Override // org.apache.sysds.runtime.compress.estim.CompressedSizeEstimator
    protected CompressedSizeInfoColGroup estimateJoinCompressedSize(int[] iArr, CompressedSizeInfoColGroup compressedSizeInfoColGroup, CompressedSizeInfoColGroup compressedSizeInfoColGroup2, int i) {
        if (compressedSizeInfoColGroup.getNumVals() * compressedSizeInfoColGroup2.getNumVals() > OptimizerUtils.MAX_NUMCELLS_CP_DENSE) {
            return null;
        }
        AMapToData join = MapToFactory.join(compressedSizeInfoColGroup.getMap(), compressedSizeInfoColGroup2.getMap());
        return new CompressedSizeInfoColGroup(iArr, estimateCompressionFactors(EstimationFactors.computeSizeEstimation(iArr, join, this._cs.validCompressions.contains(AColGroup.CompressionType.RLE), join.size(), false), join, iArr, i), this._cs.validCompressions, join);
    }

    private EstimationFactors estimateCompressionFactors(EstimationFactors estimationFactors, AMapToData aMapToData, int[] iArr, int i) {
        int numRows = getNumRows();
        if (aMapToData == null || estimationFactors == null) {
            int length = iArr.length;
            if (this._data.isEmpty()) {
                return new EstimationFactors(iArr.length, 0, 0, numRows, null, 0, 0, numRows, false, true, DataExpression.DEFAULT_DELIM_FILL_VALUE, DataExpression.DEFAULT_DELIM_FILL_VALUE);
            }
            return new EstimationFactors(iArr.length, 1, 1, numRows - 1, null, 2, 1, numRows, false, true, 1.0d / numRows, 1.0d / length);
        }
        double d = numRows / this._sampleSize;
        int calculateOffs = calculateOffs(estimationFactors, this._sampleSize, numRows, d, estimationFactors.numRows - estimationFactors.numOffs);
        int estimatedDistinctCount = getEstimatedDistinctCount(estimationFactors.frequencies, i, calculateOffs, estimationFactors.numOffs);
        return new EstimationFactors(iArr.length, estimatedDistinctCount, calculateOffs, Math.min((numRows - estimatedDistinctCount) + 1, (int) Math.floor(estimationFactors.largestOff * d)), estimationFactors.frequencies, getNumRuns(aMapToData, estimationFactors.numVals, this._sampleSize, numRows), estimationFactors.numSingle, numRows, estimationFactors.lossy, estimationFactors.zeroIsMostFrequent, calculateSparsity(iArr, d, estimationFactors.overAllSparsity), estimationFactors.tupleSparsity);
    }

    private int calculateOffs(EstimationFactors estimationFactors, int i, int i2, double d, int i3) {
        if (getNumColumns() == 1) {
            return (int) this._data.getNonZeros();
        }
        return (int) Math.ceil(i2 - ((d * Math.max(1.0d - (estimationFactors.numSingle / i), i / i2)) * i3));
    }

    private double calculateSparsity(int[] iArr, double d, double d2) {
        if (iArr.length == getNumColumns()) {
            return this._data.getSparsity();
        }
        if (this._cs.transposed && this._data.isInSparseFormat()) {
            double d3 = 0.0d;
            SparseBlock sparseBlock = this._data.getSparseBlock();
            for (int i = 0; i < iArr.length; i++) {
                d3 += sparseBlock.get(i).size();
            }
            return d3 / (getNumRows() * iArr.length);
        }
        if (!this._transposed || !this._sample.isInSparseFormat()) {
            return d2;
        }
        double d4 = 0.0d;
        SparseBlock sparseBlock2 = this._sample.getSparseBlock();
        for (int i2 = 0; i2 < iArr.length; i2++) {
            d4 += sparseBlock2.get(i2).size() * d;
        }
        return d4 / (getNumRows() * iArr.length);
    }

    private int getEstimatedDistinctCount(int[] iArr, int i, int i2, int i3) {
        return Math.min(SampleEstimatorFactory.distinctCount(iArr, i2, i3, this._cs.estimationType, this._solveCache), i);
    }

    private int getNumRuns(AMapToData aMapToData, int i, int i2, int i3) {
        if (!this._cs.validCompressions.contains(AColGroup.CompressionType.RLE) || i <= 0) {
            return 0;
        }
        return getNumRuns(aMapToData, i2, i3);
    }

    private static int getNumRuns(ABitmap aBitmap, int i, int i2, int[] iArr) {
        int i3;
        double d;
        boolean z;
        int numValues = aBitmap.getNumValues();
        double d2 = 0.0d;
        for (int i4 = 0; i4 < numValues; i4++) {
            int[] extractValues = aBitmap.getOffsetsList(i4).extractValues();
            int numOffsets = aBitmap.getNumOffsets(i4);
            double d3 = numOffsets / i;
            if ((d3 * i2) / i < 1.0d) {
                d2 += (numOffsets * i2) / i;
            } else {
                double d4 = 1.0d;
                boolean z2 = false;
                if (iArr[0] == 0) {
                    i3 = 0;
                } else {
                    int i5 = iArr[0];
                    int i6 = (i5 - (-1)) - 1;
                    double d5 = d3 * i6;
                    d2 += ((i6 - d5) * d5) / i6;
                    i3 = i5;
                    d4 = (i6 - d5) / i6;
                }
                int i7 = 0;
                boolean z3 = false;
                boolean z4 = false;
                int i8 = 0;
                int i9 = 1;
                while (i9 < i) {
                    if (i8 >= numOffsets || extractValues[i8] != i3) {
                        z3 = true;
                        z = false;
                    } else {
                        z4 = true;
                        i8++;
                        z = true;
                    }
                    while (true) {
                        if (i3 + 1 != iArr[i9]) {
                            break;
                        }
                        i3 = iArr[i9];
                        if (z3) {
                            if (i8 >= numOffsets || extractValues[i8] != i3) {
                                d2 += i7;
                                i7 = 0;
                                z = false;
                            } else {
                                i7 = 1;
                                i8++;
                                z = true;
                            }
                        } else if (i8 >= numOffsets || extractValues[i8] != i3) {
                            z3 = true;
                            z = false;
                        } else {
                            i8++;
                            z = true;
                        }
                        i9++;
                        if (i9 == i) {
                            z2 = true;
                            break;
                        }
                    }
                    if (z2) {
                        break;
                    }
                    int i10 = iArr[i9];
                    int i11 = (i10 - i3) - 1;
                    double d6 = d3 * i11;
                    d2 += ((i11 - d6) * d6) / i11;
                    double d7 = (i11 - d6) / i11;
                    if (z3) {
                        if (z4) {
                            d2 += d4;
                        }
                        if (z) {
                            d2 += d7;
                        }
                    } else {
                        d2 += d4 * d7;
                    }
                    d4 = d7;
                    i3 = i10;
                    z4 = false;
                    z3 = false;
                    i7 = 0;
                    i9++;
                }
                if (i3 != i2 - 1) {
                    int i12 = (i2 - i3) - 1;
                    double d8 = d3 * i12;
                    d2 += ((i12 - d8) * d8) / i12;
                    d = (i12 - d8) / i12;
                } else {
                    d = 1.0d;
                }
                boolean z5 = i3 == extractValues[numOffsets - 1];
                if (z3) {
                    if (z4) {
                        d2 += d4;
                    }
                    if (z5) {
                        d2 += d;
                    }
                } else if (z5) {
                    d2 += d4 * d;
                }
            }
        }
        return (int) Math.min(Math.round(d2), OptimizerUtils.MAX_NUMCELLS_CP_DENSE);
    }

    private static int getNumRuns(AMapToData aMapToData, int i, int i2) {
        throw new NotImplementedException("Not Supported ever since the ubm was replaced by the map");
    }

    private static int[] getSortedSample(int i, int i2, long j, int i3) {
        int[] iArr = new int[i2];
        Random random = new Random(j == -1 ? System.nanoTime() : j);
        for (int i4 = 0; i4 < i2; i4++) {
            iArr[i4] = i4;
        }
        for (int i5 = i2; i5 < i; i5++) {
            if (random.nextInt(i5) < i2) {
                iArr[random.nextInt(i2)] = i5;
            }
        }
        if (i / 100 < i2) {
            for (int i6 = 0; i6 < i2 - 1; i6++) {
                int nextInt = random.nextInt(i2 - i6) + i6;
                int i7 = iArr[i6];
                iArr[i6] = iArr[nextInt];
                iArr[nextInt] = i7;
            }
        }
        if (i3 > 1) {
            Arrays.parallelSort(iArr);
        } else {
            Arrays.sort(iArr);
        }
        return iArr;
    }

    private MatrixBlock sampleData(int i) {
        Timing timing = new Timing(true);
        int[] sortedSample = getSortedSample(getNumRows(), i, this._cs.seed, this._k);
        LOG.debug("sampleRow:" + timing.stop());
        MatrixBlock sparseNotTransposedSamplePath = !this._cs.transposed ? this._data.isInSparseFormat() ? sparseNotTransposedSamplePath(sortedSample) : denseSamplePath(sortedSample) : defaultSlowSamplingPath(sortedSample);
        if (sparseNotTransposedSamplePath.isEmpty()) {
            return null;
        }
        return sparseNotTransposedSamplePath;
    }

    private MatrixBlock sparseNotTransposedSamplePath(int[] iArr) {
        MatrixBlock matrixBlock = new MatrixBlock(iArr.length, this._data.getNumColumns(), true);
        SparseRow[] sparseRowArr = new SparseRow[iArr.length];
        SparseBlock sparseBlock = this._data.getSparseBlock();
        for (int i = 0; i < iArr.length; i++) {
            sparseRowArr[i] = sparseBlock.get(iArr[i]);
        }
        matrixBlock.setSparseBlock(new SparseBlockMCSR(sparseRowArr, false));
        matrixBlock.recomputeNonZeros();
        this._transposed = true;
        return LibMatrixReorg.transposeInPlace(matrixBlock, this._k);
    }

    private MatrixBlock defaultSlowSamplingPath(int[] iArr) {
        MatrixBlock matrixBlock = this._cs.transposed ? new MatrixBlock(this._data.getNumColumns(), 1, false) : new MatrixBlock(this._data.getNumRows(), 1, false);
        for (int i : iArr) {
            matrixBlock.appendValue(i, 0, 1.0d);
        }
        return this._data.removeEmptyOperations(new MatrixBlock(), !this._cs.transposed, true, matrixBlock);
    }

    private MatrixBlock denseSamplePath(int[] iArr) {
        int length = iArr.length;
        long ceil = (long) Math.ceil(this._data.getNonZeros() / (this._cs.transposed ? this._data.getNumColumns() / length : this._data.getNumRows() / length));
        int numRows = this._cs.transposed ? this._data.getNumRows() : this._data.getNumColumns();
        MatrixBlock matrixBlock = new MatrixBlock(numRows, length, 0.4d > ((double) ceil) / ((double) (((long) length) * ((long) numRows))));
        matrixBlock.allocateBlock();
        DenseBlock denseBlock = this._data.getDenseBlock();
        if (matrixBlock.isInSparseFormat()) {
            SparseBlock sparseBlock = matrixBlock.getSparseBlock();
            if (!(sparseBlock instanceof SparseBlockMCSR)) {
                throw new NotImplementedException("Not Implemented support for dense sample into sparse: " + sparseBlock.getClass().getSimpleName());
            }
            SparseBlockMCSR sparseBlockMCSR = (SparseBlockMCSR) sparseBlock;
            int max = (int) Math.max(4.0d, Math.ceil(ceil / length));
            for (int i = 0; i < numRows; i++) {
                sparseBlockMCSR.allocate(i, max);
            }
            for (int i2 = 0; i2 < length; i2++) {
                int i3 = iArr[i2];
                double[] values = denseBlock.values(i3);
                int pos = denseBlock.pos(i3);
                for (int i4 = 0; i4 < numRows; i4++) {
                    sparseBlockMCSR.get(i4).append(i2, values[pos + i4]);
                }
            }
        } else {
            DenseBlock denseBlock2 = matrixBlock.getDenseBlock();
            for (int i5 = 0; i5 < length; i5++) {
                int i6 = iArr[i5];
                double[] values2 = denseBlock.values(i6);
                int pos2 = denseBlock.pos(i6);
                for (int i7 = 0; i7 < numRows; i7++) {
                    denseBlock2.values(i7)[(i7 * length) + i5] = values2[pos2 + i7];
                }
            }
        }
        matrixBlock.setNonZeros(ceil);
        this._transposed = true;
        return matrixBlock;
    }

    public String toString() {
        return super.toString() + " sampleSize: " + getSampleSize() + " transposed: " + this._transposed;
    }
}
