package org.deeplearning4j.nn.layers.recurrent;

import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.util.Dropout;
import org.nd4j.linalg.api.blas.Level1;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:org/deeplearning4j/nn/layers/recurrent/GravesLSTM.class */
public class GravesLSTM extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.layers.GravesLSTM> {
    public static final String STATE_KEY_PREV_ACTIVATION = "prevAct";
    public static final String STATE_KEY_PREV_MEMCELL = "prevMem";

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/deeplearning4j/nn/layers/recurrent/GravesLSTM$FwdPassReturn.class */
    public static class FwdPassReturn {
        private INDArray fwdPassOutput;
        private INDArray[] paramsMmulCompatible;
        private INDArray[] fwdPassOutputAsArrays;
        private INDArray[] memCellState;
        private INDArray[] memCellActivations;
        private INDArray[] iz;
        private INDArray[] ia;
        private INDArray[] fa;
        private INDArray[] oa;
        private INDArray[] ga;
        private INDArray lastAct;
        private INDArray lastMemCell;

        private FwdPassReturn() {
        }
    }

    public GravesLSTM(NeuralNetConfiguration neuralNetConfiguration) {
        super(neuralNetConfiguration);
    }

    public GravesLSTM(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        super(neuralNetConfiguration, iNDArray);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Model
    public Gradient gradient() {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Gradient calcGradient(Gradient gradient, INDArray iNDArray) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray) {
        return backpropGradientHelper(iNDArray, false, -1);
    }

    @Override // org.deeplearning4j.nn.layers.recurrent.BaseRecurrentLayer
    public Pair<Gradient, INDArray> tbpttBackpropGradient(INDArray iNDArray, int i) {
        return backpropGradientHelper(iNDArray, true, i);
    }

    private Pair<Gradient, INDArray> backpropGradientHelper(INDArray iNDArray, boolean z, int i) {
        FwdPassReturn activateHelper;
        if (z) {
            activateHelper = activateHelper(true, this.stateMap.get("prevAct"), this.stateMap.get(STATE_KEY_PREV_MEMCELL), true);
            this.tBpttStateMap.put("prevAct", activateHelper.lastAct);
            this.tBpttStateMap.put(STATE_KEY_PREV_MEMCELL, activateHelper.lastMemCell);
        } else {
            activateHelper = activateHelper(true, null, null, true);
        }
        INDArray param = getParam("W");
        int size = getParam("RW").size(0);
        int size2 = param.size(0);
        int size3 = iNDArray.size(0);
        boolean z2 = iNDArray.rank() < 3;
        int size4 = z2 ? 1 : iNDArray.size(2);
        INDArray iNDArray2 = activateHelper.paramsMmulCompatible[0];
        INDArray iNDArray3 = activateHelper.paramsMmulCompatible[1];
        INDArray iNDArray4 = activateHelper.paramsMmulCompatible[2];
        INDArray iNDArray5 = activateHelper.paramsMmulCompatible[3];
        INDArray iNDArray6 = activateHelper.paramsMmulCompatible[5];
        INDArray iNDArray7 = activateHelper.paramsMmulCompatible[6];
        INDArray iNDArray8 = activateHelper.paramsMmulCompatible[8];
        INDArray iNDArray9 = activateHelper.paramsMmulCompatible[9];
        INDArray iNDArray10 = activateHelper.paramsMmulCompatible[4];
        INDArray iNDArray11 = activateHelper.paramsMmulCompatible[7];
        INDArray iNDArray12 = activateHelper.paramsMmulCompatible[10];
        INDArray[] iNDArrayArr = new INDArray[4];
        INDArray[] iNDArrayArr2 = new INDArray[4];
        INDArray[] iNDArrayArr3 = new INDArray[7];
        for (int i2 = 0; i2 < 4; i2++) {
            iNDArrayArr[i2] = Nd4j.create(new int[]{1, size});
            iNDArrayArr2[i2] = Nd4j.create(new int[]{size2, size}, 'f');
            iNDArrayArr3[i2] = Nd4j.create(new int[]{size, size}, 'f');
        }
        for (int i3 = 0; i3 < 3; i3++) {
            iNDArrayArr3[i3 + 4] = Nd4j.zeros(1, size);
        }
        INDArray zeros = Nd4j.zeros(new int[]{size3, size2, size4});
        INDArray iNDArray13 = null;
        INDArray iNDArray14 = null;
        INDArray iNDArray15 = null;
        INDArray iNDArray16 = null;
        INDArray iNDArray17 = null;
        Level1 level1 = Nd4j.getBlasWrapper().level1();
        int max = z ? Math.max(0, size4 - i) : 0;
        int i4 = size4 - 1;
        while (i4 >= max) {
            INDArray iNDArray18 = i4 == 0 ? null : activateHelper.memCellState[i4 - 1];
            INDArray iNDArray19 = i4 == 0 ? null : activateHelper.fwdPassOutputAsArrays[i4 - 1];
            INDArray iNDArray20 = activateHelper.memCellState[i4];
            INDArray offsetZeroCopy = Shape.toOffsetZeroCopy(z2 ? iNDArray : iNDArray.tensorAlongDimension(i4, new int[]{1, 0}), 'f');
            if (i4 != size4 - 1) {
                Nd4j.gemm(iNDArray14, iNDArray3, offsetZeroCopy, false, true, 1.0d, 1.0d);
                Nd4j.gemm(iNDArray15, iNDArray5, offsetZeroCopy, false, true, 1.0d, 1.0d);
                Nd4j.gemm(iNDArray16, iNDArray7, offsetZeroCopy, false, true, 1.0d, 1.0d);
                Nd4j.gemm(iNDArray17, iNDArray9, offsetZeroCopy, false, true, 1.0d, 1.0d);
            }
            INDArray iNDArray21 = activateHelper.memCellActivations[i4];
            INDArray iNDArray22 = activateHelper.oa[i4];
            INDArray muli = offsetZeroCopy.dup('f').muli(iNDArray21).muli(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("timesoneminus", iNDArray22.dup('f'))));
            INDArray muli2 = iNDArray22.muli(offsetZeroCopy).muli(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf.getLayer().getActivationFunction(), iNDArray20.dup('f')).derivative()));
            level1.axpy(muli2.length(), 1.0d, muli.dup('f').muliRowVector(iNDArray11), muli2);
            if (i4 != size4 - 1) {
                INDArray iNDArray23 = activateHelper.fa[i4 + 1];
                int length = muli2.length();
                level1.axpy(length, 1.0d, iNDArray23.muli(iNDArray13), muli2);
                level1.axpy(length, 1.0d, iNDArray15.dup('f').muliRowVector(iNDArray10), muli2);
                level1.axpy(length, 1.0d, iNDArray17.dup('f').muliRowVector(iNDArray12), muli2);
            }
            iNDArray13 = muli2;
            INDArray muli3 = i4 > 0 ? muli2.dup('f').muli(iNDArray18).muli(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("timesoneminus", activateHelper.fa[i4].dup('f')))) : null;
            INDArray iNDArray24 = activateHelper.ga[i4];
            INDArray muli4 = activateHelper.ia[i4].muli(muli2).muli(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("timesoneminus", iNDArray24.dup('f'))));
            INDArray muli5 = iNDArray24.muli(muli2).muli(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf.getLayer().getActivationFunction(), activateHelper.iz[i4]).derivative()));
            INDArray mmulCompatible = Shape.toMmulCompatible(z2 ? this.input : this.input.tensorAlongDimension(i4, new int[]{1, 0}));
            Nd4j.gemm(mmulCompatible, muli5, iNDArrayArr2[0], true, false, 1.0d, 1.0d);
            if (i4 > 0) {
                Nd4j.gemm(mmulCompatible, muli3, iNDArrayArr2[1], true, false, 1.0d, 1.0d);
            }
            Nd4j.gemm(mmulCompatible, muli, iNDArrayArr2[2], true, false, 1.0d, 1.0d);
            Nd4j.gemm(mmulCompatible, muli4, iNDArrayArr2[3], true, false, 1.0d, 1.0d);
            if (i4 > 0) {
                Nd4j.gemm(iNDArray19, muli5, iNDArrayArr3[0], true, false, 1.0d, 1.0d);
                Nd4j.gemm(iNDArray19, muli3, iNDArrayArr3[1], true, false, 1.0d, 1.0d);
                Nd4j.gemm(iNDArray19, muli, iNDArrayArr3[2], true, false, 1.0d, 1.0d);
                Nd4j.gemm(iNDArray19, muli4, iNDArrayArr3[3], true, false, 1.0d, 1.0d);
                level1.axpy(iNDArrayArr3[4].length(), 1.0d, muli3.dup('f').muli(iNDArray18).sum(new int[]{0}), iNDArrayArr3[4]);
                level1.axpy(iNDArrayArr3[6].length(), 1.0d, muli4.dup('f').muli(iNDArray18).sum(new int[]{0}), iNDArrayArr3[6]);
            }
            level1.axpy(iNDArrayArr3[5].length(), 1.0d, muli.dup('f').muli(iNDArray20).sum(new int[]{0}), iNDArrayArr3[5]);
            level1.axpy(iNDArrayArr[0].length(), 1.0d, muli5.sum(new int[]{0}), iNDArrayArr[0]);
            if (i4 > 0) {
                level1.axpy(iNDArrayArr[1].length(), 1.0d, muli3.sum(new int[]{0}), iNDArrayArr[1]);
            }
            level1.axpy(iNDArrayArr[2].length(), 1.0d, muli.sum(new int[]{0}), iNDArrayArr[2]);
            level1.axpy(iNDArrayArr[3].length(), 1.0d, muli4.sum(new int[]{0}), iNDArrayArr[3]);
            INDArray gemm = Nd4j.gemm(muli5, iNDArray2, false, true);
            Nd4j.gemm(muli, iNDArray6, gemm, false, true, 1.0d, 1.0d);
            Nd4j.gemm(muli4, iNDArray8, gemm, false, true, 1.0d, 1.0d);
            if (i4 > 0) {
                Nd4j.gemm(muli3, iNDArray4, gemm, false, true, 1.0d, 1.0d);
            }
            zeros.tensorAlongDimension(i4, new int[]{1, 0}).assign(gemm);
            iNDArray14 = muli5;
            iNDArray15 = muli3;
            iNDArray16 = muli;
            iNDArray17 = muli4;
            i4--;
        }
        INDArray zeros2 = Nd4j.zeros(size2, 4 * size);
        INDArray zeros3 = Nd4j.zeros(size, (4 * size) + 3);
        INDArray hstack = Nd4j.hstack(iNDArrayArr);
        zeros2.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, size)}, iNDArrayArr2[0]);
        zeros2.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(size, 2 * size)}, iNDArrayArr2[1]);
        zeros2.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(2 * size, 3 * size)}, iNDArrayArr2[2]);
        zeros2.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(3 * size, 4 * size)}, iNDArrayArr2[3]);
        zeros3.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, size)}, iNDArrayArr3[0]);
        zeros3.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(size, 2 * size)}, iNDArrayArr3[1]);
        zeros3.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(2 * size, 3 * size)}, iNDArrayArr3[2]);
        zeros3.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(3 * size, 4 * size)}, iNDArrayArr3[3]);
        zeros3.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.point(4 * size)}, iNDArrayArr3[4].transpose());
        zeros3.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.point((4 * size) + 1)}, iNDArrayArr3[5].transpose());
        zeros3.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.point((4 * size) + 2)}, iNDArrayArr3[6].transpose());
        DefaultGradient defaultGradient = new DefaultGradient();
        defaultGradient.gradientForVariable().put("W", zeros2);
        defaultGradient.gradientForVariable().put("RW", zeros3);
        defaultGradient.gradientForVariable().put("b", hstack);
        return new Pair<>(defaultGradient, zeros);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray preOutput(INDArray iNDArray) {
        return activate(iNDArray, true);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray preOutput(INDArray iNDArray, boolean z) {
        return activate(iNDArray, z);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(INDArray iNDArray, boolean z) {
        setInput(iNDArray, z);
        return activateHelper(z, null, null, false).fwdPassOutput;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(INDArray iNDArray) {
        setInput(iNDArray);
        return activateHelper(true, null, null, false).fwdPassOutput;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z) {
        return activateHelper(z, null, null, false).fwdPassOutput;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate() {
        return activateHelper(false, null, null, false).fwdPassOutput;
    }

    private FwdPassReturn activateHelper(boolean z, INDArray iNDArray, INDArray iNDArray2, boolean z2) {
        INDArray param = getParam("RW");
        INDArray param2 = getParam("W");
        INDArray param3 = getParam("b");
        boolean z3 = this.input.rank() < 3;
        int size = z3 ? 1 : this.input.size(2);
        int size2 = param.size(0);
        int size3 = this.input.size(0);
        if (this.conf.isUseDropConnect() && z && this.conf.getLayer().getDropOut() > 0.0d) {
            param2 = Dropout.applyDropConnect(this, "W");
        }
        INDArray iNDArray3 = param2.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, size2)});
        INDArray iNDArray4 = param.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, size2)});
        INDArray iNDArray5 = param3.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(0, size2)});
        INDArray iNDArray6 = param2.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(size2, 2 * size2)});
        INDArray iNDArray7 = param.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(size2, 2 * size2)});
        INDArray transpose = param.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(4 * size2, (4 * size2) + 1)}).transpose();
        INDArray iNDArray8 = param3.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(size2, 2 * size2)});
        INDArray iNDArray9 = param2.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(2 * size2, 3 * size2)});
        INDArray iNDArray10 = param.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(2 * size2, 3 * size2)});
        INDArray transpose2 = param.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((4 * size2) + 1, (4 * size2) + 2)}).transpose();
        INDArray iNDArray11 = param3.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(2 * size2, 3 * size2)});
        INDArray iNDArray12 = param2.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(3 * size2, 4 * size2)});
        INDArray iNDArray13 = param.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(3 * size2, 4 * size2)});
        INDArray transpose3 = param.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((4 * size2) + 2, (4 * size2) + 3)}).transpose();
        INDArray iNDArray14 = param3.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(3 * size2, 4 * size2)});
        if (size > 1 || z2) {
            iNDArray3 = Shape.toMmulCompatible(iNDArray3);
            iNDArray4 = Shape.toMmulCompatible(iNDArray4);
            iNDArray6 = Shape.toMmulCompatible(iNDArray6);
            iNDArray7 = Shape.toMmulCompatible(iNDArray7);
            transpose = Shape.toMmulCompatible(transpose);
            iNDArray9 = Shape.toMmulCompatible(iNDArray9);
            iNDArray10 = Shape.toMmulCompatible(iNDArray10);
            transpose2 = Shape.toMmulCompatible(transpose2);
            iNDArray12 = Shape.toMmulCompatible(iNDArray12);
            iNDArray13 = Shape.toMmulCompatible(iNDArray13);
            transpose3 = Shape.toMmulCompatible(transpose3);
            iNDArray5 = Shape.toMmulCompatible(iNDArray5);
            iNDArray8 = Shape.toMmulCompatible(iNDArray8);
            iNDArray11 = Shape.toMmulCompatible(iNDArray11);
            iNDArray14 = Shape.toMmulCompatible(iNDArray14);
        }
        INDArray iNDArray15 = null;
        FwdPassReturn fwdPassReturn = new FwdPassReturn();
        if (z2) {
            fwdPassReturn.paramsMmulCompatible = new INDArray[]{iNDArray3, iNDArray4, iNDArray6, iNDArray7, transpose, iNDArray9, iNDArray10, transpose2, iNDArray12, iNDArray13, transpose3};
            fwdPassReturn.fwdPassOutputAsArrays = new INDArray[size];
            fwdPassReturn.memCellState = new INDArray[size];
            fwdPassReturn.memCellActivations = new INDArray[size];
            fwdPassReturn.iz = new INDArray[size];
            fwdPassReturn.ia = new INDArray[size];
            fwdPassReturn.fa = new INDArray[size];
            fwdPassReturn.oa = new INDArray[size];
            fwdPassReturn.ga = new INDArray[size];
        } else {
            iNDArray15 = Nd4j.zeros(new int[]{size3, size2, size});
            fwdPassReturn.fwdPassOutput = iNDArray15;
        }
        Level1 level1 = Nd4j.getBlasWrapper().level1();
        if (iNDArray == null) {
            iNDArray = Nd4j.zeros(new int[]{size3, size2});
        }
        if (iNDArray2 == null) {
            iNDArray2 = Nd4j.zeros(new int[]{size3, size2});
        }
        for (int i = 0; i < size; i++) {
            INDArray mmulCompatible = Shape.toMmulCompatible(z3 ? this.input : this.input.tensorAlongDimension(i, new int[]{1, 0}));
            INDArray mmul = mmulCompatible.mmul(iNDArray3);
            Nd4j.gemm(iNDArray, iNDArray4, mmul, false, false, 1.0d, 1.0d);
            mmul.addiRowVector(iNDArray5);
            if (z2) {
                fwdPassReturn.iz[i] = mmul.dup('f');
            }
            Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf.getLayer().getActivationFunction(), mmul));
            if (z2) {
                fwdPassReturn.ia[i] = mmul;
            }
            INDArray mmul2 = mmulCompatible.mmul(iNDArray6);
            Nd4j.gemm(iNDArray, iNDArray7, mmul2, false, false, 1.0d, 1.0d);
            INDArray muliRowVector = iNDArray2.dup('f').muliRowVector(transpose);
            level1.axpy(muliRowVector.length(), 1.0d, muliRowVector, mmul2);
            mmul2.addiRowVector(iNDArray8);
            Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("sigmoid", mmul2));
            if (z2) {
                fwdPassReturn.fa[i] = mmul2;
            }
            INDArray mmul3 = mmulCompatible.mmul(iNDArray12);
            Nd4j.gemm(iNDArray, iNDArray13, mmul3, false, false, 1.0d, 1.0d);
            INDArray muliRowVector2 = iNDArray2.dup('f').muliRowVector(transpose3);
            level1.axpy(muliRowVector2.length(), 1.0d, muliRowVector2, mmul3);
            mmul3.addiRowVector(iNDArray14);
            Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("sigmoid", mmul3));
            if (z2) {
                fwdPassReturn.ga[i] = mmul3;
            }
            INDArray muli = mmul2.dup('f').muli(iNDArray2);
            level1.axpy(muli.length(), 1.0d, mmul3.dup('f').muli(mmul), muli);
            INDArray mmul4 = mmulCompatible.mmul(iNDArray9);
            Nd4j.gemm(iNDArray, iNDArray10, mmul4, false, false, 1.0d, 1.0d);
            INDArray muliRowVector3 = muli.dup('f').muliRowVector(transpose2);
            level1.axpy(muliRowVector3.length(), 1.0d, muliRowVector3, mmul4);
            mmul4.addiRowVector(iNDArray11);
            Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("sigmoid", mmul4));
            if (z2) {
                fwdPassReturn.oa[i] = mmul4;
            }
            INDArray execAndReturn = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf.getLayer().getActivationFunction(), muli.dup('f')));
            INDArray muli2 = execAndReturn.dup('f').muli(mmul4);
            if (z2) {
                fwdPassReturn.fwdPassOutputAsArrays[i] = muli2;
                fwdPassReturn.memCellState[i] = muli;
                fwdPassReturn.memCellActivations[i] = execAndReturn;
            } else {
                iNDArray15.tensorAlongDimension(i, new int[]{1, 0}).assign(muli2);
            }
            iNDArray = muli2;
            iNDArray2 = muli;
            fwdPassReturn.lastAct = muli2;
            fwdPassReturn.lastMemCell = muli;
        }
        return fwdPassReturn;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activationMean() {
        return activate();
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Layer.Type type() {
        return Layer.Type.RECURRENT;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Layer transpose() {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public double calcL2() {
        if (!this.conf.isUseRegularization() || this.conf.getL2() <= 0.0d) {
            return 0.0d;
        }
        return 0.5d * this.conf.getL2() * (Transforms.pow(getParam("RW"), 2).sum(new int[]{Integer.MAX_VALUE}).getDouble(0) + Transforms.pow(getParam("W"), 2).sum(new int[]{Integer.MAX_VALUE}).getDouble(0));
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public double calcL1() {
        if (!this.conf.isUseRegularization() || this.conf.getL1() <= 0.0d) {
            return 0.0d;
        }
        return this.conf.getL1() * (Transforms.abs(getParam("RW")).sum(new int[]{Integer.MAX_VALUE}).getDouble(0) + Transforms.abs(getParam("W")).sum(new int[]{Integer.MAX_VALUE}).getDouble(0));
    }

    @Override // org.deeplearning4j.nn.layers.recurrent.BaseRecurrentLayer
    public INDArray rnnTimeStep(INDArray iNDArray) {
        setInput(iNDArray);
        FwdPassReturn activateHelper = activateHelper(false, this.stateMap.get("prevAct"), this.stateMap.get(STATE_KEY_PREV_MEMCELL), false);
        INDArray iNDArray2 = activateHelper.fwdPassOutput;
        this.stateMap.put("prevAct", activateHelper.lastAct);
        this.stateMap.put(STATE_KEY_PREV_MEMCELL, activateHelper.lastMemCell);
        return iNDArray2;
    }

    @Override // org.deeplearning4j.nn.layers.recurrent.BaseRecurrentLayer
    public INDArray rnnActivateUsingStoredState(INDArray iNDArray, boolean z, boolean z2) {
        setInput(iNDArray);
        FwdPassReturn activateHelper = activateHelper(z, this.stateMap.get("prevAct"), this.stateMap.get(STATE_KEY_PREV_MEMCELL), false);
        INDArray iNDArray2 = activateHelper.fwdPassOutput;
        if (z2) {
            this.tBpttStateMap.put("prevAct", activateHelper.lastAct);
            this.tBpttStateMap.put(STATE_KEY_PREV_MEMCELL, activateHelper.lastMemCell);
        }
        return iNDArray2;
    }
}
