package org.deeplearning4j.nn.layers.recurrent;

import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseOutputLayer;
import org.deeplearning4j.util.Dropout;
import org.deeplearning4j.util.TimeSeriesUtils;
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.SoftMax;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/nn/layers/recurrent/RnnOutputLayer.class */
public class RnnOutputLayer extends BaseOutputLayer<org.deeplearning4j.nn.conf.layers.RnnOutputLayer> {
    public RnnOutputLayer(NeuralNetConfiguration neuralNetConfiguration) {
        super(neuralNetConfiguration);
    }

    public RnnOutputLayer(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        super(neuralNetConfiguration, iNDArray);
    }

    private INDArray reshape3dTo2d(INDArray iNDArray) {
        if (iNDArray.rank() != 3) {
            throw new IllegalArgumentException("Invalid input: expect NDArray with rank 3");
        }
        int[] shape = iNDArray.shape();
        return shape[0] == 1 ? iNDArray.tensorAlongDimension(0, new int[]{1, 2}).permutei(new int[]{1, 0}) : shape[2] == 1 ? iNDArray.tensorAlongDimension(0, new int[]{1, 0}) : iNDArray.permute(new int[]{0, 2, 1}).reshape('f', shape[0] * shape[2], shape[1]);
    }

    private INDArray reshape2dTo3d(INDArray iNDArray, int i) {
        if (iNDArray.rank() != 2) {
            throw new IllegalArgumentException("Invalid input: expect NDArray with rank 2");
        }
        int[] shape = iNDArray.shape();
        if (iNDArray.ordering() != 'f') {
            iNDArray = Shape.toOffsetZeroCopy(iNDArray, 'f');
        }
        return iNDArray.reshape('f', new int[]{i, shape[0] / i, shape[1]}).permute(new int[]{0, 2, 1});
    }

    @Override // org.deeplearning4j.nn.layers.BaseOutputLayer, org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray) {
        if (this.input.rank() != 3) {
            throw new UnsupportedOperationException("Input is not rank 3");
        }
        INDArray iNDArray2 = this.input;
        this.input = reshape3dTo2d(this.input);
        Pair<Gradient, INDArray> backpropGradient = super.backpropGradient(iNDArray);
        this.input = iNDArray2;
        return new Pair<>(backpropGradient.getFirst(), reshape2dTo3d(backpropGradient.getSecond(), this.input.size(0)));
    }

    @Override // org.deeplearning4j.nn.layers.BaseOutputLayer, org.deeplearning4j.nn.api.Classifier
    public double f1Score(INDArray iNDArray, INDArray iNDArray2) {
        if (iNDArray.rank() == 3) {
            iNDArray = reshape3dTo2d(iNDArray);
        }
        if (iNDArray2.rank() == 3) {
            iNDArray2 = reshape3dTo2d(iNDArray2);
        }
        return super.f1Score(iNDArray, iNDArray2);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer
    public INDArray getInput() {
        return this.input;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Layer.Type type() {
        return Layer.Type.RECURRENT;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray preOutput(INDArray iNDArray, boolean z) {
        setInput(iNDArray);
        return reshape2dTo3d(preOutput2d(z), this.input.size(0));
    }

    @Override // org.deeplearning4j.nn.layers.BaseOutputLayer
    protected INDArray preOutput2d(boolean z) {
        if (this.input.rank() != 3) {
            return super.preOutput(this.input, z);
        }
        INDArray iNDArray = this.input;
        this.input = reshape3dTo2d(this.input);
        INDArray preOutput = super.preOutput(this.input, z);
        this.input = iNDArray;
        return preOutput;
    }

    @Override // org.deeplearning4j.nn.layers.BaseOutputLayer
    protected INDArray getLabels2d() {
        return this.labels.rank() == 3 ? reshape3dTo2d(this.labels) : this.labels;
    }

    @Override // org.deeplearning4j.nn.layers.BaseOutputLayer
    public INDArray output(INDArray iNDArray) {
        if (iNDArray.rank() != 3) {
            throw new IllegalArgumentException("Input must be rank 3 (is: " + iNDArray.rank());
        }
        setInput(iNDArray);
        return output(false);
    }

    @Override // org.deeplearning4j.nn.layers.BaseOutputLayer
    public INDArray output(boolean z) {
        if (this.input.rank() != 3) {
            throw new IllegalArgumentException("input must be rank 3");
        }
        INDArray preOutput2d = preOutput2d(z);
        if (this.conf.getLayer().getActivationFn() instanceof ActivationSoftmax) {
            INDArray execAndReturn = Nd4j.getExecutioner().execAndReturn(new SoftMax(preOutput2d));
            if (this.maskArray != null) {
                execAndReturn.muliColumnVector(this.maskArray);
            }
            return reshape2dTo3d(execAndReturn, this.input.size(0));
        }
        if (z) {
            applyDropOutIfNecessary(z);
        }
        INDArray iNDArray = this.input;
        this.input = reshape3dTo2d(this.input);
        INDArray activate = super.activate(true);
        this.input = iNDArray;
        if (this.maskArray != null) {
            activate.muliColumnVector(this.maskArray);
        }
        return reshape2dTo3d(activate, this.input.size(0));
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z) {
        if (this.input.rank() != 3) {
            throw new UnsupportedOperationException("Input must be rank 3");
        }
        INDArray param = getParam("b");
        INDArray param2 = getParam("W");
        if (this.conf.isUseDropConnect() && z) {
            param2 = Dropout.applyDropConnect(this, "W");
        }
        INDArray activation = this.conf.getLayer().getActivationFn().getActivation(reshape3dTo2d(this.input).mmul(param2).addiRowVector(param), z);
        if (this.maskArray != null) {
            activation.muliColumnVector(this.maskArray);
        }
        return reshape2dTo3d(activation, this.input.size(0));
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public void setMaskArray(INDArray iNDArray) {
        if (iNDArray != null && iNDArray.size(1) != 1) {
            iNDArray = TimeSeriesUtils.reshapeTimeSeriesMaskToVector(iNDArray);
        }
        this.maskArray = iNDArray;
    }
}
