package org.apache.sysds.runtime.matrix.data;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.functionobjects.KahanPlus;
import org.apache.sysds.runtime.instructions.cp.KahanObject;
import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.DnnUtils;

/* loaded from: input_file:org/apache/sysds/runtime/matrix/data/LibMatrixDNN.class */
public class LibMatrixDNN {
    protected static final Log LOG = LogFactory.getLog(LibMatrixDNN.class.getName());

    /* loaded from: input_file:org/apache/sysds/runtime/matrix/data/LibMatrixDNN$PoolingType.class */
    public enum PoolingType {
        MAX,
        AVG
    }

    public static void conv2d(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3, DnnParameters dnnParameters) {
        checkInputsConv2d(matrixBlock, matrixBlock2, matrixBlock3, dnnParameters);
        if (dnnParameters.bias != null && dnnParameters.bias.isInSparseFormat()) {
            dnnParameters.bias.sparseToDense();
        }
        matrixBlock3.setNonZeros(execute(LibMatrixDNNConv2d.getConv2dWorkers(dnnParameters), dnnParameters));
        matrixBlock3.examSparsity();
    }

    public static void conv2dBackwardData(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3, DnnParameters dnnParameters) {
        checkInputsConv2dBackwardData(matrixBlock, matrixBlock2, matrixBlock3, dnnParameters);
        matrixBlock3.setNonZeros(execute(LibMatrixDNNConv2d.getConv2dBackwardDataWorkers(dnnParameters), dnnParameters));
        matrixBlock3.examSparsity();
    }

    public static void conv2dBackwardFilter(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3, DnnParameters dnnParameters) {
        checkInputsConv2dBackwardFilter(matrixBlock, matrixBlock2, matrixBlock3, dnnParameters);
        execute(LibMatrixDNNConv2d.getConv2dBackwardFilterWorkers(dnnParameters), dnnParameters);
        matrixBlock3.recomputeNonZeros();
        matrixBlock3.examSparsity();
    }

    public static void pooling(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, DnnParameters dnnParameters, PoolingType poolingType) {
        dnnParameters.input1 = matrixBlock;
        dnnParameters.output = matrixBlock2;
        if (matrixBlock.getNumColumns() != dnnParameters.C * dnnParameters.H * dnnParameters.W || matrixBlock.getNumRows() != dnnParameters.N) {
            throw new DMLRuntimeException("Incorrect input dimensions in maxpooling:" + matrixBlock.getNumRows() + " " + matrixBlock.getNumColumns() + " " + dnnParameters.N + " " + (dnnParameters.C * dnnParameters.H * dnnParameters.W));
        }
        if (!dnnParameters.isStride1Pad0() || matrixBlock.sparse) {
            fillIndexesArray(dnnParameters);
        }
        matrixBlock2.setNonZeros(execute(LibMatrixDNNPooling.getPoolingWorkers(dnnParameters, poolingType), dnnParameters));
        matrixBlock2.examSparsity();
    }

    public static void poolingBackward(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3, DnnParameters dnnParameters, boolean z, PoolingType poolingType) {
        dnnParameters.input1 = matrixBlock;
        dnnParameters.input2 = matrixBlock2;
        dnnParameters.output = matrixBlock3;
        if (poolingType == PoolingType.MAX && (matrixBlock.getNumColumns() != dnnParameters.C * dnnParameters.H * dnnParameters.W || matrixBlock.getNumRows() != dnnParameters.N)) {
            throw new DMLRuntimeException("Incorrect input dimensions in maxpooling_backward:" + matrixBlock.getNumRows() + " " + matrixBlock.getNumColumns() + " " + dnnParameters.N + " " + (dnnParameters.K * dnnParameters.P * dnnParameters.Q));
        }
        if (matrixBlock2.getNumColumns() != dnnParameters.C * dnnParameters.P * dnnParameters.Q || matrixBlock2.getNumRows() != dnnParameters.N) {
            throw new DMLRuntimeException("Incorrect dout dimensions in pooling_backward:" + matrixBlock.getNumRows() + " " + matrixBlock.getNumColumns() + " " + dnnParameters.N + " " + (dnnParameters.K * dnnParameters.P * dnnParameters.Q));
        }
        if (dnnParameters.output.isInSparseFormat() && poolingType != PoolingType.MAX) {
            throw new DMLRuntimeException("Sparse pooling_backward is not supported");
        }
        if (poolingType == PoolingType.AVG) {
            fillIndexesArray(dnnParameters);
        } else {
            if (!dnnParameters.input2.isInSparseFormat()) {
                dnnParameters.input1.sparseToDense();
            }
            if (!dnnParameters.input1.isInSparseFormat() || dnnParameters.input2.isInSparseFormat()) {
                fillIndexesArray(dnnParameters);
            }
        }
        matrixBlock3.setNonZeros(execute(LibMatrixDNNPooling.getPoolingBackwardWorkers(dnnParameters, z, poolingType), dnnParameters));
        matrixBlock3.examSparsity();
    }

    public static void reluBackward(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3, int i) {
        DnnParameters dnnParameters = new DnnParameters(matrixBlock.getNumRows(), -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, i);
        dnnParameters.input1 = matrixBlock;
        dnnParameters.input2 = matrixBlock2;
        dnnParameters.output = matrixBlock3;
        if (matrixBlock.getNumRows() != matrixBlock2.getNumRows() || matrixBlock.getNumColumns() != matrixBlock2.getNumColumns()) {
            throw new DMLRuntimeException("Incorrect dimensions for relu_backward:" + matrixBlock.getNumRows() + " != " + matrixBlock2.getNumRows() + " || " + matrixBlock.getNumColumns() + " != " + matrixBlock2.getNumColumns());
        }
        matrixBlock3.setNonZeros(execute(LibMatrixDNNRelu.getReluBackwardWorkers(dnnParameters), dnnParameters));
        matrixBlock3.examSparsity();
    }

    public static void biasAdd(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3, int i) {
        int numRows = matrixBlock.getNumRows();
        int numRows2 = matrixBlock2.getNumRows();
        int numColumns = matrixBlock.getNumColumns() / numRows2;
        if (matrixBlock2.getNumColumns() != 1 || matrixBlock.getNumColumns() % numRows2 != 0) {
            throw new DMLRuntimeException("Incorrect inputs for bias_add: input[" + numRows + " X " + matrixBlock.getNumColumns() + "] and bias[" + numRows2 + " X " + matrixBlock2.getNumColumns() + "]");
        }
        double[] denseBlockValues = matrixBlock3.getDenseBlockValues();
        if (matrixBlock.isEmptyBlock()) {
            for (int i2 = 0; i2 < numRows; i2++) {
                DnnUtils.fillBias(matrixBlock2, denseBlockValues, i2, i2 + 1, numRows, numRows2, numColumns);
            }
        } else {
            matrixBlock3.copy(matrixBlock, false);
            if (matrixBlock2.isInSparseFormat()) {
                matrixBlock2.sparseToDense();
            }
            addBias(denseBlockValues, matrixBlock2.getDenseBlockValues(), 1.0d, numRows, numRows2, numColumns);
        }
        matrixBlock3.recomputeNonZeros();
        matrixBlock3.examSparsity();
    }

    public static void channelSums(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, int i, int i2) {
        double[] denseBlockValues = matrixBlock2.getDenseBlockValues();
        if (matrixBlock.isInSparseFormat()) {
            SparseBlock sparseBlock = matrixBlock.getSparseBlock();
            for (int i3 = 0; i3 < matrixBlock.getNumRows(); i3++) {
                if (!sparseBlock.isEmpty(i3)) {
                    int pos = sparseBlock.pos(i3);
                    int size = sparseBlock.size(i3);
                    int[] indexes = sparseBlock.indexes(i3);
                    double[] values = sparseBlock.values(i3);
                    for (int i4 = pos; i4 < pos + size; i4++) {
                        int i5 = indexes[i4] / i2;
                        denseBlockValues[i5] = denseBlockValues[i5] + values[i4];
                    }
                }
            }
        } else {
            double[] denseBlockValues2 = matrixBlock.getDenseBlockValues();
            if (denseBlockValues2 != null) {
                KahanPlus kahanPlusFnObject = KahanPlus.getKahanPlusFnObject();
                for (int i6 = 0; i6 < i; i6++) {
                    KahanObject kahanObject = new KahanObject(DataExpression.DEFAULT_DELIM_FILL_VALUE, DataExpression.DEFAULT_DELIM_FILL_VALUE);
                    for (int i7 = 0; i7 < matrixBlock.getNumRows(); i7++) {
                        int i8 = (i7 * i * i2) + (i6 * i2);
                        int i9 = 0;
                        while (i9 < i2) {
                            kahanPlusFnObject.execute2(kahanObject, denseBlockValues2[i8]);
                            i9++;
                            i8++;
                        }
                    }
                    denseBlockValues[i6] = kahanObject._sum;
                }
            }
        }
        matrixBlock2.recomputeNonZeros();
    }

    public static void batchNorm2DBackward(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3, double d, MatrixBlock matrixBlock4, MatrixBlock matrixBlock5, MatrixBlock matrixBlock6, MatrixBlock matrixBlock7, MatrixBlock matrixBlock8) {
        int numRows = matrixBlock.getNumRows();
        int numRows2 = matrixBlock3.getNumRows();
        int numColumns = matrixBlock.getNumColumns() / numRows2;
        channelSums(matrixBlock, matrixBlock8, numRows2, numColumns);
        if (matrixBlock8.isInSparseFormat()) {
            matrixBlock8.sparseToDense();
        }
        if (matrixBlock7.isInSparseFormat()) {
            matrixBlock7.sparseToDense();
        }
        if (matrixBlock6.isInSparseFormat()) {
            matrixBlock6.sparseToDense();
        }
        if (matrixBlock4.isInSparseFormat()) {
            matrixBlock4.sparseToDense();
        }
        if (matrixBlock5.isInSparseFormat()) {
            matrixBlock5.sparseToDense();
        }
        if (matrixBlock3.isInSparseFormat()) {
            matrixBlock3.sparseToDense();
        }
        double[] denseBlockValues = matrixBlock8.getDenseBlockValues();
        double[] denseBlockValues2 = matrixBlock7.getDenseBlockValues();
        double[] denseBlockValues3 = matrixBlock6.getDenseBlockValues();
        double[] denseBlockValues4 = matrixBlock4.getDenseBlockValues();
        double[] denseBlockValues5 = matrixBlock5.getDenseBlockValues();
        double[] denseBlockValues6 = matrixBlock3.getDenseBlockValues();
        double[] dArr = denseBlockValues4 == null ? new double[numRows2] : denseBlockValues4;
        double[] dArr2 = denseBlockValues5 == null ? new double[numRows2] : denseBlockValues5;
        double[] dArr3 = denseBlockValues6 == null ? new double[numRows2] : denseBlockValues6;
        if (matrixBlock.isInSparseFormat()) {
            matrixBlock.sparseToDense();
        }
        if (matrixBlock2.isInSparseFormat()) {
            matrixBlock2.sparseToDense();
        }
        if (matrixBlock.isInSparseFormat() || matrixBlock2.isInSparseFormat()) {
            throw new DMLRuntimeException("Sparse format is not yet supported for batch norm backward");
        }
        double[] denseBlockValues7 = matrixBlock.getDenseBlockValues();
        double[] denseBlockValues8 = matrixBlock2.getDenseBlockValues();
        double pow = Math.pow(numRows * numColumns, -1.0d);
        int i = numRows2 * numColumns;
        for (int i2 = 0; i2 < numRows2; i2++) {
            double d2 = 0.0d;
            double d3 = 0.0d;
            double d4 = 0.0d;
            double d5 = 0.0d;
            double d6 = 0.0d;
            for (int i3 = 0; i3 < numRows; i3++) {
                int i4 = (i3 * i) + (i2 * numColumns);
                int i5 = 0;
                while (i5 < numColumns) {
                    double d7 = denseBlockValues8 != null ? denseBlockValues8[i4] : DataExpression.DEFAULT_DELIM_FILL_VALUE;
                    double d8 = (denseBlockValues7 != null ? denseBlockValues7[i4] : DataExpression.DEFAULT_DELIM_FILL_VALUE) - dArr[i2];
                    double d9 = d7 * dArr3[i2];
                    d2 -= ((0.5d * d8) * Math.pow(dArr2[i2], 3.0d)) * d9;
                    d3 -= d9 * dArr2[i2];
                    d6 += d8 * dArr2[i2] * d7;
                    d5 += d7;
                    d4 -= (2.0d * pow) * d8;
                    i5++;
                    i4++;
                }
            }
            denseBlockValues[i2] = d5;
            denseBlockValues2[i2] = d6;
            double d10 = pow * (d3 + (d4 * d2));
            for (int i6 = 0; i6 < numRows; i6++) {
                int i7 = (i6 * i) + (i2 * numColumns);
                int i8 = 0;
                while (i8 < numColumns) {
                    denseBlockValues3[i7] = ((denseBlockValues8 != null ? denseBlockValues8[i7] : DataExpression.DEFAULT_DELIM_FILL_VALUE) * dArr3[i2] * dArr2[i2]) + d10 + (2.0d * pow * ((denseBlockValues7 != null ? denseBlockValues7[i7] : DataExpression.DEFAULT_DELIM_FILL_VALUE) - dArr[i2]) * d2);
                    i8++;
                    i7++;
                }
            }
        }
        matrixBlock8.recomputeNonZeros();
        matrixBlock7.recomputeNonZeros();
        matrixBlock6.recomputeNonZeros();
    }

    public static void batchNorm2D(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3, MatrixBlock matrixBlock4, MatrixBlock matrixBlock5, String str, double d, double d2, MatrixBlock matrixBlock6, MatrixBlock matrixBlock7, MatrixBlock matrixBlock8, MatrixBlock matrixBlock9, MatrixBlock matrixBlock10) {
        if (matrixBlock3.isInSparseFormat()) {
            matrixBlock3.sparseToDense();
        }
        double[] denseBlockValues = matrixBlock3.getDenseBlockValues();
        if (matrixBlock2.isInSparseFormat()) {
            matrixBlock2.sparseToDense();
        }
        double[] denseBlockValues2 = matrixBlock2.getDenseBlockValues();
        if (matrixBlock4.isInSparseFormat()) {
            matrixBlock4.sparseToDense();
        }
        double[] denseBlockValues3 = matrixBlock4.getDenseBlockValues();
        if (matrixBlock5.isInSparseFormat()) {
            matrixBlock5.sparseToDense();
        }
        double[] denseBlockValues4 = matrixBlock5.getDenseBlockValues();
        double[] denseBlockValues5 = matrixBlock7.getDenseBlockValues();
        double[] denseBlockValues6 = matrixBlock8.getDenseBlockValues();
        double[] denseBlockValues7 = matrixBlock9.getDenseBlockValues();
        double[] denseBlockValues8 = matrixBlock10.getDenseBlockValues();
        int numRows = matrixBlock.getNumRows();
        int numRows2 = matrixBlock3.getNumRows();
        int numColumns = matrixBlock.getNumColumns() / numRows2;
        if (str.equalsIgnoreCase("train")) {
            computeBiasSumAndSumSquares(matrixBlock, denseBlockValues7, denseBlockValues8, numRows2, numColumns);
            int i = numRows * numColumns;
            for (int i2 = 0; i2 < numRows2; i2++) {
                double d3 = denseBlockValues7[i2] / i;
                double pow = (denseBlockValues8[i2] / i) - Math.pow(d3, 2.0d);
                denseBlockValues7[i2] = d3;
                denseBlockValues8[i2] = Math.pow(Math.sqrt(pow + d), -1.0d);
                denseBlockValues5[i2] = (d2 * (denseBlockValues3 != null ? denseBlockValues3[i2] : DataExpression.DEFAULT_DELIM_FILL_VALUE)) + ((1.0d - d2) * d3);
                denseBlockValues6[i2] = (d2 * (denseBlockValues4 != null ? denseBlockValues4[i2] : DataExpression.DEFAULT_DELIM_FILL_VALUE)) + ((1.0d - d2) * d3);
            }
        } else {
            if (!str.equalsIgnoreCase("test")) {
                throw new DMLRuntimeException("Incorrect mode: Expected either train or test, but found " + str);
            }
            copy(matrixBlock4, denseBlockValues5);
            copy(matrixBlock5, denseBlockValues6);
            copy(matrixBlock4, denseBlockValues7);
            double pow2 = Math.pow(Math.sqrt(d), -1.0d);
            double[] denseBlockValues9 = matrixBlock5.getDenseBlockValues();
            if (denseBlockValues9 != null) {
                for (int i3 = 0; i3 < denseBlockValues9.length; i3++) {
                    denseBlockValues8[i3] = Math.pow(Math.sqrt(denseBlockValues9[i3] + d), -1.0d);
                }
            } else {
                Arrays.fill(denseBlockValues8, pow2);
            }
        }
        double[] denseBlockValues10 = matrixBlock6.getDenseBlockValues();
        copy(matrixBlock, denseBlockValues10);
        if (denseBlockValues == null || denseBlockValues2 == null) {
            addBias(denseBlockValues10, denseBlockValues7, -1.0d, numRows, numRows2, numColumns);
            multBias(denseBlockValues10, denseBlockValues8, numRows, numRows2, numColumns);
            multBias(denseBlockValues10, denseBlockValues2, numRows, numRows2, numColumns);
            addBias(denseBlockValues10, denseBlockValues, 1.0d, numRows, numRows2, numColumns);
        } else {
            int i4 = 0;
            for (int i5 = 0; i5 < numRows; i5++) {
                for (int i6 = 0; i6 < numRows2; i6++) {
                    int i7 = 0;
                    while (i7 < numColumns) {
                        denseBlockValues10[i4] = ((denseBlockValues10[i4] - denseBlockValues7[i6]) * denseBlockValues8[i6] * denseBlockValues2[i6]) + denseBlockValues[i6];
                        i7++;
                        i4++;
                    }
                }
            }
        }
        matrixBlock6.recomputeNonZeros();
        matrixBlock7.recomputeNonZeros();
        matrixBlock8.recomputeNonZeros();
        matrixBlock9.recomputeNonZeros();
        matrixBlock10.recomputeNonZeros();
    }

    private static void copy(MatrixBlock matrixBlock, double[] dArr) {
        if (!matrixBlock.isInSparseFormat()) {
            double[] denseBlockValues = matrixBlock.getDenseBlockValues();
            if (denseBlockValues != null) {
                System.arraycopy(denseBlockValues, 0, dArr, 0, denseBlockValues.length);
                return;
            }
            return;
        }
        SparseBlock sparseBlock = matrixBlock.getSparseBlock();
        int numColumns = matrixBlock.getNumColumns();
        for (int i = 0; i < matrixBlock.getNumRows(); i++) {
            if (!sparseBlock.isEmpty(i)) {
                int pos = sparseBlock.pos(i);
                int size = sparseBlock.size(i);
                int[] indexes = sparseBlock.indexes(i);
                double[] values = sparseBlock.values(i);
                for (int i2 = pos; i2 < pos + size; i2++) {
                    dArr[(i * numColumns) + indexes[i2]] = values[i2];
                }
            }
        }
    }

    public static void addBias(double[] dArr, double[] dArr2, double d, int i, int i2, int i3) {
        if (dArr2 == null) {
            return;
        }
        int i4 = 0;
        for (int i5 = 0; i5 < i; i5++) {
            for (int i6 = 0; i6 < i2; i6++) {
                double d2 = d * dArr2[i6];
                int i7 = 0;
                while (i7 < i3) {
                    int i8 = i4;
                    dArr[i8] = dArr[i8] + d2;
                    i7++;
                    i4++;
                }
            }
        }
    }

    public static void multBias(double[] dArr, double[] dArr2, int i, int i2, int i3) {
        if (dArr2 == null) {
            Arrays.fill(dArr, DataExpression.DEFAULT_DELIM_FILL_VALUE);
            return;
        }
        int i4 = 0;
        for (int i5 = 0; i5 < i; i5++) {
            for (int i6 = 0; i6 < i2; i6++) {
                double d = dArr2[i6];
                int i7 = 0;
                while (i7 < i3) {
                    int i8 = i4;
                    dArr[i8] = dArr[i8] * d;
                    i7++;
                    i4++;
                }
            }
        }
    }

    private static void computeBiasSumAndSumSquares(MatrixBlock matrixBlock, double[] dArr, double[] dArr2, int i, int i2) {
        if (dArr.length != i) {
            throw new DMLRuntimeException("Expected the length of array to be " + i + ", but instead is " + dArr.length);
        }
        if (dArr2.length != i) {
            throw new DMLRuntimeException("Expected the length of array to be " + i + ", but instead is " + dArr2.length);
        }
        if (matrixBlock.isInSparseFormat()) {
            SparseBlock sparseBlock = matrixBlock.getSparseBlock();
            for (int i3 = 0; i3 < matrixBlock.getNumRows(); i3++) {
                if (!sparseBlock.isEmpty(i3)) {
                    int pos = sparseBlock.pos(i3);
                    int size = sparseBlock.size(i3);
                    int[] indexes = sparseBlock.indexes(i3);
                    double[] values = sparseBlock.values(i3);
                    for (int i4 = pos; i4 < pos + size; i4++) {
                        int i5 = indexes[i4] / i2;
                        dArr[i5] = dArr[i5] + values[i4];
                        dArr2[i5] = dArr2[i5] + Math.pow(values[i4], 2.0d);
                    }
                }
            }
            return;
        }
        double[] denseBlockValues = matrixBlock.getDenseBlockValues();
        int numRows = matrixBlock.getNumRows();
        if (denseBlockValues != null) {
            int i6 = 0;
            for (int i7 = 0; i7 < numRows; i7++) {
                for (int i8 = 0; i8 < i; i8++) {
                    int i9 = 0;
                    while (i9 < i2) {
                        int i10 = i8;
                        dArr[i10] = dArr[i10] + denseBlockValues[i6];
                        int i11 = i8;
                        dArr2[i11] = dArr2[i11] + Math.pow(denseBlockValues[i6], 2.0d);
                        i9++;
                        i6++;
                    }
                }
            }
        }
    }

    public static void biasMultiply(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3, int i) {
        int numRows = matrixBlock.getNumRows();
        int numRows2 = matrixBlock2.getNumRows();
        int numColumns = matrixBlock.getNumColumns() / numRows2;
        DnnParameters dnnParameters = new DnnParameters(numRows, numColumns, -1, -1, numRows2, -1, -1, -1, -1, -1, -1, i);
        dnnParameters.input1 = matrixBlock;
        dnnParameters.input2 = matrixBlock2;
        dnnParameters.output = matrixBlock3;
        if (matrixBlock2.getNumColumns() != 1 || matrixBlock.getNumColumns() % numRows2 != 0) {
            throw new DMLRuntimeException("Incorrect inputs for bias_multiply: input[" + numRows + " X " + matrixBlock.getNumColumns() + "] and bias[" + numRows2 + " X " + matrixBlock2.getNumColumns() + "]");
        }
        if (matrixBlock.isEmptyBlock() || matrixBlock2.isEmptyBlock()) {
            dnnParameters.output.setNonZeros(0L);
            return;
        }
        matrixBlock3.copy(matrixBlock);
        if (matrixBlock2.isInSparseFormat()) {
            matrixBlock2.sparseToDense();
        }
        double[] denseBlockValues = matrixBlock2.getDenseBlockValues();
        if (matrixBlock.isInSparseFormat()) {
            SparseBlock sparseBlock = matrixBlock3.sparseBlock;
            for (int i2 = 0; i2 < numRows2; i2++) {
                if (denseBlockValues[i2] == DataExpression.DEFAULT_DELIM_FILL_VALUE) {
                    for (int i3 = 0; i3 < numRows; i3++) {
                        if (!sparseBlock.isEmpty(i3)) {
                            sparseBlock.deleteIndexRange(i3, i2 * numColumns, (i2 + 1) * numColumns);
                        }
                    }
                }
            }
            for (int i4 = 0; i4 < numRows; i4++) {
                if (!sparseBlock.isEmpty(i4)) {
                    int pos = sparseBlock.pos(i4);
                    int size = sparseBlock.size(i4);
                    int[] indexes = sparseBlock.indexes(i4);
                    double[] values = sparseBlock.values(i4);
                    for (int i5 = pos; i5 < pos + size; i5++) {
                        int i6 = indexes[i5] / numColumns;
                        if (denseBlockValues[i6] != DataExpression.DEFAULT_DELIM_FILL_VALUE) {
                            int i7 = i5;
                            values[i7] = values[i7] * denseBlockValues[i6];
                        }
                    }
                }
            }
        } else {
            double[] denseBlockValues2 = matrixBlock3.getDenseBlockValues();
            int i8 = 0;
            for (int i9 = 0; i9 < numRows; i9++) {
                for (int i10 = 0; i10 < numRows2; i10++) {
                    double d = denseBlockValues[i10];
                    int i11 = 0;
                    while (i11 < numColumns) {
                        int i12 = i8;
                        denseBlockValues2[i12] = denseBlockValues2[i12] * d;
                        i11++;
                        i8++;
                    }
                }
            }
        }
        dnnParameters.output.recomputeNonZeros();
        dnnParameters.output.examSparsity();
    }

    private static long execute(ArrayList<Callable<Long>> arrayList, DnnParameters dnnParameters) {
        int constrainedNumThreads = OptimizerUtils.getConstrainedNumThreads(dnnParameters.numThreads);
        long j = 0;
        try {
            if (constrainedNumThreads == 1) {
                Iterator<Callable<Long>> it = arrayList.iterator();
                while (it.hasNext()) {
                    j += it.next().call().longValue();
                }
            } else {
                ExecutorService executorService = CommonThreadPool.get(Math.min(constrainedNumThreads, dnnParameters.N));
                List invokeAll = executorService.invokeAll(arrayList);
                executorService.shutdown();
                Iterator it2 = invokeAll.iterator();
                while (it2.hasNext()) {
                    j += ((Long) ((Future) it2.next()).get()).longValue();
                }
            }
            return j;
        } catch (Exception e) {
            throw new DMLRuntimeException("Error while executing multi-threaded tasks", e);
        }
    }

    private static void checkOrThrowException(String str, long j, long j2) {
        if (j != j2) {
            DMLRuntimeException dMLRuntimeException = new DMLRuntimeException(str + ":" + j + " != " + dMLRuntimeException);
            throw dMLRuntimeException;
        }
    }

    private static void checkOrThrowException(String str, long j, long j2, long j3, long j4) {
        if (j != j2 * j3 * j4) {
            DMLRuntimeException dMLRuntimeException = new DMLRuntimeException(str + ":" + j + " != (" + dMLRuntimeException + " * " + j2 + " * " + dMLRuntimeException);
            throw dMLRuntimeException;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void checkInputsConv2dBackwardData(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3, DnnParameters dnnParameters) {
        dnnParameters.input1 = matrixBlock;
        dnnParameters.input2 = matrixBlock2;
        dnnParameters.output = matrixBlock3;
        checkOrThrowException("Incorrect input to conv2d_backward_data: Number of rows of input filter != number of filters in filter_shape", matrixBlock.getNumRows(), dnnParameters.K);
        checkOrThrowException("Incorrect input to conv2d_backward_data: Number of columns of input filter != channels*filter_height*filter_height in filter_shape", matrixBlock.getNumColumns(), dnnParameters.C, dnnParameters.R, dnnParameters.S);
        checkOrThrowException("Incorrect input to conv2d_backward_data: Number of rows of input errors != batch size in input_shape", matrixBlock2.getNumRows(), dnnParameters.N);
        checkOrThrowException("Incorrect input to conv2d_backward_data: Number of columns of input errors != expected input error channels*height*width", matrixBlock2.getNumColumns(), dnnParameters.K, dnnParameters.P, dnnParameters.Q);
        if (dnnParameters.stride_h <= 0 || dnnParameters.stride_w <= 0) {
            throw new DMLRuntimeException("Only positive strides supported:" + dnnParameters.stride_h + ", " + dnnParameters.stride_w);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void checkInputsConv2dBackwardFilter(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3, DnnParameters dnnParameters) {
        dnnParameters.input1 = matrixBlock;
        dnnParameters.input2 = matrixBlock2;
        dnnParameters.output = matrixBlock3;
        checkOrThrowException("Incorrect input to conv2d_backward_filter: Number of rows of input data != batch size in input_shape", matrixBlock.getNumRows(), dnnParameters.N);
        checkOrThrowException("Incorrect input to conv2d_backward_filter: Number of columns of input data != channels*input_height*input_height in input_shape", matrixBlock.getNumColumns(), dnnParameters.C, dnnParameters.H, dnnParameters.W);
        checkOrThrowException("Incorrect input to conv2d_backward_filter: Number of rows of input errors != batch size in input_shape", matrixBlock2.getNumRows(), dnnParameters.N);
        checkOrThrowException("Incorrect input to conv2d_backward_filter: Number of columns of input errors != expected input error channels*height*width", matrixBlock2.getNumColumns(), dnnParameters.K, dnnParameters.P, dnnParameters.Q);
        if (dnnParameters.stride_h <= 0 || dnnParameters.stride_w <= 0) {
            throw new DMLRuntimeException("Only positive strides supported:" + dnnParameters.stride_h + ", " + dnnParameters.stride_w);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void checkInputsConv2d(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3, DnnParameters dnnParameters) {
        dnnParameters.input1 = matrixBlock;
        dnnParameters.input2 = matrixBlock2;
        dnnParameters.output = matrixBlock3;
        checkOrThrowException("Incorrect input to conv2d: Number of rows of input filter != number of filters in filter_shape", matrixBlock2.getNumRows(), dnnParameters.K);
        checkOrThrowException("Incorrect input to conv2d: Number of columns of input filter != channels*filter_height*filter_height in filter_shape", matrixBlock2.getNumColumns(), dnnParameters.C, dnnParameters.R, dnnParameters.S);
        checkOrThrowException("Incorrect input to conv2d: Number of rows of input data != batch size in input_shape", matrixBlock.getNumRows(), dnnParameters.N);
        checkOrThrowException("Incorrect input to conv2d: Number of columns of input data != channels*input_height*input_height in input_shape", matrixBlock.getNumColumns(), dnnParameters.C, dnnParameters.H, dnnParameters.W);
        if (dnnParameters.stride_h <= 0 || dnnParameters.stride_w <= 0) {
            throw new DMLRuntimeException("Only positive strides supported:" + dnnParameters.stride_h + ", " + dnnParameters.stride_w);
        }
    }

    private static void fillIndexesArray(DnnParameters dnnParameters) {
        dnnParameters.start_indexes_h = new int[dnnParameters.P];
        dnnParameters.end_indexes_h = new int[dnnParameters.P];
        dnnParameters.start_indexes_w = new int[dnnParameters.Q];
        dnnParameters.end_indexes_w = new int[dnnParameters.Q];
        int i = 0;
        int i2 = -dnnParameters.pad_h;
        while (true) {
            int i3 = i2;
            if (i >= dnnParameters.P) {
                break;
            }
            dnnParameters.start_indexes_h[i] = Math.max(i3, 0);
            dnnParameters.end_indexes_h[i] = Math.min(i3 + dnnParameters.R, dnnParameters.H);
            i++;
            i2 = i3 + dnnParameters.stride_h;
        }
        int i4 = 0;
        int i5 = -dnnParameters.pad_w;
        while (true) {
            int i6 = i5;
            if (i4 >= dnnParameters.Q) {
                return;
            }
            dnnParameters.start_indexes_w[i4] = Math.max(i6, 0);
            dnnParameters.end_indexes_w[i4] = Math.min(i6 + dnnParameters.S, dnnParameters.W);
            i4++;
            i5 = i6 + dnnParameters.stride_w;
        }
    }
}
