package org.apache.sysds.hops.estim;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.math3.distribution.ExponentialDistribution;
import org.apache.commons.math3.random.Well1024a;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.estim.SparsityEstimator;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;

/* loaded from: input_file:org/apache/sysds/hops/estim/EstimatorLayeredGraph.class */
public class EstimatorLayeredGraph extends SparsityEstimator {
    private static final int ROUNDS = 32;
    private final int _rounds;

    /* loaded from: input_file:org/apache/sysds/hops/estim/EstimatorLayeredGraph$LayeredGraph.class */
    public static class LayeredGraph {
        private final List<Node[]> _nodes = new ArrayList();
        private final int _rounds;

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:org/apache/sysds/hops/estim/EstimatorLayeredGraph$LayeredGraph$Node.class */
        public static class Node {
            private List<Node> _input = new ArrayList();
            private double[] _rvect;

            private Node() {
            }

            public List<Node> getInput() {
                return this._input;
            }

            public double[] getVector() {
                return this._rvect;
            }

            public void setVector(double[] dArr) {
                this._rvect = dArr;
            }

            public void addInput(Node node) {
                this._input.add(node);
            }

            /* JADX INFO: Access modifiers changed from: private */
            public double[] computeVector(int i) {
                if (this._rvect != null || getInput().isEmpty()) {
                    return this._rvect;
                }
                List list = (List) getInput().stream().map(node -> {
                    return node.computeVector(i);
                }).filter(dArr -> {
                    return dArr != null;
                }).collect(Collectors.toList());
                if (list.isEmpty()) {
                    return null;
                }
                if (list.size() == 1) {
                    double[] dArr2 = (double[]) list.get(0);
                    this._rvect = dArr2;
                    return dArr2;
                }
                double[] dArr3 = (double[]) ((double[]) list.get(0)).clone();
                for (int i2 = 1; i2 < list.size(); i2++) {
                    double[] dArr4 = (double[]) list.get(i2);
                    for (int i3 = 0; i3 < i; i3++) {
                        dArr3[i3] = Math.min(dArr3[i3], dArr4[i3]);
                    }
                }
                this._rvect = dArr3;
                return dArr3;
            }
        }

        public LayeredGraph(List<MatrixBlock> list, int i) {
            this._rounds = i;
            list.forEach(matrixBlock -> {
                buildNext(matrixBlock);
            });
        }

        public void buildNext(MatrixBlock matrixBlock) {
            Node[] nodeArr;
            if (matrixBlock.isEmpty()) {
                return;
            }
            int numRows = matrixBlock.getNumRows();
            int numColumns = matrixBlock.getNumColumns();
            if (this._nodes.size() == 0) {
                nodeArr = new Node[numRows];
                for (int i = 0; i < numRows; i++) {
                    nodeArr[i] = new Node();
                }
                this._nodes.add(nodeArr);
            } else {
                nodeArr = this._nodes.get(this._nodes.size() - 1);
            }
            Node[] nodeArr2 = new Node[numColumns];
            for (int i2 = 0; i2 < numColumns; i2++) {
                nodeArr2[i2] = new Node();
            }
            this._nodes.add(nodeArr2);
            if (!matrixBlock.isInSparseFormat()) {
                DenseBlock denseBlock = matrixBlock.getDenseBlock();
                for (int i3 = 0; i3 < numRows; i3++) {
                    double[] values = denseBlock.values(i3);
                    int pos = denseBlock.pos(i3);
                    for (int i4 = 0; i4 < numColumns; i4++) {
                        if (values[pos + i4] != DataExpression.DEFAULT_DELIM_FILL_VALUE) {
                            nodeArr2[i4].addInput(nodeArr[i3]);
                        }
                    }
                }
                return;
            }
            SparseBlock sparseBlock = matrixBlock.getSparseBlock();
            for (int i5 = 0; i5 < numRows; i5++) {
                if (!sparseBlock.isEmpty(i5)) {
                    int pos2 = sparseBlock.pos(i5);
                    int size = sparseBlock.size(i5);
                    int[] indexes = sparseBlock.indexes(i5);
                    for (int i6 = pos2; i6 < pos2 + size; i6++) {
                        nodeArr2[indexes[i6]].addInput(nodeArr[i5]);
                    }
                }
            }
        }

        public long estimateNnz() {
            ExponentialDistribution exponentialDistribution = new ExponentialDistribution(new Well1024a(), 1.0d);
            for (Node node : this._nodes.get(0)) {
                double[] dArr = new double[this._rounds];
                for (int i = 0; i < this._rounds; i++) {
                    dArr[i] = exponentialDistribution.sample();
                }
                node.setVector(dArr);
            }
            return Math.round(Arrays.stream(this._nodes.get(this._nodes.size() - 1)).mapToDouble(node2 -> {
                return calcNNZ(node2.computeVector(this._rounds), this._rounds);
            }).sum());
        }

        private static double calcNNZ(double[] dArr, int i) {
            return (dArr == null || dArr.length <= 0) ? DataExpression.DEFAULT_DELIM_FILL_VALUE : (i - 1) / Arrays.stream(dArr).sum();
        }
    }

    public EstimatorLayeredGraph() {
        this(32);
    }

    public EstimatorLayeredGraph(int i) {
        this._rounds = i;
    }

    @Override // org.apache.sysds.hops.estim.SparsityEstimator
    public DataCharacteristics estim(MMNode mMNode) {
        return mMNode.setDataCharacteristics(new MatrixCharacteristics(r0.get(0).getNumRows(), r0.get(r0.size() - 1).getNumColumns(), new LayeredGraph(getMatrices(mMNode, new ArrayList()), this._rounds).estimateNnz()));
    }

    @Override // org.apache.sysds.hops.estim.SparsityEstimator
    public double estim(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, SparsityEstimator.OpCode opCode) {
        if (opCode == SparsityEstimator.OpCode.MM) {
            return estim(matrixBlock, matrixBlock2);
        }
        throw new NotImplementedException();
    }

    @Override // org.apache.sysds.hops.estim.SparsityEstimator
    public double estim(MatrixBlock matrixBlock, SparsityEstimator.OpCode opCode) {
        throw new NotImplementedException();
    }

    @Override // org.apache.sysds.hops.estim.SparsityEstimator
    public double estim(MatrixBlock matrixBlock, MatrixBlock matrixBlock2) {
        return OptimizerUtils.getSparsity(matrixBlock.getNumRows(), matrixBlock2.getNumColumns(), new LayeredGraph(Arrays.asList(matrixBlock, matrixBlock2), this._rounds).estimateNnz());
    }

    private List<MatrixBlock> getMatrices(MMNode mMNode, List<MatrixBlock> list) {
        if (mMNode.isLeaf()) {
            list.add(mMNode.getData());
        } else {
            getMatrices(mMNode.getLeft(), list);
            getMatrices(mMNode.getRight(), list);
        }
        return list;
    }
}
