package org.deeplearning4j.da;

import java.io.Serializable;
import org.apache.commons.math3.distribution.RealDistribution;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.BaseNeuralNetwork;
import org.deeplearning4j.nn.gradient.NeuralNetworkGradient;
import org.deeplearning4j.sda.DenoisingAutoEncoderOptimizer;
import org.deeplearning4j.util.MathUtils;
import org.deeplearning4j.util.MatrixUtil;
import org.jblas.DoubleMatrix;

/* loaded from: input_file:org/deeplearning4j/da/DenoisingAutoEncoder.class */
public class DenoisingAutoEncoder extends BaseNeuralNetwork implements Serializable {
    private static final long serialVersionUID = -6445530486350763837L;

    /* loaded from: input_file:org/deeplearning4j/da/DenoisingAutoEncoder$Builder.class */
    public static class Builder extends BaseNeuralNetwork.Builder<DenoisingAutoEncoder> {
        public Builder() {
            this.clazz = DenoisingAutoEncoder.class;
        }
    }

    public DenoisingAutoEncoder() {
    }

    public DenoisingAutoEncoder(DoubleMatrix doubleMatrix, int i, int i2, DoubleMatrix doubleMatrix2, DoubleMatrix doubleMatrix3, DoubleMatrix doubleMatrix4, RandomGenerator randomGenerator, double d, RealDistribution realDistribution) {
        super(doubleMatrix, i, i2, doubleMatrix2, doubleMatrix3, doubleMatrix4, randomGenerator, d, realDistribution);
    }

    public DoubleMatrix getCorruptedInput(DoubleMatrix doubleMatrix, double d) {
        DoubleMatrix zeros = DoubleMatrix.zeros(doubleMatrix.rows, doubleMatrix.columns);
        for (int i = 0; i < doubleMatrix.rows; i++) {
            for (int i2 = 0; i2 < doubleMatrix.columns; i2++) {
                zeros.put(i, i2, MathUtils.binomial(this.rng, 1, 1.0d - d));
            }
        }
        return zeros.mul(doubleMatrix);
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public Pair<DoubleMatrix, DoubleMatrix> sampleHiddenGivenVisible(DoubleMatrix doubleMatrix) {
        DoubleMatrix hiddenValues = getHiddenValues(doubleMatrix);
        return new Pair<>(hiddenValues, hiddenValues);
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public Pair<DoubleMatrix, DoubleMatrix> sampleVisibleGivenHidden(DoubleMatrix doubleMatrix) {
        DoubleMatrix reconstructedInput = getReconstructedInput(doubleMatrix);
        return new Pair<>(reconstructedInput, reconstructedInput);
    }

    public DoubleMatrix getHiddenValues(DoubleMatrix doubleMatrix) {
        DoubleMatrix sigmoid = MatrixUtil.sigmoid(doubleMatrix.mmul(this.W).addRowVector(this.hBias));
        applyDropOutIfNecessary(sigmoid);
        return sigmoid;
    }

    public DoubleMatrix getReconstructedInput(DoubleMatrix doubleMatrix) {
        return MatrixUtil.sigmoid(doubleMatrix.mmul(this.W.transpose()).addRowVector(this.vBias));
    }

    public void trainTillConvergence(DoubleMatrix doubleMatrix, double d, double d2) {
        if (doubleMatrix != null) {
            this.input = doubleMatrix;
        }
        this.optimizer = new DenoisingAutoEncoderOptimizer(this, d, new Object[]{Double.valueOf(d2)}, this.optimizationAlgo, this.lossFunction);
        this.optimizer.train(doubleMatrix);
    }

    public void train(DoubleMatrix doubleMatrix, double d, double d2) {
        this.input = doubleMatrix;
        NeuralNetworkGradient gradient = getGradient(new Object[]{Double.valueOf(d2), Double.valueOf(d)});
        this.vBias.addi(gradient.getvBiasGradient());
        this.W.addi(gradient.getwGradient());
        this.hBias.addi(gradient.gethBiasGradient());
    }

    @Override // org.deeplearning4j.nn.BaseNeuralNetwork
    public DoubleMatrix reconstruct(DoubleMatrix doubleMatrix) {
        return getReconstructedInput(getHiddenValues(doubleMatrix));
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public void trainTillConvergence(DoubleMatrix doubleMatrix, double d, Object[] objArr) {
        if (doubleMatrix != null) {
            this.input = doubleMatrix;
        }
        this.optimizer = new DenoisingAutoEncoderOptimizer(this, d, objArr, this.optimizationAlgo, this.lossFunction);
        this.optimizer.train(doubleMatrix);
    }

    @Override // org.deeplearning4j.nn.BaseNeuralNetwork
    public double lossFunction(Object[] objArr) {
        return negativeLogLikelihood();
    }

    @Override // org.deeplearning4j.nn.BaseNeuralNetwork, org.deeplearning4j.nn.NeuralNetwork
    public void train(DoubleMatrix doubleMatrix, double d, Object[] objArr) {
        double doubleValue = ((Double) objArr[0]).doubleValue();
        this.input = doubleMatrix;
        NeuralNetworkGradient gradient = getGradient(new Object[]{Double.valueOf(doubleValue), Double.valueOf(d)});
        this.vBias.addi(gradient.getvBiasGradient());
        this.W.addi(gradient.getwGradient());
        this.hBias.addi(gradient.gethBiasGradient());
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public synchronized NeuralNetworkGradient getGradient(Object[] objArr) {
        double doubleValue = ((Double) objArr[0]).doubleValue();
        double doubleValue2 = ((Double) objArr[1]).doubleValue();
        if (this.wAdaGrad != null) {
            this.wAdaGrad.setMasterStepSize(doubleValue2);
        }
        if (this.hBiasAdaGrad != null) {
            this.hBiasAdaGrad.setMasterStepSize(doubleValue2);
        }
        if (this.vBiasAdaGrad != null) {
            this.vBiasAdaGrad.setMasterStepSize(doubleValue2);
        }
        DoubleMatrix corruptedInput = getCorruptedInput(this.input, doubleValue);
        DoubleMatrix hiddenValues = getHiddenValues(corruptedInput);
        DoubleMatrix sub = this.input.sub(getReconstructedInput(hiddenValues));
        DoubleMatrix mul = this.sparsity == 0.0d ? sub.mmul(this.W).mul(hiddenValues).mul(MatrixUtil.oneMinus(hiddenValues)) : sub.mmul(this.W).mul(hiddenValues).mul(hiddenValues.add(-this.sparsity));
        NeuralNetworkGradient neuralNetworkGradient = new NeuralNetworkGradient(corruptedInput.transpose().mmul(mul).add(sub.transpose().mmul(hiddenValues)), sub.columnMeans(), mul.columnMeans());
        triggerGradientEvents(neuralNetworkGradient);
        updateGradientAccordingToParams(neuralNetworkGradient, doubleValue2);
        return neuralNetworkGradient;
    }
}
