package org.apache.sysds.runtime.matrix.data;

import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.PriorityQueue;
import java.util.Set;
import org.antlr.v4.runtime.atn.PredictionContext;
import org.apache.commons.lang.NotImplementedException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLException;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator;
import org.apache.sysds.utils.Hash;

/* loaded from: input_file:org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.class */
public class LibMatrixCountDistinct {
    private static final Log LOG = LogFactory.getLog(LibMatrixCountDistinct.class.getName());
    public static int minimumSize = 1024;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct$SmallestPriorityQueue.class */
    public static class SmallestPriorityQueue {
        private Set<Integer> containedSet = new HashSet(1);
        private PriorityQueue<Integer> smallestHashes;
        private int k;

        public SmallestPriorityQueue(int i) {
            this.smallestHashes = new PriorityQueue<>(i, Collections.reverseOrder());
            this.k = i;
        }

        public void add(int i) {
            if (this.containedSet.contains(Integer.valueOf(i))) {
                return;
            }
            if (this.smallestHashes.size() < this.k) {
                this.smallestHashes.add(Integer.valueOf(i));
                this.containedSet.add(Integer.valueOf(i));
            } else if (i < this.smallestHashes.peek().intValue()) {
                LibMatrixCountDistinct.LOG.trace(this.smallestHashes.peek() + " -- " + i);
                this.smallestHashes.add(Integer.valueOf(i));
                this.containedSet.add(Integer.valueOf(i));
                this.containedSet.remove(this.smallestHashes.poll());
            }
        }

        public int size() {
            return this.smallestHashes.size();
        }

        public int peek() {
            return this.smallestHashes.peek().intValue();
        }

        public int poll() {
            return this.smallestHashes.poll().intValue();
        }

        public String toString() {
            return this.smallestHashes.toString();
        }
    }

    private LibMatrixCountDistinct() {
    }

    public static int estimateDistinctValues(MatrixBlock matrixBlock, CountDistinctOperator countDistinctOperator) {
        int countDistinctValuesKVM;
        if (countDistinctOperator.operatorType == CountDistinctOperator.CountDistinctTypes.KMV && (countDistinctOperator.hashType == Hash.HashType.ExpHash || countDistinctOperator.hashType == Hash.HashType.StandardJava)) {
            throw new DMLException("Invalid hashing configuration using " + countDistinctOperator.hashType + " and " + countDistinctOperator.operatorType);
        }
        if (countDistinctOperator.operatorType == CountDistinctOperator.CountDistinctTypes.HLL) {
            throw new NotImplementedException("HyperLogLog not implemented");
        }
        if (matrixBlock.getLength() == 1 || matrixBlock.isEmpty()) {
            return 1;
        }
        if (matrixBlock.getNonZeros() < minimumSize) {
            countDistinctValuesKVM = countDistinctValuesNaive(matrixBlock);
        } else {
            switch (countDistinctOperator.operatorType) {
                case COUNT:
                    countDistinctValuesKVM = countDistinctValuesNaive(matrixBlock);
                    break;
                case KMV:
                    countDistinctValuesKVM = countDistinctValuesKVM(matrixBlock, countDistinctOperator);
                    break;
                default:
                    throw new DMLException("Invalid or not implemented Estimator Type");
            }
        }
        if (countDistinctValuesKVM == 0) {
            throw new DMLRuntimeException("Impossible estimate of distinct values");
        }
        return countDistinctValuesKVM;
    }

    private static int countDistinctValuesNaive(MatrixBlock matrixBlock) {
        HashSet hashSet = new HashSet();
        if (matrixBlock instanceof CompressedMatrixBlock) {
            CompressedMatrixBlock compressedMatrixBlock = (CompressedMatrixBlock) matrixBlock;
            if (compressedMatrixBlock.isOverlapping()) {
                matrixBlock = compressedMatrixBlock.decompress();
            } else {
                Iterator<AColGroup> it = ((CompressedMatrixBlock) matrixBlock).getColGroups().iterator();
                while (it.hasNext()) {
                    countDistinctValuesNaive(it.next().getValues(), hashSet);
                }
            }
        }
        long nonZeros = matrixBlock.getNonZeros();
        if (nonZeros != -1 && nonZeros < matrixBlock.getNumColumns() * matrixBlock.getNumRows()) {
            hashSet.add(Double.valueOf(DataExpression.DEFAULT_DELIM_FILL_VALUE));
        }
        if (matrixBlock.sparseBlock != null) {
            SparseBlock sparseBlock = matrixBlock.sparseBlock;
            if (matrixBlock.sparseBlock.isContiguous()) {
                countDistinctValuesNaive(sparseBlock.values(0), hashSet);
            } else {
                for (int i = 0; i < matrixBlock.getNumRows(); i++) {
                    if (!sparseBlock.isEmpty(i)) {
                        countDistinctValuesNaive(matrixBlock.sparseBlock.values(i), hashSet);
                    }
                }
            }
        } else if (matrixBlock.denseBlock != null) {
            DenseBlock denseBlock = matrixBlock.denseBlock;
            for (int i2 = 0; i2 <= denseBlock.numBlocks(); i2++) {
                countDistinctValuesNaive(denseBlock.valuesAt(i2), hashSet);
            }
        }
        return hashSet.size();
    }

    private static Set<Double> countDistinctValuesNaive(double[] dArr, Set<Double> set) {
        for (double d : dArr) {
            set.add(Double.valueOf(d));
        }
        return set;
    }

    private static int countDistinctValuesKVM(MatrixBlock matrixBlock, CountDistinctOperator countDistinctOperator) {
        long nonZeros = matrixBlock.getNonZeros() + 1;
        long j = nonZeros * nonZeros;
        int i = j > OptimizerUtils.MAX_NUMCELLS_CP_DENSE ? PredictionContext.EMPTY_RETURN_STATE : (int) j;
        LOG.debug("M not forced to int size: " + j);
        LOG.debug("M: " + i);
        int i2 = nonZeros > 64 ? 64 : (int) nonZeros;
        SmallestPriorityQueue smallestPriorityQueue = new SmallestPriorityQueue(i2);
        countDistinctValuesKVM(matrixBlock, countDistinctOperator.hashType, i2, smallestPriorityQueue, i);
        LOG.debug("M: " + i);
        LOG.debug("smallest hash:" + smallestPriorityQueue.peek());
        LOG.debug("spq: " + smallestPriorityQueue.toString());
        if (smallestPriorityQueue.size() < i2) {
            return smallestPriorityQueue.size();
        }
        double poll = smallestPriorityQueue.poll() / i;
        LOG.debug("U_k : " + poll);
        double d = (i2 - 1) / poll;
        LOG.debug("Estimate: " + d);
        double min = Math.min(d, nonZeros);
        LOG.debug("Ceil worst case: " + nonZeros);
        return (int) min;
    }

    private static void countDistinctValuesKVM(MatrixBlock matrixBlock, Hash.HashType hashType, int i, SmallestPriorityQueue smallestPriorityQueue, int i2) {
        if (matrixBlock.sparseBlock == null && matrixBlock.denseBlock == null) {
            Iterator<AColGroup> it = ((CompressedMatrixBlock) matrixBlock).getColGroups().iterator();
            while (it.hasNext()) {
                countDistinctValuesKVM(it.next().getValues(), hashType, i, smallestPriorityQueue, i2);
            }
            return;
        }
        if (matrixBlock.sparseBlock == null) {
            DenseBlock denseBlock = matrixBlock.denseBlock;
            int index = denseBlock.index(0);
            int index2 = denseBlock.index(matrixBlock.rlen);
            for (int i3 = index; i3 <= index2; i3++) {
                countDistinctValuesKVM(denseBlock.valuesAt(i3), hashType, i, smallestPriorityQueue, i2);
            }
            return;
        }
        SparseBlock sparseBlock = matrixBlock.sparseBlock;
        if (matrixBlock.sparseBlock.isContiguous()) {
            countDistinctValuesKVM(sparseBlock.values(0), hashType, i, smallestPriorityQueue, i2);
            return;
        }
        for (int i4 = 0; i4 < matrixBlock.getNumRows(); i4++) {
            if (!sparseBlock.isEmpty(i4)) {
                countDistinctValuesKVM(matrixBlock.sparseBlock.values(i4), hashType, i, smallestPriorityQueue, i2);
            }
        }
    }

    private static void countDistinctValuesKVM(double[] dArr, Hash.HashType hashType, int i, SmallestPriorityQueue smallestPriorityQueue, int i2) {
        for (double d : dArr) {
            smallestPriorityQueue.add((Math.abs(Hash.hash(d, hashType)) % (i2 - 1)) + 1);
        }
    }
}
