package org.deeplearning4j.nn.layers.recurrent;

import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.deeplearning4j.util.Dropout;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:org/deeplearning4j/nn/layers/recurrent/GravesLSTM.class */
public class GravesLSTM extends BaseLayer<org.deeplearning4j.nn.conf.layers.GravesLSTM> {
    public GravesLSTM(NeuralNetConfiguration neuralNetConfiguration) {
        super(neuralNetConfiguration);
    }

    public GravesLSTM(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        super(neuralNetConfiguration, iNDArray);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Model
    public Gradient gradient() {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Gradient calcGradient(Gradient gradient, INDArray iNDArray) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray) {
        INDArray[] activateHelper = activateHelper(true);
        INDArray iNDArray2 = activateHelper[0];
        INDArray iNDArray3 = activateHelper[1];
        INDArray iNDArray4 = activateHelper[2];
        INDArray iNDArray5 = activateHelper[3];
        INDArray param = getParam("W");
        INDArray param2 = getParam("RW");
        int size = param2.size(0);
        int size2 = param.size(0);
        int size3 = iNDArray.size(0);
        boolean z = iNDArray.rank() < 3;
        int size4 = z ? 1 : iNDArray.size(2);
        INDArray iNDArray6 = param.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, size)});
        INDArray iNDArray7 = param2.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, size)});
        INDArray iNDArray8 = param.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(size, 2 * size)});
        INDArray iNDArray9 = param2.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(size, 2 * size)});
        INDArray iNDArray10 = param2.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(4 * size, (4 * size) + 1)});
        INDArray iNDArray11 = param.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(2 * size, 3 * size)});
        INDArray iNDArray12 = param2.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(2 * size, 3 * size)});
        INDArray iNDArray13 = param2.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((4 * size) + 1, (4 * size) + 2)});
        INDArray iNDArray14 = param.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(3 * size, 4 * size)});
        INDArray iNDArray15 = param2.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(3 * size, 4 * size)});
        INDArray iNDArray16 = param2.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((4 * size) + 2, (4 * size) + 3)});
        INDArray zeros = Nd4j.zeros(new int[]{size3, 4 * size, size4});
        INDArray zeros2 = Nd4j.zeros(new int[]{size2, 4 * size, size4});
        INDArray zeros3 = Nd4j.zeros(new int[]{size, (4 * size) + 3, size4});
        INDArray zeros4 = Nd4j.zeros(new int[]{size3, size2, size4});
        INDArray zeros5 = Nd4j.zeros(size3, size);
        int i = size4 - 1;
        while (i >= 0) {
            INDArray zeros6 = i == 0 ? Nd4j.zeros(size3, size) : iNDArray3.tensorAlongDimension(i - 1, new int[]{1, 0});
            INDArray zeros7 = i == 0 ? Nd4j.zeros(size3, size) : iNDArray2.tensorAlongDimension(i - 1, new int[]{1, 0});
            INDArray tensorAlongDimension = z ? iNDArray3 : iNDArray3.tensorAlongDimension(i, new int[]{1, 0});
            INDArray zeros8 = i == size4 - 1 ? Nd4j.zeros(size3, size) : zeros.tensorAlongDimension(i + 1, new int[]{1, 0}).get(new INDArrayIndex[]{NDArrayIndex.interval(0, size3), NDArrayIndex.interval(0, size)});
            INDArray zeros9 = i == size4 - 1 ? Nd4j.zeros(size3, size) : zeros.tensorAlongDimension(i + 1, new int[]{1, 0}).get(new INDArrayIndex[]{NDArrayIndex.interval(0, size3), NDArrayIndex.interval(size, 2 * size)});
            INDArray zeros10 = i == size4 - 1 ? Nd4j.zeros(size3, size) : zeros.tensorAlongDimension(i + 1, new int[]{1, 0}).get(new INDArrayIndex[]{NDArrayIndex.interval(0, size3), NDArrayIndex.interval(2 * size, 3 * size)});
            INDArray zeros11 = i == size4 - 1 ? Nd4j.zeros(size3, size) : zeros.tensorAlongDimension(i + 1, new int[]{1, 0}).get(new INDArrayIndex[]{NDArrayIndex.interval(0, size3), NDArrayIndex.interval(3 * size, 4 * size)});
            INDArray addi = (z ? iNDArray : iNDArray.tensorAlongDimension(i, new int[]{1, 0})).dup().addi(zeros8.mmul(iNDArray7.transpose())).addi(zeros9.mmul(iNDArray9.transpose())).addi(zeros10.mmul(iNDArray12.transpose())).addi(zeros11.mmul(iNDArray15.transpose()));
            INDArray muli = addi.mul(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf.getLayer().getActivationFunction(), tensorAlongDimension.dup()))).muli(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("sigmoid", z ? iNDArray4.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(2 * size, 3 * size)}) : iNDArray4.tensorAlongDimension(i, new int[]{1, 0}).get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(2 * size, 3 * size)})).derivative()));
            INDArray addi2 = addi.mul(z ? iNDArray5.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(2 * size, 3 * size)}) : iNDArray5.tensorAlongDimension(i, new int[]{1, 0}).get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(2 * size, 3 * size)})).muli(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf.getLayer().getActivationFunction(), tensorAlongDimension.dup()).derivative())).addi((i == size4 - 1 ? Nd4j.zeros(size3, size) : iNDArray5.tensorAlongDimension(i + 1, new int[]{1, 0}).get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(2 * size, 3 * size)})).mul(zeros5)).addi(zeros9.mmul(Nd4j.diag(iNDArray10))).addi(zeros10.mmul(Nd4j.diag(iNDArray13))).addi(zeros11.mmul(Nd4j.diag(iNDArray16)));
            zeros5 = addi2;
            INDArray muli2 = addi2.mul(zeros6).muli(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("sigmoid", z ? iNDArray4.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(size, 2 * size)}) : iNDArray4.tensorAlongDimension(i, new int[]{1, 0}).get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(size, 2 * size)})).derivative()));
            INDArray muli3 = addi2.mul(z ? iNDArray5.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(size, 2 * size)}) : iNDArray5.tensorAlongDimension(i, new int[]{1, 0}).get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(size, 2 * size)})).muli(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("sigmoid", z ? iNDArray4.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(3 * size, 4 * size)}) : iNDArray4.tensorAlongDimension(i, new int[]{1, 0}).get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(3 * size, 4 * size)})).derivative()));
            INDArray muli4 = addi2.mul(z ? iNDArray5.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(3 * size, 4 * size)}) : iNDArray5.tensorAlongDimension(i, new int[]{1, 0}).get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(3 * size, 4 * size)})).muli(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf.getLayer().getActivationFunction(), z ? iNDArray4.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, size)}) : iNDArray4.tensorAlongDimension(i, new int[]{1, 0}).get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, size)})).derivative()));
            INDArray tensorAlongDimension2 = z ? this.input : this.input.tensorAlongDimension(i, new int[]{1, 0});
            zeros2.tensorAlongDimension(i, new int[]{1, 0}).put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, size)}, muli4.transpose().mmul(tensorAlongDimension2).transpose());
            zeros2.tensorAlongDimension(i, new int[]{1, 0}).put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(size, 2 * size)}, muli2.transpose().mmul(tensorAlongDimension2).transpose());
            zeros2.tensorAlongDimension(i, new int[]{1, 0}).put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(2 * size, 3 * size)}, muli.transpose().mmul(tensorAlongDimension2).transpose());
            zeros2.tensorAlongDimension(i, new int[]{1, 0}).put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(3 * size, 4 * size)}, muli3.transpose().mmul(tensorAlongDimension2).transpose());
            if (i > 0) {
                zeros3.tensorAlongDimension(i, new int[]{1, 0}).put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, size)}, muli4.transpose().mmul(zeros7).transpose());
                zeros3.tensorAlongDimension(i, new int[]{1, 0}).put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(size, 2 * size)}, muli2.transpose().mmul(zeros7).transpose());
                zeros3.tensorAlongDimension(i, new int[]{1, 0}).put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(2 * size, 3 * size)}, muli.transpose().mmul(zeros7).transpose());
                zeros3.tensorAlongDimension(i, new int[]{1, 0}).put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(3 * size, 4 * size)}, muli3.transpose().mmul(zeros7).transpose());
                zeros3.tensorAlongDimension(i, new int[]{1, 0}).put(new INDArrayIndex[]{NDArrayIndex.all(), new NDArrayIndex(new int[]{4 * size})}, muli2.mul(zeros6).sum(new int[]{0}).transpose());
                zeros3.tensorAlongDimension(i, new int[]{1, 0}).put(new INDArrayIndex[]{NDArrayIndex.all(), new NDArrayIndex(new int[]{(4 * size) + 2})}, muli3.mul(zeros6).sum(new int[]{0}).transpose());
            }
            zeros3.tensorAlongDimension(i, new int[]{1, 0}).put(new INDArrayIndex[]{NDArrayIndex.all(), new NDArrayIndex(new int[]{(4 * size) + 1})}, muli.mul(tensorAlongDimension).sum(new int[]{0}).transpose());
            INDArray tensorAlongDimension3 = z ? zeros : zeros.tensorAlongDimension(i, new int[]{1, 0});
            tensorAlongDimension3.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, size)}, muli4);
            tensorAlongDimension3.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(size, 2 * size)}, muli2);
            tensorAlongDimension3.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(2 * size, 3 * size)}, muli);
            tensorAlongDimension3.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(3 * size, 4 * size)}, muli3);
            zeros4.tensorAlongDimension(i, new int[]{1, 0}).assign(iNDArray6.mmul(muli4.transpose()).transpose().addi(iNDArray8.mmul(muli2.transpose()).transpose()).addi(iNDArray11.mmul(muli.transpose()).transpose()).addi(iNDArray14.mmul(muli3.transpose()).transpose()));
            i--;
        }
        DefaultGradient defaultGradient = new DefaultGradient();
        defaultGradient.gradientForVariable().put("W", zeros2.sum(new int[]{2}));
        defaultGradient.gradientForVariable().put("RW", zeros3.sum(new int[]{2}));
        defaultGradient.gradientForVariable().put("b", zeros.sum(new int[]{2}).sum(new int[]{0}));
        return new Pair<>(defaultGradient, zeros4);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray preOutput(INDArray iNDArray) {
        return activate(iNDArray, true);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray preOutput(INDArray iNDArray, boolean z) {
        return activate(iNDArray, z);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(INDArray iNDArray, boolean z) {
        setInput(iNDArray, z);
        return activateHelper(z)[0];
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(INDArray iNDArray) {
        setInput(iNDArray);
        return activateHelper(true)[0];
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z) {
        return activateHelper(z)[0];
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate() {
        return activateHelper()[0];
    }

    private INDArray[] activateHelper() {
        return activateHelper(false);
    }

    private INDArray[] activateHelper(boolean z) {
        INDArray param = getParam("RW");
        INDArray param2 = getParam("W");
        INDArray param3 = getParam("b");
        boolean z2 = this.input.rank() < 3;
        int size = z2 ? 1 : this.input.size(2);
        int size2 = param.size(0);
        int size3 = this.input.size(0);
        param2.size(0);
        if (this.conf.isUseDropConnect() && z && this.conf.getLayer().getDropOut() > 0.0d) {
            param2 = Dropout.applyDropConnect(this, "W");
        }
        INDArray iNDArray = param2.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, size2)});
        INDArray iNDArray2 = param.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, size2)});
        INDArray iNDArray3 = param3.get(new INDArrayIndex[]{new NDArrayIndex(new int[]{0}), NDArrayIndex.interval(0, size2)});
        INDArray iNDArray4 = param2.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(size2, 2 * size2)});
        INDArray iNDArray5 = param.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(size2, 2 * size2)});
        INDArray iNDArray6 = param.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(4 * size2, (4 * size2) + 1)});
        INDArray iNDArray7 = param3.get(new INDArrayIndex[]{new NDArrayIndex(new int[]{0}), NDArrayIndex.interval(size2, 2 * size2)});
        INDArray iNDArray8 = param2.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(2 * size2, 3 * size2)});
        INDArray iNDArray9 = param.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(2 * size2, 3 * size2)});
        INDArray iNDArray10 = param.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((4 * size2) + 1, (4 * size2) + 2)});
        INDArray iNDArray11 = param3.get(new INDArrayIndex[]{new NDArrayIndex(new int[]{0}), NDArrayIndex.interval(2 * size2, 3 * size2)});
        INDArray iNDArray12 = param2.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(3 * size2, 4 * size2)});
        INDArray iNDArray13 = param.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(3 * size2, 4 * size2)});
        INDArray iNDArray14 = param.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((4 * size2) + 2, (4 * size2) + 3)});
        INDArray iNDArray15 = param3.get(new INDArrayIndex[]{new NDArrayIndex(new int[]{0}), NDArrayIndex.interval(3 * size2, 4 * size2)});
        INDArray zeros = Nd4j.zeros(new int[]{size3, size2, size});
        INDArray zeros2 = Nd4j.zeros(new int[]{size3, 4 * size2, size});
        INDArray zeros3 = Nd4j.zeros(new int[]{size3, 4 * size2, size});
        INDArray zeros4 = Nd4j.zeros(new int[]{size3, size2, size});
        int i = 0;
        while (i < size) {
            INDArray tensorAlongDimension = z2 ? this.input : this.input.tensorAlongDimension(i, new int[]{1, 0});
            INDArray zeros5 = i == 0 ? Nd4j.zeros(new int[]{size3, size2}) : zeros.tensorAlongDimension(i - 1, new int[]{1, 0});
            INDArray zeros6 = i == 0 ? Nd4j.zeros(new int[]{size3, size2}) : zeros4.tensorAlongDimension(i - 1, new int[]{1, 0});
            INDArray addiRowVector = tensorAlongDimension.mmul(iNDArray).addi(zeros5.mmul(iNDArray2)).addiRowVector(iNDArray3);
            INDArrayIndex[] iNDArrayIndexArr = {NDArrayIndex.all(), NDArrayIndex.interval(0, size2)};
            zeros2.tensorAlongDimension(i, new int[]{1, 0}).put(iNDArrayIndexArr, addiRowVector);
            zeros3.tensorAlongDimension(i, new int[]{1, 0}).put(iNDArrayIndexArr, Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf.getLayer().getActivationFunction(), addiRowVector)));
            INDArray addiRowVector2 = tensorAlongDimension.mmul(iNDArray4).addi(zeros5.mmul(iNDArray5)).addi(zeros6.mmul(Nd4j.diag(iNDArray6))).addiRowVector(iNDArray7);
            INDArrayIndex[] iNDArrayIndexArr2 = {NDArrayIndex.all(), NDArrayIndex.interval(size2, 2 * size2)};
            zeros2.tensorAlongDimension(i, new int[]{1, 0}).put(iNDArrayIndexArr2, addiRowVector2);
            zeros3.tensorAlongDimension(i, new int[]{1, 0}).put(iNDArrayIndexArr2, Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("sigmoid", addiRowVector2)));
            INDArray addiRowVector3 = tensorAlongDimension.mmul(iNDArray12).addi(zeros5.mmul(iNDArray13)).addi(zeros6.mmul(Nd4j.diag(iNDArray14))).addiRowVector(iNDArray15);
            INDArrayIndex[] iNDArrayIndexArr3 = {NDArrayIndex.all(), NDArrayIndex.interval(3 * size2, 4 * size2)};
            zeros2.tensorAlongDimension(i, new int[]{1, 0}).put(iNDArrayIndexArr3, addiRowVector3);
            zeros3.tensorAlongDimension(i, new int[]{1, 0}).put(iNDArrayIndexArr3, Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("sigmoid", addiRowVector3)));
            INDArray execAndReturn = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf.getLayer().getActivationFunction(), addiRowVector2.mul(zeros6).addi(addiRowVector3.mul(addiRowVector))));
            INDArray addiRowVector4 = tensorAlongDimension.mmul(iNDArray8).addi(zeros5.mmul(iNDArray9)).addi(execAndReturn.mmul(Nd4j.diag(iNDArray10))).addiRowVector(iNDArray11);
            INDArrayIndex[] iNDArrayIndexArr4 = {NDArrayIndex.all(), NDArrayIndex.interval(2 * size2, 3 * size2)};
            zeros2.tensorAlongDimension(i, new int[]{1, 0}).put(iNDArrayIndexArr4, addiRowVector4);
            zeros3.tensorAlongDimension(i, new int[]{1, 0}).put(iNDArrayIndexArr4, Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("sigmoid", addiRowVector4)));
            zeros.tensorAlongDimension(i, new int[]{1, 0}).assign(addiRowVector4.mul(execAndReturn));
            zeros4.tensorAlongDimension(i, new int[]{1, 0}).assign(execAndReturn);
            i++;
        }
        return new INDArray[]{zeros, zeros4, zeros2, zeros3};
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activationMean() {
        return activate();
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Layer.Type type() {
        return Layer.Type.RECURRENT;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Layer transpose() {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public double calcL2() {
        if (!this.conf.isUseRegularization() || this.conf.getL2() <= 0.0d) {
            return 0.0d;
        }
        return 0.5d * this.conf.getL2() * (Transforms.pow(getParam("RW"), 2).sum(new int[]{Integer.MAX_VALUE}).getDouble(0) + Transforms.pow(getParam("W"), 2).sum(new int[]{Integer.MAX_VALUE}).getDouble(0));
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public double calcL1() {
        if (!this.conf.isUseRegularization() || this.conf.getL1() <= 0.0d) {
            return 0.0d;
        }
        return this.conf.getL1() * (Transforms.abs(getParam("RW")).sum(new int[]{Integer.MAX_VALUE}).getDouble(0) + Transforms.abs(getParam("W")).sum(new int[]{Integer.MAX_VALUE}).getDouble(0));
    }
}
