package org.apache.sysds.runtime.matrix.data;

import java.util.Arrays;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.matrix.data.LibMatrixDNNHelper;

/* loaded from: input_file:org/apache/sysds/runtime/matrix/data/LibMatrixDNNIm2Col.class */
public class LibMatrixDNNIm2Col {
    public static void im2col(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, int i, DnnParameters dnnParameters, boolean z) {
        im2col(matrixBlock, matrixBlock2, i, dnnParameters.C, dnnParameters.R, dnnParameters.S, dnnParameters.H, dnnParameters.W, dnnParameters.P, dnnParameters.Q, dnnParameters.stride_h, dnnParameters.stride_w, dnnParameters.pad_h, dnnParameters.pad_w, z);
    }

    public static void im2col(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, int i9, int i10, int i11, int i12, boolean z) {
        boolean z2 = i9 == 1 && i10 == 1 && i11 == 0 && i12 == 0;
        if (!matrixBlock.sparse && z2 && !z) {
            im2colDenseStride1Pad0(matrixBlock.getDenseBlockValues(), matrixBlock2.getDenseBlockValues(), i * i2 * i5 * i6, i2, i3, i4, i5, i6, i7, i8);
        } else if (matrixBlock.sparse) {
            im2colSparse(matrixBlock, matrixBlock2, i, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11, i12, z);
        } else {
            im2colDense(matrixBlock.getDenseBlockValues(), matrixBlock2.getDenseBlockValues(), i, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11, i12, z);
        }
    }

    public static void im2colDenseStride1Pad0(double[] dArr, double[] dArr2, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8) {
        int i9 = i2 * i3 * i4;
        for (int i10 = 0; i10 < i9; i10++) {
            int i11 = i10 % i4;
            int i12 = (i10 / i4) % i3;
            int i13 = (i10 / i3) / i4;
            for (int i14 = 0; i14 < i7; i14++) {
                int i15 = i14 + i12;
                int i16 = ((i10 * i7) + i14) * i8;
                int i17 = i + (((i13 * i5) + i15) * i6);
                System.arraycopy(dArr, i17 + i11, dArr2, i16, i8);
                int i18 = i8 - 1;
                int i19 = i18 + i11;
                dArr2[i16 + i18] = i15 < i5 && i19 < i6 ? dArr[i17 + i19] : DataExpression.DEFAULT_DELIM_FILL_VALUE;
            }
        }
    }

    public static void im2colDense(double[] dArr, double[] dArr2, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, int i9, int i10, int i11, int i12, boolean z) {
        Arrays.fill(dArr2, DataExpression.DEFAULT_DELIM_FILL_VALUE);
        int i13 = i2 * i3 * i4;
        int i14 = i * i2 * i5 * i6;
        for (int i15 = 0; i15 < i13; i15++) {
            int i16 = i15 % i4;
            int i17 = (i15 / i4) % i3;
            int i18 = (i15 / i3) / i4;
            for (int i19 = 0; i19 < i7; i19++) {
                int i20 = z ? i15 + (i19 * i8 * i13) : ((i15 * i7) + i19) * i8;
                int i21 = ((i19 * i9) - i11) + i17;
                int i22 = i14 + (((i18 * i5) + i21) * i6);
                if (i21 >= 0 && i21 < i5) {
                    for (int i23 = 0; i23 < i8; i23++) {
                        int i24 = ((i23 * i10) - i12) + i16;
                        if (i24 >= 0 && i24 < i6) {
                            dArr2[i20 + (z ? i23 * i13 : i23)] = dArr[i22 + i24];
                        }
                    }
                }
            }
        }
    }

    public static void im2colSparse(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, int i9, int i10, int i11, int i12, boolean z) {
        matrixBlock2.reset();
        SparseBlock sparseBlock = matrixBlock.sparseBlock;
        if (sparseBlock.isEmpty(i)) {
            return;
        }
        int pos = sparseBlock.pos(i);
        int size = sparseBlock.size(i);
        int[] indexes = sparseBlock.indexes(i);
        double[] values = sparseBlock.values(i);
        boolean z2 = i9 == 1 && i10 == 1 && i11 == 0 && i12 == 0 && i6 == i4 && i8 == 1;
        int i13 = i3 * i4;
        LibMatrixDNNHelper.CellIndex3 cellIndex3 = new LibMatrixDNNHelper.CellIndex3();
        for (int i14 = pos; i14 < pos + size; i14++) {
            cellIndex3 = LibMatrixDNNHelper.computeTensorIndexes(indexes[i14], i5, i6, cellIndex3);
            if (z2) {
                appendInputValueToIm2colOutputSimple(matrixBlock2, cellIndex3.ix1, cellIndex3.ix2, cellIndex3.ix3, values[i14], i3, i4, i13, i7, z);
            } else {
                appendInputValueToIm2colOutput(matrixBlock2, cellIndex3.ix1, cellIndex3.ix2, cellIndex3.ix3, values[i14], i3, i4, i13, i7, i8, i9, i10, i11, i12, z);
            }
        }
        matrixBlock2.sortSparseRows();
    }

    private static void appendInputValueToIm2colOutput(MatrixBlock matrixBlock, int i, int i2, int i3, double d, int i4, int i5, int i6, int i7, int i8, int i9, int i10, int i11, int i12, boolean z) {
        int i13;
        int i14;
        int max = Math.max(0, ((i2 + i11) - (i7 * i9)) + 1);
        int min = Math.min(i4 - 1, i2 + i11);
        int max2 = Math.max(0, ((i3 + i12) - (i8 * i10)) + 1);
        int min2 = Math.min(i5 - 1, i3 + i12);
        int min3 = max + Math.min(((i2 - max) + i11) % i9, (min - max) + 1);
        int min4 = max2 + Math.min(((i3 - max2) + i12) % i10, (min2 - max2) + 1);
        int i15 = min3;
        int i16 = i * i6;
        int i17 = min3;
        while (true) {
            int i18 = i16 + (i17 * i5);
            if (i15 > min) {
                return;
            }
            int i19 = (((i2 - i15) + i11) / i9) * i8;
            int i20 = min4;
            int i21 = (i3 - min4) + i12;
            while (true) {
                int i22 = i21;
                if (i20 <= min2) {
                    int i23 = i22 / i10;
                    int i24 = z ? i19 + i23 : i18 + i20;
                    if (z) {
                        i13 = i18;
                        i14 = i20;
                    } else {
                        i13 = i19;
                        i14 = i23;
                    }
                    matrixBlock.appendValue(i24, i13 + i14, d);
                    i20 += i10;
                    i21 = i22 - i10;
                }
            }
            i15 += i9;
            i16 = i18;
            i17 = i9;
        }
    }

    private static void appendInputValueToIm2colOutputSimple(MatrixBlock matrixBlock, int i, int i2, int i3, double d, int i4, int i5, int i6, int i7, boolean z) {
        int max = Math.max(0, (i2 - i7) + 1);
        int min = Math.min(i4 - 1, i2);
        int i8 = (i * i6) + i3 + (max * i5);
        int i9 = i2 - max;
        while (i9 >= i2 - min) {
            matrixBlock.appendValue(z ? i9 : i8, z ? i8 : i9, d);
            i9--;
            i8 += i5;
        }
    }

    public static void col2imOverSingleImage(int i, MatrixBlock matrixBlock, DnnParameters dnnParameters) {
        if (matrixBlock.rlen != dnnParameters.P * dnnParameters.Q || matrixBlock.clen != dnnParameters.C * dnnParameters.R * dnnParameters.S) {
            throw new DMLRuntimeException("Incorrect input dimensions");
        }
        if (dnnParameters.output.isInSparseFormat()) {
            throw new DMLRuntimeException("Only dense output is implemented");
        }
        double[] denseBlockValues = dnnParameters.output.getDenseBlockValues();
        if (!matrixBlock.isInSparseFormat()) {
            col2IMDenseInput(0, i, matrixBlock.getDenseBlockValues(), denseBlockValues, dnnParameters);
            return;
        }
        if (matrixBlock.isEmptyBlock()) {
            return;
        }
        int i2 = i * dnnParameters.C * dnnParameters.H * dnnParameters.W;
        int i3 = dnnParameters.H * dnnParameters.W;
        LibMatrixDNNHelper.CellIndex3 cellIndex3 = new LibMatrixDNNHelper.CellIndex3();
        SparseBlock sparseBlock = matrixBlock.sparseBlock;
        for (int i4 = 0; i4 < matrixBlock.getNumRows(); i4++) {
            if (!sparseBlock.isEmpty(i4)) {
                cellIndex3 = LibMatrixDNNHelper.computeTensorIndexes(i4, dnnParameters.P, dnnParameters.Q, cellIndex3);
                int i5 = (cellIndex3.ix2 * dnnParameters.stride_h) - dnnParameters.pad_h;
                int i6 = (cellIndex3.ix3 * dnnParameters.stride_w) - dnnParameters.pad_w;
                if (cellIndex3.ix1 != 0) {
                    throw new DMLRuntimeException("Incorrect tensor indexes: " + cellIndex3 + ", " + dnnParameters.P + " " + dnnParameters.Q);
                }
                int pos = sparseBlock.pos(i4);
                int size = sparseBlock.size(i4);
                int[] indexes = sparseBlock.indexes(i4);
                double[] values = sparseBlock.values(i4);
                for (int i7 = pos; i7 < pos + size; i7++) {
                    cellIndex3 = LibMatrixDNNHelper.computeTensorIndexes(indexes[i7], dnnParameters.R, dnnParameters.S, cellIndex3);
                    int i8 = i5 + cellIndex3.ix2;
                    int i9 = i6 + cellIndex3.ix3;
                    if (i8 >= 0 && i8 < dnnParameters.H && i9 >= 0 && i9 < dnnParameters.W) {
                        int i10 = i2 + (cellIndex3.ix1 * i3) + (i8 * dnnParameters.W) + i9;
                        denseBlockValues[i10] = denseBlockValues[i10] + values[i7];
                    }
                }
            }
        }
    }

    private static void col2IMDenseInput(int i, int i2, double[] dArr, double[] dArr2, DnnParameters dnnParameters) {
        int i3 = i2 * dnnParameters.C * dnnParameters.H * dnnParameters.W;
        int i4 = dnnParameters.H * dnnParameters.W;
        int i5 = i * dnnParameters.P * dnnParameters.Q;
        int i6 = dnnParameters.C * dnnParameters.R * dnnParameters.S;
        int i7 = dnnParameters.R * dnnParameters.S;
        for (int i8 = 0; i8 < dnnParameters.P; i8++) {
            int i9 = (i8 * dnnParameters.stride_h) - dnnParameters.pad_h;
            int max = Math.max(0, -i9);
            int min = Math.min(dnnParameters.R, dnnParameters.H - i9);
            for (int i10 = 0; i10 < dnnParameters.Q; i10++) {
                int i11 = (i10 * dnnParameters.stride_w) - dnnParameters.pad_w;
                int max2 = Math.max(0, -i11);
                int min2 = Math.min(dnnParameters.S, dnnParameters.W - i11);
                int i12 = (i5 + (i8 * dnnParameters.Q) + i10) * i6;
                for (int i13 = 0; i13 < dnnParameters.C; i13++) {
                    int i14 = i3 + (i13 * i4);
                    int i15 = i12 + (i13 * i7);
                    for (int i16 = max; i16 < min; i16++) {
                        for (int i17 = max2; i17 < min2; i17++) {
                            int i18 = i15 + (i16 * dnnParameters.S) + i17;
                            int i19 = i14 + ((i9 + i16) * dnnParameters.W) + i11 + i17;
                            dArr2[i19] = dArr2[i19] + dArr[i18];
                        }
                    }
                }
            }
        }
    }

    public static void preallocateSparseOutput(MatrixBlock matrixBlock, MatrixBlock matrixBlock2) {
        if (matrixBlock.sparse) {
            int ceil = (int) Math.ceil(4.0d * matrixBlock.getSparsity() * matrixBlock2.clen);
            for (int i = 0; i < matrixBlock2.rlen; i++) {
                matrixBlock2.getSparseBlock().allocate(i, Math.max(Math.min(ceil, matrixBlock2.clen), 16));
            }
        }
    }
}
