package org.deeplearning4j.nn.layers.recurrent;

import java.util.List;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.EvaluationBinary;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.deeplearning4j.util.TimeSeriesUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:org/deeplearning4j/nn/layers/recurrent/RnnLossLayer.class */
public class RnnLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.RnnLossLayer> implements IOutputLayer {
    protected INDArray labels;

    public RnnLossLayer(NeuralNetConfiguration neuralNetConfiguration) {
        super(neuralNetConfiguration);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray) {
        if (this.input.rank() != 3) {
            throw new UnsupportedOperationException("Input is not rank 3. Got input with rank " + this.input.rank() + " " + layerId());
        }
        if (this.labels == null) {
            throw new IllegalStateException("Labels are not set (null)");
        }
        INDArray reshape3dTo2d = TimeSeriesUtils.reshape3dTo2d(this.input);
        return new Pair<>(new DefaultGradient(), TimeSeriesUtils.reshape2dTo3d(layerConf().getLossFn().computeGradient(TimeSeriesUtils.reshape3dTo2d(this.labels), reshape3dTo2d.dup(reshape3dTo2d.ordering()), layerConf().getActivationFn(), this.maskArray != null ? this.maskArray.rank() == 3 ? TimeSeriesUtils.reshapePerOutputTimeSeriesMaskTo2d(this.maskArray) : TimeSeriesUtils.reshapeTimeSeriesMaskToVector(this.maskArray) : null), this.input.size(0)));
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public double calcL2(boolean z) {
        return EvaluationBinary.DEFAULT_EDGE_VALUE;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public double calcL1(boolean z) {
        return EvaluationBinary.DEFAULT_EDGE_VALUE;
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public double f1Score(DataSet dataSet) {
        return EvaluationBinary.DEFAULT_EDGE_VALUE;
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public double f1Score(INDArray iNDArray, INDArray iNDArray2) {
        INDArray activate = activate(iNDArray, false);
        Evaluation evaluation = new Evaluation();
        evaluation.evalTimeSeries(iNDArray2, activate, this.maskArray);
        return evaluation.f1();
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public int numLabels() {
        return this.labels.size(1);
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public void fit(DataSetIterator dataSetIterator) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public int[] predict(INDArray iNDArray) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public List<String> predict(DataSet dataSet) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public INDArray labelProbabilities(INDArray iNDArray) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public void fit(INDArray iNDArray, INDArray iNDArray2) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public void fit(DataSet dataSet) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public void fit(INDArray iNDArray, int[] iArr) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public Layer.Type type() {
        return Layer.Type.RECURRENT;
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public INDArray preOutput(INDArray iNDArray, boolean z) {
        return iNDArray;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z) {
        if (this.input.rank() != 3) {
            throw new UnsupportedOperationException("Input must be rank 3. Got input with rank " + this.input.rank() + " " + layerId());
        }
        return layerConf().getActivationFn().getActivation(this.input.dup(this.input.ordering()), z);
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public void setMaskArray(INDArray iNDArray) {
        this.maskArray = iNDArray;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public boolean isPretrainLayer() {
        return false;
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray iNDArray, MaskState maskState, int i) {
        this.maskArray = TimeSeriesUtils.reshapeTimeSeriesMaskToVector(iNDArray);
        this.maskState = maskState;
        return null;
    }

    @Override // org.deeplearning4j.nn.api.layers.IOutputLayer
    public double computeScore(double d, double d2, boolean z) {
        double computeScore = (layerConf().getLossFn().computeScore(TimeSeriesUtils.reshape3dTo2d(this.labels), TimeSeriesUtils.reshape3dTo2d(this.input).dup(), layerConf().getActivationFn(), this.maskArray != null ? this.maskArray.rank() == 3 ? TimeSeriesUtils.reshapePerOutputTimeSeriesMaskTo2d(this.maskArray) : TimeSeriesUtils.reshapeTimeSeriesMaskToVector(this.maskArray) : null, false) + (d + d2)) / getInputMiniBatchSize();
        this.score = computeScore;
        return computeScore;
    }

    @Override // org.deeplearning4j.nn.api.layers.IOutputLayer
    public INDArray computeScoreForExamples(double d, double d2) {
        if (this.input == null || this.labels == null) {
            throw new IllegalStateException("Cannot calculate score without input and labels " + layerId());
        }
        INDArray sum = TimeSeriesUtils.reshapeVectorToTimeSeriesMask(layerConf().getLossFn().computeScoreArray(TimeSeriesUtils.reshape3dTo2d(this.labels), TimeSeriesUtils.reshape3dTo2d(this.input), layerConf().getActivationFn(), this.maskArray != null ? this.maskArray.rank() == 3 ? TimeSeriesUtils.reshapePerOutputTimeSeriesMaskTo2d(this.maskArray) : TimeSeriesUtils.reshapeTimeSeriesMaskToVector(this.maskArray) : null), this.input.size(0)).sum(new int[]{1});
        double d3 = d + d2;
        if (d3 != EvaluationBinary.DEFAULT_EDGE_VALUE) {
            sum.addi(Double.valueOf(d3));
        }
        return sum;
    }

    @Override // org.deeplearning4j.nn.api.layers.IOutputLayer
    public void setLabels(INDArray iNDArray) {
        this.labels = iNDArray;
    }

    @Override // org.deeplearning4j.nn.api.layers.IOutputLayer
    public INDArray getLabels() {
        return this.labels;
    }
}
