/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers;

import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.deeplearning4j.optimize.Solver;
import org.nd4j.linalg.api.activation.ActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class BasePretrainNetwork
extends BaseLayer {
    private static final long serialVersionUID = -7074102204433996574L;
    protected INDArray doMask;
    private static Logger log = LoggerFactory.getLogger(BasePretrainNetwork.class);

    public BasePretrainNetwork(NeuralNetConfiguration conf) {
        super(conf);
    }

    public BasePretrainNetwork(NeuralNetConfiguration conf, INDArray input) {
        super(conf, input);
    }

    protected void applySparsity(INDArray hBiasGradient) {
        INDArray change = hBiasGradient.mul((Number)this.conf.getSparsity()).mul((Number)(-this.conf.getLr() * this.conf.getSparsity()));
        hBiasGradient.addi(change);
    }

    @Override
    public double score() {
        if (this.conf.getLossFunction() != LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY) {
            return -LossFunctions.score((INDArray)this.input, (LossFunctions.LossFunction)this.conf.getLossFunction(), (INDArray)this.transform(this.input), (double)this.conf.getL2(), (boolean)this.conf.isUseRegularization());
        }
        return -LossFunctions.reconEntropy((INDArray)this.input, (INDArray)this.getParam("b"), (INDArray)this.getParam("vb"), (INDArray)this.getParam("W"), (ActivationFunction)this.conf.getActivationFunction());
    }

    @Override
    public void update(Gradient gradient) {
        this.setParams(this.params().addi(gradient.gradient()));
    }

    @Override
    public void iterate(INDArray input) {
        this.input = input;
        Gradient gradient = this.getGradient();
        this.update(gradient);
    }

    protected Gradient createGradient(INDArray wGradient, INDArray vBiasGradient, INDArray hBiasGradient) {
        DefaultGradient ret = new DefaultGradient();
        ret.gradientLookupTable().put("vb", vBiasGradient);
        ret.gradientLookupTable().put("b", hBiasGradient);
        ret.gradientLookupTable().put("W", wGradient);
        return ret;
    }

    @Override
    protected void applyDropOutIfNecessary(INDArray input) {
        this.doMask = this.conf.getDropOut() > 0.0 ? Nd4j.rand((int)input.rows(), (int)input.columns()).gt((Number)this.conf.getDropOut()) : Nd4j.ones((int)input.rows(), (int)input.columns());
        input.muli(this.doMask);
    }

    @Override
    public void fit() {
        Solver solver = new Solver.Builder().model(this).configure(this.conf()).listeners(this.conf.getListeners()).build();
        solver.optimize();
    }

    protected INDArray preProcessInput(INDArray input) {
        if (this.conf.isConcatBiases()) {
            return Nd4j.hstack((INDArray[])new INDArray[]{input, Nd4j.ones((int)input.rows(), (int)1)});
        }
        return input;
    }

    public abstract Pair<INDArray, INDArray> sampleHiddenGivenVisible(INDArray var1);

    public abstract Pair<INDArray, INDArray> sampleVisibleGivenHidden(INDArray var1);
}

