package org.nd4j.linalg.convolution;

import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/convolution/Convolution.class */
public class Convolution {
    private static Logger log = LoggerFactory.getLogger(Convolution.class);

    /* loaded from: input_file:org/nd4j/linalg/convolution/Convolution$Type.class */
    public enum Type {
        FULL,
        VALID,
        SAME
    }

    private Convolution() {
    }

    public static INDArray col2im(INDArray iNDArray, int[] iArr, int[] iArr2, int i, int i2) {
        return col2im(iNDArray, iArr[0], iArr[1], iArr2[0], iArr2[1], i, i2);
    }

    public static INDArray col2im(INDArray iNDArray, int i, int i2, int i3, int i4, int i5, int i6) {
        int size = iNDArray.size(0);
        int size2 = iNDArray.size(1);
        int size3 = iNDArray.size(2);
        int size4 = iNDArray.size(3);
        int size5 = iNDArray.size(4);
        int size6 = iNDArray.size(5);
        INDArray create = Nd4j.create(size, size2, ((i5 + (2 * i3)) + i) - 1, ((i6 + (2 * i4)) + i2) - 1);
        for (int i7 = 0; i7 < size3; i7++) {
            int i8 = i7 + (i * size5);
            for (int i9 = 0; i9 < size4; i9++) {
                INDArrayIndex[] iNDArrayIndexArr = {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(i7, i, i8), NDArrayIndex.interval(i9, i2, i9 + (i2 * size6))};
                INDArray iNDArray2 = create.get(iNDArrayIndexArr);
                iNDArray2.addi(iNDArray.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(i7), NDArrayIndex.point(i9), NDArrayIndex.all(), NDArrayIndex.all()));
                create.put(iNDArrayIndexArr, iNDArray2);
            }
        }
        return create.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(i3, i3 + i5), NDArrayIndex.interval(i4, i4 + i6));
    }

    public static INDArray im2col(INDArray iNDArray, int[] iArr, int[] iArr2, int[] iArr3) {
        return im2col(iNDArray, iArr[0], iArr[1], iArr2[0], iArr2[1], iArr3[0], iArr3[1], 0, false);
    }

    /* JADX WARN: Type inference failed for: r1v7, types: [int[], int[][]] */
    public static INDArray im2col(INDArray iNDArray, int i, int i2, int i3, int i4, int i5, int i6, int i7, boolean z) {
        int size = iNDArray.size(0);
        int size2 = iNDArray.size(1);
        int size3 = iNDArray.size(2);
        int size4 = iNDArray.size(3);
        int outSize = outSize(size3, i, i3, i5, z);
        int outSize2 = outSize(size4, i2, i4, i6, z);
        INDArray pad = Nd4j.pad(iNDArray, (int[][]) new int[]{new int[]{0, 0}, new int[]{0, 0}, new int[]{i5, (i5 + i3) - 1}, new int[]{i6, (i6 + i4) - 1}}, Nd4j.PadMode.CONSTANT);
        INDArray create = Nd4j.create(size, size2, i, i2, outSize, outSize2);
        for (int i8 = 0; i8 < i; i8++) {
            int i9 = i8 + (i3 * outSize);
            for (int i10 = 0; i10 < i2; i10++) {
                create.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(i8), NDArrayIndex.point(i10), NDArrayIndex.all(), NDArrayIndex.all()}, pad.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(i8, i3, i9), NDArrayIndex.interval(i10, i4, i10 + (i4 * outSize2))));
            }
        }
        return create;
    }

    public static int outSize(int i, int i2, int i3, int i4, boolean z) {
        return z ? (((((i + (i4 * 2)) - i2) + i3) - 1) / i3) + 1 : (((i + (i4 * 2)) - i2) / i3) + 1;
    }

    public static INDArray conv2d(INDArray iNDArray, INDArray iNDArray2, Type type) {
        return Nd4j.getConvolution().conv2d(iNDArray, iNDArray2, type);
    }

    public static INDArray conv2d(IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2, Type type) {
        return Nd4j.getConvolution().conv2d(iComplexNDArray, iComplexNDArray2, type);
    }

    public static INDArray convn(INDArray iNDArray, INDArray iNDArray2, Type type, int[] iArr) {
        return Nd4j.getConvolution().convn(iNDArray, iNDArray2, type, iArr);
    }

    public static IComplexNDArray convn(IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2, Type type, int[] iArr) {
        return Nd4j.getConvolution().convn(iComplexNDArray, iComplexNDArray2, type, iArr);
    }

    public static INDArray convn(INDArray iNDArray, INDArray iNDArray2, Type type) {
        return Nd4j.getConvolution().convn(iNDArray, iNDArray2, type);
    }

    public static IComplexNDArray convn(IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2, Type type) {
        return Nd4j.getConvolution().convn(iComplexNDArray, iComplexNDArray2, type);
    }
}
