package org.apache.sysds.hops.estim;

import org.apache.commons.lang3.NotImplementedException;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.estim.SparsityEstimator;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.matrix.data.LibMatrixAgg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.util.UtilFunctions;

/* loaded from: input_file:org/apache/sysds/hops/estim/EstimatorSample.class */
public class EstimatorSample extends SparsityEstimator {
    private static final double SAMPLE_FRACTION = 0.1d;
    private final double _frac;
    private final boolean _extended;

    public EstimatorSample() {
        this(0.1d, false);
    }

    public EstimatorSample(double d) {
        this(d, false);
    }

    public EstimatorSample(double d, boolean z) {
        if (d <= DataExpression.DEFAULT_DELIM_FILL_VALUE || d > 1.0d) {
            throw new DMLRuntimeException("Invalid sample fraction: " + d);
        }
        this._frac = d;
        this._extended = z;
    }

    @Override // org.apache.sysds.hops.estim.SparsityEstimator
    public DataCharacteristics estim(MMNode mMNode) {
        LOG.warn("Recursive estimates not supported by EstimatorSample, falling back to EstimatorBasicAvg.");
        return new EstimatorBasicAvg().estim(mMNode);
    }

    @Override // org.apache.sysds.hops.estim.SparsityEstimator
    public double estim(MatrixBlock matrixBlock, MatrixBlock matrixBlock2) {
        return estim(matrixBlock, matrixBlock2, SparsityEstimator.OpCode.MM);
    }

    @Override // org.apache.sysds.hops.estim.SparsityEstimator
    public double estim(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, SparsityEstimator.OpCode opCode) {
        switch (opCode) {
            case MM:
                int numColumns = matrixBlock.getNumColumns();
                int[] sortedSampleIndexes = UtilFunctions.getSortedSampleIndexes(numColumns, (int) Math.max(numColumns * this._frac, 1.0d));
                int length = sortedSampleIndexes.length;
                int[] computeColumnNnz = computeColumnNnz(matrixBlock, sortedSampleIndexes);
                if (!this._extended) {
                    long j = 0;
                    for (int i = 0; i < length; i++) {
                        j = Math.max(j, computeColumnNnz[i] * matrixBlock2.recomputeNonZeros(sortedSampleIndexes[i], sortedSampleIndexes[i]));
                    }
                    return OptimizerUtils.getSparsity(matrixBlock.getNumRows(), matrixBlock2.getNumColumns(), j);
                }
                double numRows = matrixBlock.getNumRows() * matrixBlock2.getNumColumns();
                double d = 0.0d;
                double d2 = 1.0d;
                for (int i2 = 0; i2 < sortedSampleIndexes.length; i2++) {
                    double recomputeNonZeros = (computeColumnNnz[i2] * matrixBlock2.recomputeNonZeros(sortedSampleIndexes[i2], sortedSampleIndexes[i2])) / numRows;
                    d += recomputeNonZeros;
                    d2 *= 1.0d - recomputeNonZeros;
                }
                return 1.0d - (Math.pow(1.0d - ((1.0d / length) * d), numColumns - length) * d2);
            case MULT:
                int max = Math.max(matrixBlock.getNumColumns(), matrixBlock.getNumRows());
                int[] sortedSampleIndexes2 = UtilFunctions.getSortedSampleIndexes(max, (int) Math.max(max * this._frac, 1.0d));
                double d3 = 0.0d;
                if (matrixBlock.getNumColumns() > matrixBlock.getNumRows()) {
                    int[] computeColumnNnz2 = computeColumnNnz(matrixBlock, sortedSampleIndexes2);
                    int[] computeColumnNnz3 = computeColumnNnz(matrixBlock2, sortedSampleIndexes2);
                    for (int i3 = 0; i3 < sortedSampleIndexes2.length; i3++) {
                        d3 += (computeColumnNnz2[i3] / matrixBlock.getNumRows()) * (computeColumnNnz3[i3] / matrixBlock.getNumRows());
                    }
                } else {
                    int[] computeRowNnz = computeRowNnz(matrixBlock, sortedSampleIndexes2);
                    int[] computeRowNnz2 = computeRowNnz(matrixBlock2, sortedSampleIndexes2);
                    for (int i4 = 0; i4 < sortedSampleIndexes2.length; i4++) {
                        d3 += (computeRowNnz[i4] / matrixBlock.getNumColumns()) * (computeRowNnz2[i4] / matrixBlock.getNumColumns());
                    }
                }
                return d3 / sortedSampleIndexes2.length;
            case PLUS:
                int max2 = Math.max(matrixBlock.getNumColumns(), matrixBlock.getNumRows());
                int[] sortedSampleIndexes3 = UtilFunctions.getSortedSampleIndexes(max2, (int) Math.max(max2 * this._frac, 1.0d));
                double d4 = 0.0d;
                if (matrixBlock.getNumColumns() > matrixBlock.getNumRows()) {
                    int[] computeColumnNnz4 = computeColumnNnz(matrixBlock, sortedSampleIndexes3);
                    int[] computeColumnNnz5 = computeColumnNnz(matrixBlock2, sortedSampleIndexes3);
                    for (int i5 = 0; i5 < sortedSampleIndexes3.length; i5++) {
                        d4 += ((computeColumnNnz4[i5] / matrixBlock.getNumRows()) + (computeColumnNnz5[i5] / matrixBlock.getNumRows())) - ((computeColumnNnz4[i5] / matrixBlock.getNumRows()) * (computeColumnNnz5[i5] / matrixBlock.getNumRows()));
                    }
                } else {
                    int[] computeRowNnz3 = computeRowNnz(matrixBlock, sortedSampleIndexes3);
                    int[] computeRowNnz4 = computeRowNnz(matrixBlock2, sortedSampleIndexes3);
                    for (int i6 = 0; i6 < sortedSampleIndexes3.length; i6++) {
                        d4 += ((computeRowNnz3[i6] / matrixBlock.getNumColumns()) + (computeRowNnz4[i6] / matrixBlock.getNumColumns())) - ((computeRowNnz3[i6] / matrixBlock.getNumColumns()) * (computeRowNnz4[i6] / matrixBlock.getNumColumns()));
                    }
                }
                return d4 / sortedSampleIndexes3.length;
            case RBIND:
            case CBIND:
            case EQZERO:
            case NEQZERO:
            case TRANS:
            case DIAG:
            case RESHAPE:
                return OptimizerUtils.getSparsity(estimExactMetaData(matrixBlock.getDataCharacteristics(), matrixBlock2.getDataCharacteristics(), opCode));
            default:
                throw new NotImplementedException();
        }
    }

    @Override // org.apache.sysds.hops.estim.SparsityEstimator
    public double estim(MatrixBlock matrixBlock, SparsityEstimator.OpCode opCode) {
        return estim(matrixBlock, null, opCode);
    }

    private static int[] computeColumnNnz(MatrixBlock matrixBlock, int[] iArr) {
        int[] iArr2 = new int[matrixBlock.getNumColumns()];
        if (matrixBlock.isInSparseFormat()) {
            SparseBlock sparseBlock = matrixBlock.getSparseBlock();
            for (int i = 0; i < matrixBlock.getNumRows(); i++) {
                if (!sparseBlock.isEmpty(i)) {
                    LibMatrixAgg.countAgg(sparseBlock.values(i), iArr2, sparseBlock.indexes(i), sparseBlock.pos(i), sparseBlock.size(i));
                }
            }
        } else {
            DenseBlock denseBlock = matrixBlock.getDenseBlock();
            for (int i2 = 0; i2 < matrixBlock.getNumRows(); i2++) {
                double[] values = denseBlock.values(i2);
                int pos = denseBlock.pos(i2);
                for (int i3 = 0; i3 < matrixBlock.getNumColumns(); i3++) {
                    int i4 = i3;
                    iArr2[i4] = iArr2[i4] + (values[pos + i3] != DataExpression.DEFAULT_DELIM_FILL_VALUE ? 1 : 0);
                }
            }
        }
        int[] iArr3 = new int[iArr.length];
        for (int i5 = 0; i5 < iArr.length; i5++) {
            iArr3[i5] = iArr2[iArr[i5]];
        }
        return iArr3;
    }

    private static int[] computeRowNnz(MatrixBlock matrixBlock, int[] iArr) {
        int[] iArr2 = new int[iArr.length];
        for (int i = 0; i < iArr.length; i++) {
            iArr2[i] = (int) matrixBlock.recomputeNonZeros(iArr[i], iArr[i]);
        }
        return iArr2;
    }
}
