/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.recurrent;

import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseOutputLayer;
import org.deeplearning4j.util.Dropout;
import org.deeplearning4j.util.TimeSeriesUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.SoftMax;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;

public class RnnOutputLayer
extends BaseOutputLayer<org.deeplearning4j.nn.conf.layers.RnnOutputLayer> {
    public RnnOutputLayer(NeuralNetConfiguration conf) {
        super(conf);
    }

    public RnnOutputLayer(NeuralNetConfiguration conf, INDArray input) {
        super(conf, input);
    }

    private INDArray reshape3dTo2d(INDArray in) {
        if (in.rank() != 3) {
            throw new IllegalArgumentException("Invalid input: expect NDArray with rank 3");
        }
        int[] shape = in.shape();
        if (shape[0] == 1) {
            return in.tensorAlongDimension(0, new int[]{1, 2});
        }
        if (shape[2] == 1) {
            return in.tensorAlongDimension(0, new int[]{1, 0});
        }
        INDArray permuted = in.permute(new int[]{0, 2, 1});
        return permuted.reshape(shape[0] * shape[2], shape[1]);
    }

    private INDArray reshape2dTo3d(INDArray in, int miniBatchSize) {
        if (in.rank() != 2) {
            throw new IllegalArgumentException("Invalid input: expect NDArray with rank 2");
        }
        int[] shape = in.shape();
        if (in.ordering() == 'f') {
            in = Shape.toOffsetZeroCopy((INDArray)in, (char)'c');
        }
        INDArray reshaped = in.reshape(new int[]{miniBatchSize, shape[0] / miniBatchSize, shape[1]});
        return reshaped.permute(new int[]{0, 2, 1});
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon) {
        if (this.input.rank() != 3) {
            throw new UnsupportedOperationException("Input is not rank 3");
        }
        INDArray inputTemp = this.input;
        this.input = this.reshape3dTo2d(this.input);
        Pair<Gradient, INDArray> gradAndEpsilonNext = super.backpropGradient(epsilon);
        this.input = inputTemp;
        INDArray epsilon2d = gradAndEpsilonNext.getSecond();
        INDArray epsilon3d = this.reshape2dTo3d(epsilon2d, this.input.size(0));
        return new Pair<Gradient, INDArray>(gradAndEpsilonNext.getFirst(), epsilon3d);
    }

    @Override
    public double f1Score(INDArray examples, INDArray labels) {
        if (examples.rank() == 3) {
            examples = this.reshape3dTo2d(examples);
        }
        if (labels.rank() == 3) {
            labels = this.reshape3dTo2d(labels);
        }
        return super.f1Score(examples, labels);
    }

    @Override
    public INDArray getInput() {
        return this.input;
    }

    @Override
    public Layer.Type type() {
        return Layer.Type.RECURRENT;
    }

    @Override
    public INDArray preOutput(INDArray x, boolean training) {
        this.setInput(x);
        return this.reshape2dTo3d(this.preOutput2d(training), this.input.size(0));
    }

    @Override
    protected INDArray preOutput2d(boolean training) {
        if (this.input.rank() == 3) {
            INDArray inputTemp = this.input;
            this.input = this.reshape3dTo2d(this.input);
            INDArray out = super.preOutput(this.input, training);
            this.input = inputTemp;
            return out;
        }
        INDArray out = super.preOutput(this.input, training);
        return out;
    }

    @Override
    protected INDArray output2d(INDArray input) {
        return this.reshape3dTo2d(this.output(input));
    }

    @Override
    protected INDArray getLabels2d() {
        if (this.labels.rank() == 3) {
            return this.reshape3dTo2d(this.labels);
        }
        return this.labels;
    }

    @Override
    public INDArray output(INDArray input) {
        if (input.rank() != 3) {
            throw new IllegalArgumentException("Input must be rank 3 (is: " + input.rank());
        }
        this.setInput(input);
        return this.output(false);
    }

    @Override
    public INDArray output(boolean training) {
        if (this.input.rank() != 3) {
            throw new IllegalArgumentException("input must be rank 3");
        }
        INDArray preOutput2d = this.preOutput2d(training);
        if (this.conf.getLayer().getActivationFunction().equals("softmax")) {
            SoftMax softMax = new SoftMax(preOutput2d);
            softMax.exec(new int[]{1});
            INDArray out2d = softMax.z();
            if (this.maskArray != null) {
                out2d.muliColumnVector(this.maskArray);
            }
            return this.reshape2dTo3d(out2d, this.input.size(0));
        }
        if (training) {
            this.applyDropOutIfNecessary(training);
        }
        INDArray origInput = this.input;
        this.input = this.reshape3dTo2d(this.input);
        INDArray out = super.activate(true);
        this.input = origInput;
        return this.reshape2dTo3d(out, this.input.size(0));
    }

    @Override
    public INDArray activate(boolean training) {
        if (this.input.rank() != 3) {
            throw new UnsupportedOperationException("Input must be rank 3");
        }
        INDArray b = this.getParam("b");
        INDArray W = this.getParam("W");
        if (this.conf.isUseDropConnect() && training) {
            W = Dropout.applyDropConnect(this, "W");
        }
        INDArray input2d = this.reshape3dTo2d(this.input);
        INDArray act2d = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf.getLayer().getActivationFunction(), input2d.mmul(W).addiRowVector(b)));
        return this.reshape2dTo3d(act2d, this.input.size(0));
    }

    @Override
    public void setMaskArray(INDArray maskArray) {
        if (maskArray != null && maskArray.size(1) != 1) {
            maskArray = TimeSeriesUtils.reshapeTimeSeriesMaskToVector(maskArray);
        }
        this.maskArray = maskArray;
    }
}

