package org.apache.sysds.hops.estim;

import java.util.Arrays;
import java.util.Random;
import java.util.stream.IntStream;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.estim.SparsityEstimator;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.matrix.data.LibMatrixAgg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;

/* loaded from: input_file:org/apache/sysds/hops/estim/EstimatorMatrixHistogram.class */
public class EstimatorMatrixHistogram extends SparsityEstimator {
    private static final boolean DEFAULT_USE_EXTENDED = true;
    private static final boolean ADVANCED_SKETCH_PROP = false;
    private final boolean _useExtended;

    /* loaded from: input_file:org/apache/sysds/hops/estim/EstimatorMatrixHistogram$MatrixHistogram.class */
    public static class MatrixHistogram {
        private final int[] rNnz;
        private int[] rNnz1e;
        private final int[] cNnz;
        private int[] cNnz1e;
        private final int rMaxNnz;
        private final int cMaxNnz;
        private final int rN1;
        private final int cN1;
        private final int rNonEmpty;
        private final int cNonEmpty;
        private final int rNdiv2;
        private final int cNdiv2;
        private boolean fullDiag;
        private MatrixBlock _data;

        public MatrixHistogram(MatrixBlock matrixBlock, boolean z) {
            this.rNnz1e = null;
            this.cNnz1e = null;
            this._data = null;
            int numRows = matrixBlock.getNumRows();
            int numColumns = matrixBlock.getNumColumns();
            this.rNnz = new int[matrixBlock.getNumRows()];
            this.cNnz = new int[matrixBlock.getNumColumns()];
            this.fullDiag = ((long) matrixBlock.getNumRows()) == matrixBlock.getNonZeros() && matrixBlock.getNumRows() == matrixBlock.getNumColumns();
            if (matrixBlock.getLength() == matrixBlock.getNonZeros()) {
                Arrays.fill(this.rNnz, numColumns);
                Arrays.fill(this.cNnz, numRows);
            } else if (!matrixBlock.isEmpty()) {
                if (matrixBlock.isInSparseFormat()) {
                    SparseBlock sparseBlock = matrixBlock.getSparseBlock();
                    int i = 0;
                    while (i < numRows) {
                        if (!sparseBlock.isEmpty(i)) {
                            int pos = sparseBlock.pos(i);
                            int size = sparseBlock.size(i);
                            int[] indexes = sparseBlock.indexes(i);
                            this.rNnz[i] = size;
                            LibMatrixAgg.countAgg(sparseBlock.values(i), this.cNnz, indexes, pos, size);
                            this.fullDiag &= indexes[pos] == i;
                        }
                        i++;
                    }
                } else {
                    DenseBlock denseBlock = matrixBlock.getDenseBlock();
                    int i2 = 0;
                    while (i2 < numRows) {
                        this.rNnz[i2] = denseBlock.countNonZeros(i2);
                        LibMatrixAgg.countAgg(denseBlock.values(i2), this.cNnz, denseBlock.pos(i2), numColumns);
                        this.fullDiag &= this.rNnz[i2] == 1 && numColumns > i2 && denseBlock.get(i2, i2) != DataExpression.DEFAULT_DELIM_FILL_VALUE;
                        i2++;
                    }
                }
            }
            int[] deriveSummaryStatistics = deriveSummaryStatistics(this.rNnz, getCols());
            int[] deriveSummaryStatistics2 = deriveSummaryStatistics(this.cNnz, getRows());
            this.rMaxNnz = deriveSummaryStatistics[0];
            this.cMaxNnz = deriveSummaryStatistics2[0];
            this.rN1 = deriveSummaryStatistics[1];
            this.cN1 = deriveSummaryStatistics2[1];
            this.rNonEmpty = deriveSummaryStatistics[2];
            this.cNonEmpty = deriveSummaryStatistics2[2];
            this.rNdiv2 = deriveSummaryStatistics[3];
            this.cNdiv2 = deriveSummaryStatistics2[3];
            if (!z || matrixBlock.isEmpty()) {
                return;
            }
            if ((this.rMaxNnz > 1 || this.cMaxNnz > 1) && matrixBlock.getLength() != matrixBlock.getNonZeros()) {
                this.rNnz1e = new int[matrixBlock.getNumRows()];
                this.cNnz1e = new int[matrixBlock.getNumColumns()];
                if (matrixBlock.isInSparseFormat()) {
                    SparseBlock sparseBlock2 = matrixBlock.getSparseBlock();
                    for (int i3 = 0; i3 < numRows; i3++) {
                        if (!sparseBlock2.isEmpty(i3)) {
                            int size2 = sparseBlock2.size(i3);
                            int pos2 = sparseBlock2.pos(i3);
                            int[] indexes2 = sparseBlock2.indexes(i3);
                            for (int i4 = pos2; i4 < pos2 + size2; i4++) {
                                if (this.cNnz[indexes2[i4]] <= 1) {
                                    int[] iArr = this.rNnz1e;
                                    int i5 = i3;
                                    iArr[i5] = iArr[i5] + 1;
                                }
                            }
                            if (size2 == 1) {
                                int[] iArr2 = this.cNnz1e;
                                int i6 = indexes2[pos2];
                                iArr2[i6] = iArr2[i6] + 1;
                            }
                        }
                    }
                    return;
                }
                DenseBlock denseBlock2 = matrixBlock.getDenseBlock();
                for (int i7 = 0; i7 < numRows; i7++) {
                    double[] values = denseBlock2.values(i7);
                    int pos3 = denseBlock2.pos(i7);
                    boolean z2 = this.rNnz[i7] <= 1;
                    for (int i8 = 0; i8 < numColumns; i8++) {
                        if (values[pos3 + i8] != DataExpression.DEFAULT_DELIM_FILL_VALUE) {
                            if (this.cNnz[i8] <= 1) {
                                int[] iArr3 = this.rNnz1e;
                                int i9 = i7;
                                iArr3[i9] = iArr3[i9] + 1;
                            }
                            if (z2) {
                                int[] iArr4 = this.cNnz1e;
                                int i10 = i8;
                                iArr4[i10] = iArr4[i10] + 1;
                            }
                        }
                    }
                }
            }
        }

        public MatrixHistogram(int[] iArr, int[] iArr2, int[] iArr3, int[] iArr4, int i, int i2) {
            this.rNnz1e = null;
            this.cNnz1e = null;
            this._data = null;
            this.rNnz = iArr;
            this.rNnz1e = iArr2;
            this.cNnz = iArr3;
            this.cNnz1e = iArr4;
            this.rMaxNnz = i;
            this.cMaxNnz = i2;
            this.cN1 = -1;
            this.rN1 = -1;
            this.cNdiv2 = -1;
            this.rNdiv2 = -1;
            this.rNonEmpty = (int) Arrays.stream(this.rNnz).filter(i3 -> {
                return i3 != 0;
            }).count();
            this.cNonEmpty = (int) Arrays.stream(this.cNnz).filter(i4 -> {
                return i4 != 0;
            }).count();
        }

        public int getRows() {
            return this.rNnz.length;
        }

        public int getCols() {
            return this.cNnz.length;
        }

        public int[] getRowCounts() {
            return this.rNnz;
        }

        public int[] getColCounts() {
            return this.cNnz;
        }

        public long getNonZeros() {
            return getRows() < getCols() ? IntStream.range(0, getRows()).mapToLong(i -> {
                return this.rNnz[i];
            }).sum() : IntStream.range(0, getCols()).mapToLong(i2 -> {
                return this.cNnz[i2];
            }).sum();
        }

        public void setData(MatrixBlock matrixBlock) {
            this._data = matrixBlock;
        }

        public static MatrixHistogram deriveOutputHistogram(MatrixHistogram matrixHistogram, MatrixHistogram matrixHistogram2, double d, SparsityEstimator.OpCode opCode, long[] jArr) {
            switch (opCode) {
                case MM:
                    return deriveMMHistogram(matrixHistogram, matrixHistogram2, d);
                case MULT:
                    return deriveMultHistogram(matrixHistogram, matrixHistogram2);
                case PLUS:
                    return derivePlusHistogram(matrixHistogram, matrixHistogram2);
                case EQZERO:
                    return deriveEq0Histogram(matrixHistogram);
                case DIAG:
                    return deriveDiagHistogram(matrixHistogram);
                case CBIND:
                    return deriveCbindHistogram(matrixHistogram, matrixHistogram2);
                case RBIND:
                    return deriveRbindHistogram(matrixHistogram, matrixHistogram2);
                case NEQZERO:
                    return matrixHistogram;
                case TRANS:
                    return deriveTransHistogram(matrixHistogram);
                case RESHAPE:
                    return deriveReshapeHistogram(matrixHistogram, (int) jArr[0], (int) jArr[1]);
                default:
                    throw new NotImplementedException();
            }
        }

        public static DataCharacteristics deriveOutputCharacteristics(MatrixHistogram matrixHistogram, MatrixHistogram matrixHistogram2, double d, SparsityEstimator.OpCode opCode, long[] jArr) {
            switch (opCode) {
                case MM:
                    return new MatrixCharacteristics(matrixHistogram.getRows(), matrixHistogram2.getCols(), OptimizerUtils.getNnz(matrixHistogram.getRows(), matrixHistogram2.getCols(), d));
                case MULT:
                case PLUS:
                case EQZERO:
                case NEQZERO:
                    return new MatrixCharacteristics(matrixHistogram.getRows(), matrixHistogram.getCols(), OptimizerUtils.getNnz(matrixHistogram.getRows(), matrixHistogram.getCols(), d));
                case DIAG:
                    int rows = matrixHistogram.getCols() == 1 ? matrixHistogram.getRows() : 1;
                    return new MatrixCharacteristics(matrixHistogram.getRows(), rows, OptimizerUtils.getNnz(matrixHistogram.getRows(), rows, d));
                case CBIND:
                    return new MatrixCharacteristics(matrixHistogram.getRows(), matrixHistogram.getCols() + matrixHistogram2.getCols(), OptimizerUtils.getNnz(matrixHistogram.getRows(), matrixHistogram.getCols() + matrixHistogram2.getCols(), d));
                case RBIND:
                    return new MatrixCharacteristics(matrixHistogram.getRows() + matrixHistogram.getRows(), matrixHistogram.getCols(), OptimizerUtils.getNnz(matrixHistogram.getRows() + matrixHistogram2.getRows(), matrixHistogram.getCols(), d));
                case TRANS:
                    return new MatrixCharacteristics(matrixHistogram.getCols(), matrixHistogram.getRows(), matrixHistogram.getNonZeros());
                case RESHAPE:
                    return new MatrixCharacteristics((int) jArr[0], (int) jArr[1], OptimizerUtils.getNnz((int) jArr[0], (int) jArr[1], d));
                default:
                    throw new NotImplementedException();
            }
        }

        private static MatrixHistogram deriveMMHistogram(MatrixHistogram matrixHistogram, MatrixHistogram matrixHistogram2, double d) {
            if (matrixHistogram.fullDiag) {
                return matrixHistogram2;
            }
            if (matrixHistogram2.fullDiag) {
                return matrixHistogram;
            }
            long nonZeros = matrixHistogram.getNonZeros();
            long nonZeros2 = matrixHistogram2.getNonZeros();
            double rows = d * matrixHistogram.getRows() * matrixHistogram2.getCols();
            int i = 0;
            int i2 = 0;
            int[] iArr = new int[matrixHistogram.getRows()];
            Random random = new Random();
            for (int i3 = 0; i3 < matrixHistogram.getRows(); i3++) {
                iArr[i3] = probRound((rows / nonZeros) * matrixHistogram.rNnz[i3], random);
                i = Math.max(i, iArr[i3]);
            }
            int[] iArr2 = new int[matrixHistogram2.getCols()];
            for (int i4 = 0; i4 < matrixHistogram2.getCols(); i4++) {
                iArr2[i4] = probRound((rows / nonZeros2) * matrixHistogram2.cNnz[i4], random);
                i2 = Math.max(i2, iArr2[i4]);
            }
            return new MatrixHistogram(iArr, null, iArr2, null, i, i2);
        }

        private static MatrixHistogram deriveMultHistogram(MatrixHistogram matrixHistogram, MatrixHistogram matrixHistogram2) {
            double sum = (IntStream.range(0, matrixHistogram.getCols()).mapToDouble(i -> {
                return matrixHistogram.cNnz[i] * matrixHistogram2.cNnz[i];
            }).sum() / matrixHistogram.getNonZeros()) / matrixHistogram2.getNonZeros();
            double sum2 = (IntStream.range(0, matrixHistogram.getRows()).mapToDouble(i2 -> {
                return matrixHistogram.rNnz[i2] * matrixHistogram2.rNnz[i2];
            }).sum() / matrixHistogram.getNonZeros()) / matrixHistogram2.getNonZeros();
            int i3 = 0;
            int i4 = 0;
            Random random = new Random();
            int[] iArr = new int[matrixHistogram.getRows()];
            for (int i5 = 0; i5 < matrixHistogram.getRows(); i5++) {
                iArr[i5] = probRound(matrixHistogram.rNnz[i5] * matrixHistogram2.rNnz[i5] * sum, random);
                i3 = Math.max(i3, iArr[i5]);
            }
            int[] iArr2 = new int[matrixHistogram.getCols()];
            for (int i6 = 0; i6 < matrixHistogram.getCols(); i6++) {
                iArr2[i6] = probRound(matrixHistogram.cNnz[i6] * matrixHistogram2.cNnz[i6] * sum2, random);
                i4 = Math.max(i4, iArr2[i6]);
            }
            return new MatrixHistogram(iArr, null, iArr2, null, i3, i4);
        }

        private static MatrixHistogram derivePlusHistogram(MatrixHistogram matrixHistogram, MatrixHistogram matrixHistogram2) {
            double sum = (IntStream.range(0, matrixHistogram.getCols()).mapToDouble(i -> {
                return matrixHistogram.cNnz[i] * matrixHistogram2.cNnz[i];
            }).sum() / matrixHistogram.getNonZeros()) / matrixHistogram2.getNonZeros();
            double sum2 = (IntStream.range(0, matrixHistogram.getRows()).mapToDouble(i2 -> {
                return matrixHistogram.rNnz[i2] * matrixHistogram2.rNnz[i2];
            }).sum() / matrixHistogram.getNonZeros()) / matrixHistogram2.getNonZeros();
            int i3 = 0;
            int i4 = 0;
            Random random = new Random();
            int[] iArr = new int[matrixHistogram.getRows()];
            for (int i5 = 0; i5 < matrixHistogram.getRows(); i5++) {
                iArr[i5] = probRound((matrixHistogram.rNnz[i5] + matrixHistogram2.rNnz[i5]) - ((matrixHistogram.rNnz[i5] * matrixHistogram2.rNnz[i5]) * sum), random);
                i3 = Math.max(i3, iArr[i5]);
            }
            int[] iArr2 = new int[matrixHistogram.getCols()];
            for (int i6 = 0; i6 < matrixHistogram.getCols(); i6++) {
                iArr2[i6] = probRound((matrixHistogram.cNnz[i6] + matrixHistogram2.cNnz[i6]) - ((matrixHistogram.cNnz[i6] * matrixHistogram2.cNnz[i6]) * sum2), random);
                i4 = Math.max(i4, iArr2[i6]);
            }
            return new MatrixHistogram(iArr, null, iArr2, null, i3, i4);
        }

        private static MatrixHistogram deriveRbindHistogram(MatrixHistogram matrixHistogram, MatrixHistogram matrixHistogram2) {
            int[] addAll = ArrayUtils.addAll(matrixHistogram.rNnz, matrixHistogram2.rNnz);
            int max = Math.max(matrixHistogram.rMaxNnz, matrixHistogram2.rMaxNnz);
            int[] iArr = new int[matrixHistogram.getCols()];
            int i = 0;
            for (int i2 = 0; i2 < matrixHistogram.getCols(); i2++) {
                iArr[i2] = matrixHistogram.cNnz[i2] + matrixHistogram2.cNnz[i2];
                i = Math.max(i, iArr[i2]);
            }
            return new MatrixHistogram(addAll, null, iArr, null, max, i);
        }

        private static MatrixHistogram deriveCbindHistogram(MatrixHistogram matrixHistogram, MatrixHistogram matrixHistogram2) {
            int[] iArr = new int[matrixHistogram.getRows()];
            int i = 0;
            for (int i2 = 0; i2 < matrixHistogram.getRows(); i2++) {
                iArr[i2] = matrixHistogram.rNnz[i2] + matrixHistogram2.rNnz[i2];
                i = Math.max(i, iArr[i2]);
            }
            return new MatrixHistogram(iArr, null, ArrayUtils.addAll(matrixHistogram.cNnz, matrixHistogram2.cNnz), null, i, Math.max(matrixHistogram.cMaxNnz, matrixHistogram2.cMaxNnz));
        }

        private static MatrixHistogram deriveEq0Histogram(MatrixHistogram matrixHistogram) {
            int rows = matrixHistogram.getRows();
            int cols = matrixHistogram.getCols();
            int[] iArr = new int[rows];
            int[] iArr2 = new int[cols];
            int i = 0;
            int i2 = 0;
            for (int i3 = 0; i3 < rows; i3++) {
                iArr[i3] = cols - matrixHistogram.rNnz[i3];
                i = Math.max(i, iArr[i3]);
            }
            for (int i4 = 0; i4 < cols; i4++) {
                iArr2[i4] = rows - matrixHistogram.cNnz[i4];
                i2 = Math.max(i2, iArr2[i4]);
            }
            return new MatrixHistogram(iArr, null, iArr2, null, i, i2);
        }

        private static MatrixHistogram deriveDiagHistogram(MatrixHistogram matrixHistogram) {
            if (matrixHistogram.getCols() == 1) {
                return new MatrixHistogram(matrixHistogram.rNnz, null, matrixHistogram.rNnz, null, matrixHistogram.rMaxNnz, matrixHistogram.rMaxNnz);
            }
            int rows = matrixHistogram.getRows();
            int cols = matrixHistogram.getCols();
            int[] iArr = new int[rows];
            int[] iArr2 = new int[1];
            int i = 0;
            Random random = new Random();
            for (int i2 = 0; i2 < rows; i2++) {
                iArr[i2] = probRound(matrixHistogram.getNonZeros() / cols, random);
                i = Math.max(i, iArr[i2]);
                iArr2[0] = iArr2[0] + iArr[i2];
            }
            return new MatrixHistogram(iArr, null, iArr2, null, i, iArr2[0]);
        }

        private static MatrixHistogram deriveTransHistogram(MatrixHistogram matrixHistogram) {
            return new MatrixHistogram(matrixHistogram.cNnz, matrixHistogram.cNnz1e, matrixHistogram.rNnz, matrixHistogram.rNnz1e, matrixHistogram.cMaxNnz, matrixHistogram.rMaxNnz);
        }

        private static MatrixHistogram deriveReshapeHistogram(MatrixHistogram matrixHistogram, int i, int i2) {
            if (matrixHistogram.getRows() == i) {
                return matrixHistogram;
            }
            if (matrixHistogram.getCols() % i2 != 0 && matrixHistogram.getRows() % i != 0) {
                return null;
            }
            int rows = matrixHistogram.getRows();
            int cols = matrixHistogram.getCols();
            int[] iArr = new int[i];
            int[] iArr2 = new int[i2];
            int i3 = 0;
            int i4 = 0;
            if (matrixHistogram.getCols() % i2 == 0) {
                int cols2 = matrixHistogram.getCols() / i2;
                int i5 = 0;
                int i6 = 0;
                while (true) {
                    int i7 = i6;
                    if (i5 >= rows) {
                        break;
                    }
                    for (int i8 = 0; i8 < cols2; i8++) {
                        iArr[i7 + i8] = matrixHistogram.rNnz[i5] / cols2;
                    }
                    i3 = Math.max(i3, matrixHistogram.rNnz[i5] / cols2);
                    i5++;
                    i6 = i7 + cols2;
                }
                for (int i9 = 0; i9 < cols; i9++) {
                    int i10 = i9 % i2;
                    iArr2[i10] = iArr2[i10] + matrixHistogram.cNnz[i9];
                }
                for (int i11 = 0; i11 < i2; i11++) {
                    i4 = Math.max(i4, iArr2[i11]);
                }
            } else if (matrixHistogram.getRows() % i == 0) {
                int rows2 = matrixHistogram.getRows() / i;
                int i12 = 0;
                int i13 = 0;
                while (true) {
                    int i14 = i13;
                    if (i12 >= cols) {
                        break;
                    }
                    for (int i15 = 0; i15 < rows2; i15++) {
                        iArr2[i14 + i15] = matrixHistogram.cNnz[i12] / rows2;
                    }
                    i4 = Math.max(i4, matrixHistogram.cNnz[i12] / rows2);
                    i12++;
                    i13 = i14 + rows2;
                }
                int i16 = 0;
                int i17 = 0;
                while (i16 < rows) {
                    for (int i18 = 0; i18 < rows2; i18++) {
                        int i19 = i17;
                        iArr[i19] = iArr[i19] + matrixHistogram.rNnz[i16 + i18];
                    }
                    i16 += rows2;
                    i17++;
                }
                for (int i20 = 0; i20 < i; i20++) {
                    i3 = Math.max(i3, iArr[i20]);
                }
            }
            return new MatrixHistogram(iArr, null, iArr2, null, i3, i4);
        }

        private static int probRound(double d, Random random) {
            double floor = Math.floor(d);
            return (int) (d - floor > random.nextDouble() ? floor + 1.0d : floor);
        }

        private static int[] deriveSummaryStatistics(int[] iArr, int i) {
            int i2 = Integer.MIN_VALUE;
            int i3 = i / 2;
            int i4 = 0;
            int i5 = 0;
            int i6 = 0;
            for (int i7 = 0; i7 < iArr.length; i7++) {
                int i8 = iArr[i7];
                i2 = Math.max(i2, i8);
                i4 += i8 == 1 ? 1 : 0;
                i5 += i8 != 0 ? 1 : 0;
                i6 += i8 > i3 ? 1 : 0;
            }
            return new int[]{i2, i4, i5, i6};
        }
    }

    public EstimatorMatrixHistogram() {
        this(true);
    }

    public EstimatorMatrixHistogram(boolean z) {
        this._useExtended = z;
    }

    @Override // org.apache.sysds.hops.estim.SparsityEstimator
    public DataCharacteristics estim(MMNode mMNode) {
        return estim(mMNode, true);
    }

    public DataCharacteristics estim(MMNode mMNode, boolean z) {
        MatrixHistogram cachedSynopsis = getCachedSynopsis(mMNode.getLeft());
        MatrixHistogram cachedSynopsis2 = getCachedSynopsis(mMNode.getRight());
        double estimIntern = estimIntern(cachedSynopsis, cachedSynopsis2, mMNode.getOp(), mMNode.getMisc());
        if (z) {
            return MatrixHistogram.deriveOutputCharacteristics(cachedSynopsis, cachedSynopsis2, estimIntern, mMNode.getOp(), mMNode.getMisc());
        }
        if (cachedSynopsis2 != null && mMNode.getRight() != null) {
            cachedSynopsis2.setData(mMNode.getRight().isLeaf() ? mMNode.getRight().getData() : null);
        }
        MatrixHistogram deriveOutputHistogram = MatrixHistogram.deriveOutputHistogram(cachedSynopsis, cachedSynopsis2, estimIntern, mMNode.getOp(), mMNode.getMisc());
        mMNode.setSynopsis(deriveOutputHistogram);
        return mMNode.setDataCharacteristics(new MatrixCharacteristics(deriveOutputHistogram.getRows(), deriveOutputHistogram.getCols(), deriveOutputHistogram.getNonZeros()));
    }

    @Override // org.apache.sysds.hops.estim.SparsityEstimator
    public double estim(MatrixBlock matrixBlock, MatrixBlock matrixBlock2) {
        return estim(matrixBlock, matrixBlock2, SparsityEstimator.OpCode.MM);
    }

    @Override // org.apache.sysds.hops.estim.SparsityEstimator
    public double estim(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, SparsityEstimator.OpCode opCode) {
        if (isExactMetadataOp(opCode)) {
            return estimExactMetaData(matrixBlock.getDataCharacteristics(), matrixBlock2.getDataCharacteristics(), opCode).getSparsity();
        }
        MatrixHistogram matrixHistogram = new MatrixHistogram(matrixBlock, this._useExtended);
        return estimIntern(matrixHistogram, matrixBlock == matrixBlock2 ? matrixHistogram : new MatrixHistogram(matrixBlock2, this._useExtended), opCode, null);
    }

    @Override // org.apache.sysds.hops.estim.SparsityEstimator
    public double estim(MatrixBlock matrixBlock, SparsityEstimator.OpCode opCode) {
        return isExactMetadataOp(opCode) ? estimExactMetaData(matrixBlock.getDataCharacteristics(), null, opCode).getSparsity() : estimIntern(new MatrixHistogram(matrixBlock, this._useExtended), null, opCode, null);
    }

    private MatrixHistogram getCachedSynopsis(MMNode mMNode) {
        if (mMNode == null) {
            return null;
        }
        if (mMNode.isLeaf() && mMNode.getSynopsis() == null) {
            mMNode.setSynopsis(new MatrixHistogram(mMNode.getData(), this._useExtended));
        } else if (!mMNode.isLeaf()) {
            estim(mMNode, false);
        }
        return (MatrixHistogram) mMNode.getSynopsis();
    }

    public double estimIntern(MatrixHistogram matrixHistogram, MatrixHistogram matrixHistogram2, SparsityEstimator.OpCode opCode, long[] jArr) {
        double rows = matrixHistogram.getRows() * matrixHistogram.getCols();
        switch (opCode) {
            case MM:
                return estimInternMM(matrixHistogram, matrixHistogram2);
            case MULT:
                double sum = (IntStream.range(0, matrixHistogram.getCols()).mapToDouble(i -> {
                    return matrixHistogram.cNnz[i] * matrixHistogram2.cNnz[i];
                }).sum() / matrixHistogram.getNonZeros()) / matrixHistogram2.getNonZeros();
                return IntStream.range(0, matrixHistogram.getRows()).mapToDouble(i2 -> {
                    return matrixHistogram.rNnz[i2] * matrixHistogram2.rNnz[i2] * sum;
                }).sum() / rows;
            case PLUS:
                double sum2 = (IntStream.range(0, matrixHistogram.getCols()).mapToDouble(i3 -> {
                    return matrixHistogram.cNnz[i3] * matrixHistogram2.cNnz[i3];
                }).sum() / matrixHistogram.getNonZeros()) / matrixHistogram2.getNonZeros();
                return IntStream.range(0, matrixHistogram.getRows()).mapToDouble(i4 -> {
                    return (matrixHistogram.rNnz[i4] + matrixHistogram2.rNnz[i4]) - ((matrixHistogram.rNnz[i4] * matrixHistogram2.rNnz[i4]) * sum2);
                }).sum() / rows;
            case EQZERO:
                return OptimizerUtils.getSparsity(matrixHistogram.getRows(), matrixHistogram.getCols(), (matrixHistogram.getRows() * matrixHistogram.getCols()) - matrixHistogram.getNonZeros());
            case DIAG:
                return matrixHistogram.getCols() == 1 ? OptimizerUtils.getSparsity(matrixHistogram.getRows(), matrixHistogram.getRows(), matrixHistogram.getNonZeros()) : OptimizerUtils.getSparsity(matrixHistogram.getRows(), 1L, Math.min(matrixHistogram.getRows(), matrixHistogram.getNonZeros()));
            case CBIND:
                return OptimizerUtils.getSparsity(matrixHistogram.getRows(), matrixHistogram.getCols() + matrixHistogram2.getCols(), matrixHistogram.getNonZeros() + matrixHistogram2.getNonZeros());
            case RBIND:
                return OptimizerUtils.getSparsity(matrixHistogram.getRows() + matrixHistogram2.getRows(), matrixHistogram.getCols(), matrixHistogram.getNonZeros() + matrixHistogram2.getNonZeros());
            case NEQZERO:
            case TRANS:
            case RESHAPE:
                return OptimizerUtils.getSparsity(matrixHistogram.getRows(), matrixHistogram.getCols(), matrixHistogram.getNonZeros());
            default:
                throw new NotImplementedException();
        }
    }

    private double estimInternMM(MatrixHistogram matrixHistogram, MatrixHistogram matrixHistogram2) {
        long j = 0;
        if (matrixHistogram.rMaxNnz <= 1 || matrixHistogram2.cMaxNnz <= 1) {
            for (int i = 0; i < matrixHistogram.getCols(); i++) {
                j += matrixHistogram.cNnz[i] * matrixHistogram2.rNnz[i];
            }
        } else if (matrixHistogram.cNnz1e == null && matrixHistogram2.rNnz1e == null) {
            long rows = this._useExtended ? matrixHistogram.rNonEmpty * matrixHistogram2.cNonEmpty : matrixHistogram.getRows() * matrixHistogram2.getCols();
            double d = 0.0d;
            for (int i2 = 0; i2 < matrixHistogram.getCols(); i2++) {
                double d2 = (matrixHistogram.cNnz[i2] * matrixHistogram2.rNnz[i2]) / rows;
                d = (d + d2) - (d * d2);
            }
            j = (long) (d * rows);
        } else {
            long rows2 = this._useExtended ? (matrixHistogram.rNonEmpty - matrixHistogram.rN1) * (matrixHistogram2.cNonEmpty - matrixHistogram2.cN1) : (matrixHistogram.getRows() - matrixHistogram.rN1) * (matrixHistogram2.getCols() - matrixHistogram2.cN1);
            double d3 = 0.0d;
            for (int i3 = 0; i3 < matrixHistogram.getCols(); i3++) {
                int i4 = matrixHistogram.cNnz1e != null ? matrixHistogram.cNnz1e[i3] : 0;
                j = j + (i4 * matrixHistogram2.rNnz[i3]) + ((matrixHistogram.cNnz[i3] - i4) * (matrixHistogram2.rNnz1e != null ? matrixHistogram2.rNnz1e[i3] : 0));
                double d4 = ((matrixHistogram.cNnz[i3] - i4) * (matrixHistogram2.rNnz[i3] - r18)) / rows2;
                d3 = (d3 + d4) - (d3 * d4);
            }
            j += (long) (d3 * rows2);
        }
        if (this._useExtended) {
            j = (matrixHistogram.rNdiv2 < 0 || matrixHistogram2.cNdiv2 < 0) ? j : Math.max(matrixHistogram.rNdiv2 * matrixHistogram2.cNdiv2, j);
        }
        return OptimizerUtils.getSparsity(matrixHistogram.getRows(), matrixHistogram2.getCols(), j);
    }
}
