package org.apache.sysds.hops.estim;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.Random;
import org.apache.commons.lang.NotImplementedException;
import org.apache.commons.math3.random.Well1024a;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.estim.SparsityEstimator;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.matrix.data.LibMatrixDatagen;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.meta.DataCharacteristics;

/* loaded from: input_file:org/apache/sysds/hops/estim/EstimatorSampleRa.class */
public class EstimatorSampleRa extends SparsityEstimator {
    private static final int RUNS = -1;
    private static final double SAMPLE_FRACTION = 0.1d;
    private static final double EPSILON = 0.05d;
    private static final double DELTA = 0.1d;
    private static final int K = -1;
    private final int _runs;
    private final double _sampleFrac;
    private final double _eps;
    private final double _delta;
    private final int _k;
    private final Well1024a _bigrand;
    private double[] h1;
    private double[] h2;
    private double[] h3;
    private double[] h4;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/hops/estim/EstimatorSampleRa$AdjacencyLists.class */
    public static class AdjacencyLists {
        private ArrayList<Integer>[] indexes;

        public AdjacencyLists(MatrixBlock matrixBlock, boolean z) {
            int numRows = z ? matrixBlock.getNumRows() : matrixBlock.getNumColumns();
            this.indexes = new ArrayList[numRows];
            for (int i = 0; i < numRows; i++) {
                this.indexes[i] = new ArrayList<>();
            }
            if (matrixBlock.isEmptyBlock(false)) {
                return;
            }
            if (!matrixBlock.isInSparseFormat()) {
                for (int i2 = 0; i2 < matrixBlock.getNumRows(); i2++) {
                    for (int i3 = 0; i3 < matrixBlock.getNumColumns(); i3++) {
                        if (matrixBlock.quickGetValue(i2, i3) != DataExpression.DEFAULT_DELIM_FILL_VALUE) {
                            this.indexes[z ? i2 : i3].add(Integer.valueOf(z ? i3 : i2));
                        }
                    }
                }
                return;
            }
            SparseBlock sparseBlock = matrixBlock.getSparseBlock();
            for (int i4 = 0; i4 < sparseBlock.numRows(); i4++) {
                if (!sparseBlock.isEmpty(i4)) {
                    int pos = sparseBlock.pos(i4);
                    int size = sparseBlock.size(i4);
                    int[] indexes = sparseBlock.indexes(i4);
                    for (int i5 = pos; i5 < pos + size; i5++) {
                        this.indexes[z ? i4 : indexes[i5]].add(Integer.valueOf(z ? indexes[i5] : i4));
                    }
                }
            }
        }

        public ArrayList<Integer> getList(int i) {
            return this.indexes[i];
        }
    }

    public EstimatorSampleRa() {
        this(-1, 0.1d, EPSILON, 0.1d, -1);
    }

    public EstimatorSampleRa(double d) {
        this(-1, d, EPSILON, 0.1d, -1);
    }

    public EstimatorSampleRa(int i, double d, double d2, double d3, int i2) {
        if (d <= DataExpression.DEFAULT_DELIM_FILL_VALUE || d > 1.0d) {
            throw new DMLRuntimeException("Invalid sample fraction: " + d);
        }
        this._sampleFrac = d;
        this._eps = d2;
        this._delta = d3;
        this._runs = i < 0 ? (int) (Math.log(1.0d / this._delta) / Math.log(2.0d)) : i;
        this._k = i2 < 0 ? (int) Math.ceil(1.0d / (this._eps * this._eps)) : i2;
        this._bigrand = LibMatrixDatagen.setupSeedsForRand(this._k);
    }

    @Override // org.apache.sysds.hops.estim.SparsityEstimator
    public DataCharacteristics estim(MMNode mMNode) {
        LOG.warn("Recursive estimates not supported by EstimatorSampleRa, falling back to EstimatorBasicAvg.");
        return new EstimatorBasicAvg().estim(mMNode);
    }

    @Override // org.apache.sysds.hops.estim.SparsityEstimator
    public double estim(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, SparsityEstimator.OpCode opCode) {
        if (opCode == SparsityEstimator.OpCode.MM) {
            return estim(matrixBlock, matrixBlock2);
        }
        throw new NotImplementedException();
    }

    @Override // org.apache.sysds.hops.estim.SparsityEstimator
    public double estim(MatrixBlock matrixBlock, SparsityEstimator.OpCode opCode) {
        throw new NotImplementedException();
    }

    @Override // org.apache.sysds.hops.estim.SparsityEstimator
    public double estim(MatrixBlock matrixBlock, MatrixBlock matrixBlock2) {
        double[] dArr = new double[this._runs];
        for (int i = 0; i < this._runs; i++) {
            initHashArrays(matrixBlock.getNumRows(), matrixBlock.getNumColumns(), matrixBlock2.getNumColumns());
            dArr[i] = estimateSize(matrixBlock, matrixBlock2);
        }
        Arrays.sort(dArr);
        return OptimizerUtils.getSparsity(matrixBlock.getNumRows(), matrixBlock2.getNumColumns(), (long) dArr[this._runs / 2]);
    }

    private void initHashArrays(int i, int i2, int i3) {
        if (this.h1 == null) {
            this.h1 = new double[i];
            this.h2 = new double[i3];
            this.h3 = new double[i3];
            this.h4 = new double[i];
        }
        Random random = new Random(this._bigrand.nextLong());
        for (int i4 = 0; i4 < this.h1.length; i4++) {
            this.h1[i4] = random.nextDouble();
        }
        for (int i5 = 0; i5 < this.h2.length; i5++) {
            this.h2[i5] = random.nextDouble();
        }
        for (int i6 = 0; i6 < this.h3.length; i6++) {
            this.h3[i6] = random.nextDouble();
        }
        for (int i7 = 0; i7 < this.h4.length; i7++) {
            this.h4[i7] = random.nextDouble();
        }
    }

    private double estimateSize(MatrixBlock matrixBlock, MatrixBlock matrixBlock2) {
        AdjacencyLists adjacencyLists = new AdjacencyLists(matrixBlock, false);
        AdjacencyLists adjacencyLists2 = new AdjacencyLists(matrixBlock2, true);
        ArrayList<Double> arrayList = new ArrayList<>();
        double d = 1.0d;
        int i = 0;
        for (int i2 = 0; i2 < matrixBlock.getNumColumns(); i2++) {
            ArrayList<Integer> list = adjacencyLists.getList(i2);
            ArrayList<Integer> list2 = adjacencyLists2.getList(i2);
            if (!list.isEmpty() && !list2.isEmpty()) {
                Integer[] numArr = (Integer[]) list.stream().sorted(Comparator.comparing(num -> {
                    return Double.valueOf(this.h1[num.intValue()]);
                })).toArray(i3 -> {
                    return new Integer[i3];
                });
                Integer[] numArr2 = (Integer[]) list2.stream().sorted(Comparator.comparing(num2 -> {
                    return Double.valueOf(this.h2[num2.intValue()]);
                })).toArray(i4 -> {
                    return new Integer[i4];
                });
                int i5 = 0;
                for (int i6 = 0; i6 < numArr2.length; i6++) {
                    int length = i5 > 0 ? i5 - 1 : numArr.length - 1;
                    while (h(numArr[i5].intValue(), numArr2[i6].intValue()) > h(numArr[length].intValue(), numArr2[i6].intValue())) {
                        i5 = (i5 + 1) % numArr.length;
                    }
                    int i7 = i5;
                    if (this.h3[numArr2[i6].intValue()] <= this._sampleFrac) {
                        int i8 = 0;
                        while (h(numArr[i7].intValue(), numArr2[i6].intValue()) < d && i8 < numArr.length) {
                            if (this.h4[numArr[i7].intValue()] > this._sampleFrac) {
                                i7 = (i7 + 1) % numArr.length;
                                i8++;
                            } else {
                                arrayList.add(Double.valueOf(h(numArr[i7].intValue(), numArr2[i6].intValue())));
                                i++;
                                if (i > this._k) {
                                    sortAndTruncate(arrayList);
                                    if (arrayList.size() == this._k) {
                                        d = arrayList.get(arrayList.size() - 1).doubleValue();
                                    }
                                    i = 0;
                                }
                                i7 = (i7 + 1) % numArr.length;
                                i8++;
                            }
                        }
                    }
                }
            }
        }
        sortAndTruncate(arrayList);
        if (arrayList.size() == this._k) {
            return this._k / ((arrayList.get(arrayList.size() - 1).doubleValue() * this._sampleFrac) * this._sampleFrac);
        }
        return arrayList.size() / (this._sampleFrac * this._sampleFrac);
    }

    public void sortAndTruncate(ArrayList<Double> arrayList) {
        Collections.sort(arrayList);
        int i = 1;
        while (i < arrayList.size()) {
            if (arrayList.get(i).doubleValue() / arrayList.get(i - 1).doubleValue() < 1.0000000001d) {
                arrayList.remove(i);
                i--;
            }
            i++;
        }
        arrayList.subList(Math.min(arrayList.size(), this._k), arrayList.size()).clear();
    }

    public double h(int i, int i2) {
        double d = this.h1[i] - this.h2[i2];
        return d < DataExpression.DEFAULT_DELIM_FILL_VALUE ? d + 1.0d : d;
    }
}
