package org.deeplearning4j.util;

import java.util.Arrays;
import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/util/ConvolutionUtils.class */
public class ConvolutionUtils {
    private ConvolutionUtils() {
    }

    public static int[] getOutputSize(INDArray iNDArray, int[] iArr, int[] iArr2, int[] iArr3, ConvolutionMode convolutionMode) {
        int size = iNDArray.size(2);
        int size2 = iNDArray.size(3);
        if (convolutionMode != ConvolutionMode.Same && (iArr[0] <= 0 || iArr[0] > size + (2 * iArr3[0]))) {
            throw new DL4JInvalidInputException("Invalid input data or configuration: kernel height and input height must satisfy 0 < kernel height <= input height + 2 * padding height. \nGot kernel height = " + iArr[0] + ", input height = " + size + " and padding height = " + iArr3[0] + " which do not satisfy 0 < " + iArr[0] + " <= " + (size + (2 * iArr3[0])) + getCommonErrorMsg(iNDArray, iArr, iArr2, iArr3));
        }
        if (convolutionMode != ConvolutionMode.Same && (iArr[1] <= 0 || iArr[1] > size2 + (2 * iArr3[1]))) {
            throw new DL4JInvalidInputException("Invalid input data or configuration: kernel width and input width must satisfy  0 < kernel width <= input width + 2 * padding width. \nGot kernel width = " + iArr[1] + ", input width = " + size2 + " and padding width = " + iArr3[1] + " which do not satisfy 0 < " + iArr[1] + " <= " + (size2 + (2 * iArr3[1])) + "\nInput size: [numExamples,inputDepth,inputHeight,inputWidth]=" + Arrays.toString(iNDArray.shape()) + getCommonErrorMsg(iNDArray, iArr, iArr2, iArr3));
        }
        if (convolutionMode == ConvolutionMode.Strict) {
            if (((size - iArr[0]) + (2 * iArr3[0])) % iArr2[0] != 0) {
                double d = (((size - iArr[0]) + (2 * iArr3[0])) / iArr2[0]) + 1.0d;
                String format = String.format("%.2f", Double.valueOf(d));
                throw new DL4JInvalidConfigException("Invalid input data or configuration: Combination of kernel size, stride and padding are not valid for given input height, using ConvolutionMode.Strict\nConvolutionMode.Strict requires: output height = (input height - kernelSize + 2*padding)/stride + 1 to be an integer. Got: (" + size + " - " + iArr[0] + " + 2*" + iArr3[0] + ")/" + iArr2[0] + " + 1 = " + format + "\nSee \"Constraints on strides\" at http://cs231n.github.io/convolutional-networks/ and ConvolutionType enumeration Javadoc.\nTo truncate/crop the input, such that output height = floor(" + format + ") = " + ((int) d) + ", use ConvolutionType.Truncate.\nAlternatively use ConvolutionType.Same, which will use padding to give an output height of ceil(" + size + "/" + iArr2[0] + ")=" + ((int) Math.ceil(size / iArr2[0])) + getCommonErrorMsg(iNDArray, iArr, iArr2, iArr3));
            }
            if (((size2 - iArr[1]) + (2 * iArr3[1])) % iArr2[1] != 0) {
                double d2 = (((size2 - iArr[1]) + (2 * iArr3[1])) / iArr2[1]) + 1.0d;
                String format2 = String.format("%.2f", Double.valueOf(d2));
                throw new DL4JInvalidConfigException("Invalid input data or configuration: Combination of kernel size, stride and padding are not valid for given input width, using ConvolutionMode.Strict\nConvolutionMode.Strict requires: output width = (input - kernelSize + 2*padding)/stride + 1 to be an integer. Got: (" + size2 + " - " + iArr[1] + " + 2*" + iArr3[1] + ")/" + iArr2[1] + " + 1 = " + format2 + "\nSee \"Constraints on strides\" at http://cs231n.github.io/convolutional-networks/ and ConvolutionType enumeration Javadoc.\nTo truncate/crop the input, such that output width = floor(" + format2 + ") = " + ((int) d2) + ", use ConvolutionType.Truncate.\nAlternatively use ConvolutionType.Same, which will use padding to give an output width of ceil(" + size2 + "/" + iArr2[1] + ")=" + ((int) Math.ceil(size2 / iArr2[1])) + getCommonErrorMsg(iNDArray, iArr, iArr2, iArr3));
            }
        } else if (convolutionMode == ConvolutionMode.Same) {
            return new int[]{(int) Math.ceil(size / iArr2[0]), (int) Math.ceil(size2 / iArr2[1])};
        }
        return new int[]{(((size - iArr[0]) + (2 * iArr3[0])) / iArr2[0]) + 1, (((size2 - iArr[1]) + (2 * iArr3[1])) / iArr2[1]) + 1};
    }

    private static String getCommonErrorMsg(INDArray iNDArray, int[] iArr, int[] iArr2, int[] iArr3) {
        return "\nInput size: [numExamples,inputDepth,inputHeight,inputWidth]=" + Arrays.toString(iNDArray.shape()) + ", kernel=" + Arrays.toString(iArr) + ", strides=" + Arrays.toString(iArr2) + ", padding=" + Arrays.toString(iArr3);
    }

    public static int[] getSameModeTopLeftPadding(int[] iArr, int[] iArr2, int[] iArr3, int[] iArr4) {
        return new int[]{((((iArr[0] - 1) * iArr4[0]) + iArr3[0]) - iArr2[0]) / 2, ((((iArr[1] - 1) * iArr4[1]) + iArr3[1]) - iArr2[1]) / 2};
    }

    public static int[] getSameModeBottomRightPadding(int[] iArr, int[] iArr2, int[] iArr3, int[] iArr4) {
        return new int[]{(((((iArr[0] - 1) * iArr4[0]) + iArr3[0]) - iArr2[0]) + 1) / 2, (((((iArr[1] - 1) * iArr4[1]) + iArr3[1]) - iArr2[1]) + 1) / 2};
    }

    public static int[] getHeightAndWidth(NeuralNetConfiguration neuralNetConfiguration) {
        return getHeightAndWidth(((ConvolutionLayer) neuralNetConfiguration.getLayer()).getKernelSize());
    }

    public static int numFeatureMap(NeuralNetConfiguration neuralNetConfiguration) {
        return ((ConvolutionLayer) neuralNetConfiguration.getLayer()).getNOut();
    }

    public static int[] getHeightAndWidth(int[] iArr) {
        if (iArr.length < 2) {
            throw new IllegalArgumentException("No width and height able to be found: array must be at least length 2");
        }
        return new int[]{iArr[iArr.length - 1], iArr[iArr.length - 2]};
    }

    public static int numChannels(int[] iArr) {
        if (iArr.length < 4) {
            return 1;
        }
        return iArr[1];
    }

    public static void validateCnnKernelStridePadding(int[] iArr, int[] iArr2, int[] iArr3) {
        if (iArr == null || iArr.length != 2) {
            throw new IllegalStateException("Invalid kernel size: expected int[] of length 2, got " + (iArr == null ? null : Arrays.toString(iArr)));
        }
        if (iArr2 == null || iArr2.length != 2) {
            throw new IllegalStateException("Invalid stride configuration: expected int[] of length 2, got " + (iArr2 == null ? null : Arrays.toString(iArr2)));
        }
        if (iArr3 == null || iArr3.length != 2) {
            throw new IllegalStateException("Invalid padding configuration: expected int[] of length 2, got " + (iArr3 == null ? null : Arrays.toString(iArr3)));
        }
        if (iArr[0] <= 0 || iArr[1] <= 0) {
            throw new IllegalStateException("Invalid kernel size: values must be positive (> 0) for all dimensions. Got: " + Arrays.toString(iArr));
        }
        if (iArr2[0] <= 0 || iArr2[1] <= 0) {
            throw new IllegalStateException("Invalid stride configuration: values must be positive (> 0) for all dimensions. Got: " + Arrays.toString(iArr2));
        }
        if (iArr3[0] < 0 || iArr3[1] < 0) {
            throw new IllegalStateException("Invalid padding configuration: values must be >= 0 for all dimensions. Got: " + Arrays.toString(iArr3));
        }
    }
}
