package org.deeplearning4j.models.featuredetectors.autoencoder;

import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BasePretrainNetwork;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.params.PretrainParamInitializer;
import org.deeplearning4j.util.MathUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/models/featuredetectors/autoencoder/AutoEncoder.class */
public class AutoEncoder extends BasePretrainNetwork {
    private static final long serialVersionUID = -6445530486350763837L;

    public AutoEncoder(NeuralNetConfiguration neuralNetConfiguration) {
        super(neuralNetConfiguration);
    }

    public AutoEncoder(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        super(neuralNetConfiguration, iNDArray);
    }

    public INDArray getCorruptedInput(INDArray iNDArray, double d) {
        INDArray zeros = Nd4j.zeros(iNDArray.rows(), iNDArray.columns());
        for (int i = 0; i < iNDArray.rows(); i++) {
            for (int i2 = 0; i2 < iNDArray.columns(); i2++) {
                zeros.put(i, i2, Integer.valueOf(MathUtils.binomial(this.conf.getRng(), 1, 1.0d - d)));
            }
        }
        zeros.muli(iNDArray);
        return zeros;
    }

    @Override // org.deeplearning4j.nn.layers.BasePretrainNetwork
    public Pair<INDArray, INDArray> sampleHiddenGivenVisible(INDArray iNDArray) {
        INDArray encode = encode(iNDArray);
        return new Pair<>(encode, encode);
    }

    @Override // org.deeplearning4j.nn.layers.BasePretrainNetwork
    public Pair<INDArray, INDArray> sampleVisibleGivenHidden(INDArray iNDArray) {
        INDArray decode = decode(iNDArray);
        return new Pair<>(decode, decode);
    }

    public INDArray encode(INDArray iNDArray) {
        INDArray param = getParam(DefaultParamInitializer.WEIGHT_KEY);
        INDArray param2 = getParam("b");
        return (INDArray) this.conf.getActivationFunction().apply(this.conf.isConcatBiases() ? iNDArray.mmul(Nd4j.hstack(new INDArray[]{param, param2.transpose()})) : iNDArray.mmul(param).addiRowVector(param2));
    }

    public INDArray decode(INDArray iNDArray) {
        INDArray param = getParam(DefaultParamInitializer.WEIGHT_KEY);
        INDArray param2 = getParam(PretrainParamInitializer.VISIBLE_BIAS_KEY);
        if (this.conf.isConcatBiases()) {
            INDArray mmul = iNDArray.mmul(param.transpose());
            return (INDArray) this.conf.getActivationFunction().apply(Nd4j.hstack(new INDArray[]{mmul, Nd4j.ones(mmul.rows(), 1)}));
        }
        INDArray mmul2 = iNDArray.mmul(param.transpose());
        mmul2.addiRowVector(param2);
        return (INDArray) this.conf.getActivationFunction().apply(mmul2);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public INDArray transform(INDArray iNDArray) {
        return decode(encode(iNDArray));
    }

    @Override // org.deeplearning4j.nn.api.Model
    public Gradient getGradient() {
        INDArray param = getParam(DefaultParamInitializer.WEIGHT_KEY);
        double corruptionLevel = this.conf.getCorruptionLevel();
        INDArray corruptedInput = corruptionLevel > 0.0d ? getCorruptedInput(this.input, corruptionLevel) : this.input;
        INDArray encode = encode(corruptedInput);
        INDArray sub = this.input.sub(decode(encode));
        INDArray muli = this.conf.getSparsity() == 0.0d ? sub.mmul(param).muli(encode).muli(encode.rsub(1)) : sub.mmul(param).muli(encode).muli(encode.addi(Double.valueOf(-this.conf.getSparsity())));
        return createGradient(corruptedInput.transpose().mmul(muli).addi(sub.transpose().mmul(encode)), sub.mean(0), muli.mean(0));
    }
}
