package org.deeplearning4j.nn.layers.convolution;

import java.util.Arrays;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.deeplearning4j.util.Dropout;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.convolution.Convolution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:org/deeplearning4j/nn/layers/convolution/ConvolutionLayer.class */
public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.ConvolutionLayer> {
    protected INDArray col;

    public ConvolutionLayer(NeuralNetConfiguration neuralNetConfiguration) {
        super(neuralNetConfiguration);
    }

    public ConvolutionLayer(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        super(neuralNetConfiguration, iNDArray);
    }

    public void setCol(INDArray iNDArray) {
        this.col = iNDArray;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public double calcL2() {
        if (!this.conf.isUseRegularization() || this.conf.getL2() <= 0.0d) {
            return 0.0d;
        }
        return 0.5d * this.conf.getL2() * Transforms.pow(getParam("W"), 2).sum(new int[]{Integer.MAX_VALUE}).getDouble(0);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public double calcL1() {
        if (!this.conf.isUseRegularization() || this.conf.getL1() <= 0.0d) {
            return 0.0d;
        }
        return this.conf.getL1() * Transforms.abs(getParam("W")).sum(new int[]{Integer.MAX_VALUE}).getDouble(0);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Layer.Type type() {
        return Layer.Type.CONVOLUTIONAL;
    }

    public INDArray calculateDelta(INDArray iNDArray) {
        INDArray preOutput = preOutput(true);
        INDArray execAndReturn = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(conf().getLayer().getActivationFunction(), preOutput).derivative());
        if (Arrays.equals(preOutput.shape(), execAndReturn.shape())) {
            return iNDArray.muli(execAndReturn);
        }
        throw new IllegalStateException("Shapes must be same");
    }

    /* JADX WARN: Type inference failed for: r2v3, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r2v6, types: [int[], int[][]] */
    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray) {
        int size = input().size(-2);
        int size2 = input().size(-1);
        INDArray param = getParam("W");
        INDArray calculateDelta = calculateDelta(iNDArray);
        DefaultGradient defaultGradient = new DefaultGradient();
        defaultGradient.setGradientFor("b", calculateDelta.sum(new int[]{0, 2, 3}));
        defaultGradient.setGradientFor("W", Nd4j.tensorMmul(calculateDelta, this.col, (int[][]) new int[]{new int[]{0, 2, 3}, new int[]{0, 4, 5}}));
        return new Pair<>(defaultGradient, Convolution.col2im(Nd4j.rollAxis(Nd4j.tensorMmul(param, calculateDelta, (int[][]) new int[]{new int[]{0}, new int[]{1}}), 3), layerConf().getStride(), layerConf().getPadding(), size, size2));
    }

    /* JADX WARN: Type inference failed for: r2v1, types: [int[], int[][]] */
    public INDArray preOutput(boolean z) {
        INDArray param = getParam("W");
        INDArray param2 = getParam("b");
        if (this.conf.isUseDropConnect() && z && this.conf.getLayer().getDropOut() > 0.0d) {
            param = Dropout.applyDropConnect(this, "W");
        }
        INDArray tensorMmul = Nd4j.tensorMmul(this.col, param, (int[][]) new int[]{new int[]{1, 2, 3}, new int[]{1, 2, 3}});
        tensorMmul.addi(param2.dimShuffle(new Object[]{'x', 0, 'x', 'x'}, new int[]{0, 1}, new boolean[]{true, true}).broadcast(tensorMmul.shape()));
        return Nd4j.rollAxis(tensorMmul, 3, 1);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z) {
        if (this.input == null) {
            throw new IllegalArgumentException("No null input allowed");
        }
        applyDropOutIfNecessary(this.input, z);
        this.col = Convolution.im2col(this.input, layerConf().getKernelSize(), layerConf().getStride(), layerConf().getPadding());
        return Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf.getLayer().getActivationFunction(), preOutput(z)));
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Layer transpose() {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Gradient calcGradient(Gradient gradient, INDArray iNDArray) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public void merge(Layer layer, int i) {
        throw new UnsupportedOperationException();
    }
}
