/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.feedforward.autoencoder.recursive;

import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

public class RecursiveAutoEncoder
extends BaseLayer<org.deeplearning4j.nn.conf.layers.RecursiveAutoEncoder> {
    private INDArray currInput = null;
    private INDArray allInput = null;
    private INDArray visibleLoss = null;
    private INDArray hiddenLoss = null;
    private INDArray cLoss = null;
    private INDArray bLoss = null;
    private INDArray y = null;
    double currScore = 0.0;

    public RecursiveAutoEncoder(NeuralNetConfiguration conf) {
        super(conf);
    }

    @Override
    public Layer.Type type() {
        return Layer.Type.RECURSIVE;
    }

    @Override
    public double score() {
        return this.currScore;
    }

    @Deprecated
    private double scoreSnapShot() {
        return 0.5 * Transforms.pow((INDArray)this.y.sub(this.allInput), (Number)2).mean(new int[]{Integer.MAX_VALUE}).getDouble(0);
    }

    @Override
    public void computeGradientAndScore() {
        this.gradient();
        this.score = 0.5 * Transforms.pow((INDArray)this.y.sub(this.allInput), (Number)2).mean(new int[]{Integer.MAX_VALUE}).getDouble(0);
    }

    @Override
    public INDArray activate(INDArray data) {
        INDArray w = this.getParam("W");
        INDArray c = this.getParam("U");
        INDArray inputTimesW = data.mmul(w);
        return Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf.getLayer().getActivationFunction(), inputTimesW.addiRowVector(c)));
    }

    public INDArray decode(INDArray input) {
        return Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf.getLayer().getActivationFunction(), input.mmul(((INDArray)this.params.get("U")).addiRowVector((INDArray)this.params.get("b")))));
    }

    @Override
    public void iterate(INDArray input) {
    }

    @Override
    public Gradient gradient() {
        this.currScore = 0.0;
        for (int i = 0; i < this.input.rows(); ++i) {
            INDArray currBLoss;
            INDArray combined;
            INDArray iNDArray = combined = this.currInput == null ? Nd4j.concat((int)0, (INDArray[])new INDArray[]{this.input.slice(i), this.input.slice(i + 1)}) : Nd4j.concat((int)0, (INDArray[])new INDArray[]{this.input.slice(i), this.currInput});
            if (i == 0) {
                ++i;
            }
            this.currInput = combined;
            this.allInput = combined;
            INDArray encoded = this.activate(combined);
            this.y = this.decode(encoded);
            INDArray currVisibleLoss = this.currInput.sub(this.y);
            INDArray currHiddenLoss = currVisibleLoss.mmul(this.getParam("W")).muli(encoded).muli(encoded.rsub((Number)1));
            INDArray hiddenGradient = this.y.transpose().mmul(currHiddenLoss);
            INDArray visibleGradient = encoded.transpose().mmul(currVisibleLoss);
            if (this.visibleLoss == null) {
                this.visibleLoss = visibleGradient;
            } else {
                this.visibleLoss.addi(visibleGradient);
            }
            if (this.hiddenLoss == null) {
                this.hiddenLoss = hiddenGradient;
            } else {
                this.hiddenLoss.addi(hiddenGradient);
            }
            INDArray currCLoss = currVisibleLoss.isMatrix() ? currVisibleLoss.mean(new int[]{0}) : currVisibleLoss;
            INDArray iNDArray2 = currBLoss = currHiddenLoss.isMatrix() ? currHiddenLoss.mean(new int[]{0}) : currHiddenLoss;
            if (this.cLoss == null) {
                this.cLoss = currCLoss;
            } else {
                this.cLoss.addi(currCLoss);
            }
            if (this.bLoss == null) {
                this.bLoss = currBLoss;
            } else {
                this.bLoss.addi(currBLoss);
            }
            this.currInput = encoded;
            this.currScore += this.scoreSnapShot();
        }
        return this.createGradient(this.hiddenLoss, this.visibleLoss, this.cLoss, this.bLoss);
    }
}

