package org.deeplearning4j.nn;

import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.lang.reflect.Constructor;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.distribution.RealDistribution;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.gradient.NeuralNetworkGradientListener;
import org.deeplearning4j.nn.NeuralNetwork;
import org.deeplearning4j.nn.gradient.NeuralNetworkGradient;
import org.deeplearning4j.nn.learning.AdaGrad;
import org.deeplearning4j.optimize.NeuralNetworkOptimizer;
import org.deeplearning4j.plot.NeuralNetPlotter;
import org.deeplearning4j.util.MatrixUtil;
import org.jblas.DoubleMatrix;
import org.jblas.MatrixFunctions;

/* loaded from: input_file:org/deeplearning4j/nn/BaseNeuralNetwork.class */
public abstract class BaseNeuralNetwork implements NeuralNetwork, Persistable {
    private static final long serialVersionUID = -7074102204433996574L;
    public int nVisible;
    protected int nHidden;
    protected DoubleMatrix W;
    protected DoubleMatrix hBias;
    protected DoubleMatrix vBias;
    protected RandomGenerator rng;
    protected DoubleMatrix input;
    protected double sparsity;
    protected double momentum;
    protected transient RealDistribution dist;
    protected double l2;
    protected transient NeuralNetworkOptimizer optimizer;
    protected int renderWeightsEveryNumEpochs;
    protected double fanIn;
    protected boolean useRegularization;
    protected boolean useAdaGrad;
    protected boolean firstTimeThrough;
    protected boolean normalizeByInputRows;
    protected boolean applySparsity;
    protected List<NeuralNetworkGradientListener> gradientListeners;
    protected double dropOut;
    protected DoubleMatrix doMask;
    protected NeuralNetwork.OptimizationAlgorithm optimizationAlgo;
    protected NeuralNetwork.LossFunction lossFunction;
    protected AdaGrad wAdaGrad;
    protected AdaGrad hBiasAdaGrad;
    protected AdaGrad vBiasAdaGrad;

    /* loaded from: input_file:org/deeplearning4j/nn/BaseNeuralNetwork$Builder.class */
    public static class Builder<E extends BaseNeuralNetwork> {
        private DoubleMatrix W;
        protected Class<? extends NeuralNetwork> clazz;
        private DoubleMatrix vBias;
        private DoubleMatrix hBias;
        private int numVisible;
        private int numHidden;
        private DoubleMatrix input;
        private RealDistribution dist;
        private E ret = null;
        private RandomGenerator gen = new MersenneTwister(123);
        private double sparsity = 0.01d;
        private double l2 = 0.01d;
        private double momentum = 0.5d;
        private int renderWeightsEveryNumEpochs = -1;
        private double fanIn = 0.1d;
        private boolean useRegularization = false;
        private boolean useAdaGrad = false;
        private boolean normalizeByInputRows = false;
        private double dropOut = 0.0d;
        private NeuralNetwork.LossFunction lossFunction = NeuralNetwork.LossFunction.RECONSTRUCTION_CROSSENTROPY;
        private NeuralNetwork.OptimizationAlgorithm optimizationAlgo = NeuralNetwork.OptimizationAlgorithm.CONJUGATE_GRADIENT;

        public Builder<E> withOptmizationAlgo(NeuralNetwork.OptimizationAlgorithm optimizationAlgorithm) {
            this.optimizationAlgo = optimizationAlgorithm;
            return this;
        }

        public Builder<E> withLossFunction(NeuralNetwork.LossFunction lossFunction) {
            this.lossFunction = lossFunction;
            return this;
        }

        public Builder<E> withDropOut(double d) {
            this.dropOut = d;
            return this;
        }

        public Builder<E> normalizeByInputRows(boolean z) {
            this.normalizeByInputRows = z;
            return this;
        }

        public Builder<E> useAdaGrad(boolean z) {
            this.useAdaGrad = z;
            return this;
        }

        public Builder<E> withDistribution(RealDistribution realDistribution) {
            this.dist = realDistribution;
            return this;
        }

        public Builder<E> useRegularization(boolean z) {
            this.useRegularization = z;
            return this;
        }

        public Builder<E> fanIn(double d) {
            this.fanIn = d;
            return this;
        }

        public Builder<E> withL2(double d) {
            this.l2 = d;
            return this;
        }

        public Builder<E> renderWeights(int i) {
            this.renderWeightsEveryNumEpochs = i;
            return this;
        }

        public E buildEmpty() {
            try {
                return (E) this.clazz.newInstance();
            } catch (IllegalAccessException | InstantiationException e) {
                throw new RuntimeException(e);
            }
        }

        public Builder<E> withClazz(Class<? extends BaseNeuralNetwork> cls) {
            this.clazz = cls;
            return this;
        }

        public Builder<E> withSparsity(double d) {
            this.sparsity = d;
            return this;
        }

        public Builder<E> withMomentum(double d) {
            this.momentum = d;
            return this;
        }

        public Builder<E> withInput(DoubleMatrix doubleMatrix) {
            this.input = doubleMatrix;
            return this;
        }

        /* JADX WARN: Multi-variable type inference failed */
        public Builder<E> asType(Class<E> cls) {
            this.clazz = cls;
            return this;
        }

        public Builder<E> withWeights(DoubleMatrix doubleMatrix) {
            this.W = doubleMatrix;
            return this;
        }

        public Builder<E> withVisibleBias(DoubleMatrix doubleMatrix) {
            this.vBias = doubleMatrix;
            return this;
        }

        public Builder<E> withHBias(DoubleMatrix doubleMatrix) {
            this.hBias = doubleMatrix;
            return this;
        }

        public Builder<E> numberOfVisible(int i) {
            this.numVisible = i;
            return this;
        }

        public Builder<E> numHidden(int i) {
            this.numHidden = i;
            return this;
        }

        public Builder<E> withRandom(RandomGenerator randomGenerator) {
            this.gen = randomGenerator;
            return this;
        }

        public E build() {
            return buildWithInput();
        }

        private E buildWithInput() {
            for (Constructor<?> constructor : this.clazz.getDeclaredConstructors()) {
                constructor.setAccessible(true);
                Class<?>[] parameterTypes = constructor.getParameterTypes();
                if (parameterTypes != null && parameterTypes.length > 0 && parameterTypes[0].isAssignableFrom(DoubleMatrix.class)) {
                    try {
                        this.ret = (E) constructor.newInstance(this.input, Integer.valueOf(this.numVisible), Integer.valueOf(this.numHidden), this.W, this.hBias, this.vBias, this.gen, Double.valueOf(this.fanIn), this.dist);
                        this.ret.sparsity = this.sparsity;
                        this.ret.normalizeByInputRows = this.normalizeByInputRows;
                        this.ret.renderWeightsEveryNumEpochs = this.renderWeightsEveryNumEpochs;
                        this.ret.l2 = this.l2;
                        this.ret.momentum = this.momentum;
                        this.ret.useRegularization = this.useRegularization;
                        this.ret.useAdaGrad = this.useAdaGrad;
                        this.ret.dropOut = this.dropOut;
                        this.ret.optimizationAlgo = this.optimizationAlgo;
                        this.ret.lossFunction = this.lossFunction;
                        return this.ret;
                    } catch (Exception e) {
                        throw new RuntimeException(e);
                    }
                }
            }
            return this.ret;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public BaseNeuralNetwork() {
        this.sparsity = 0.0d;
        this.momentum = 0.5d;
        this.dist = new NormalDistribution(this.rng, 0.0d, 0.01d, 1.0E-9d);
        this.l2 = 0.1d;
        this.renderWeightsEveryNumEpochs = -1;
        this.fanIn = -1.0d;
        this.useRegularization = false;
        this.useAdaGrad = false;
        this.firstTimeThrough = false;
        this.normalizeByInputRows = false;
        this.applySparsity = true;
        this.dropOut = 0.0d;
    }

    public BaseNeuralNetwork(int i, int i2, DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2, DoubleMatrix doubleMatrix3, RandomGenerator randomGenerator, double d, RealDistribution realDistribution) {
        this(null, i, i2, doubleMatrix, doubleMatrix2, doubleMatrix3, randomGenerator, d, realDistribution);
    }

    public BaseNeuralNetwork(DoubleMatrix doubleMatrix, int i, int i2, DoubleMatrix doubleMatrix2, DoubleMatrix doubleMatrix3, DoubleMatrix doubleMatrix4, RandomGenerator randomGenerator, double d, RealDistribution realDistribution) {
        this.sparsity = 0.0d;
        this.momentum = 0.5d;
        this.dist = new NormalDistribution(this.rng, 0.0d, 0.01d, 1.0E-9d);
        this.l2 = 0.1d;
        this.renderWeightsEveryNumEpochs = -1;
        this.fanIn = -1.0d;
        this.useRegularization = false;
        this.useAdaGrad = false;
        this.firstTimeThrough = false;
        this.normalizeByInputRows = false;
        this.applySparsity = true;
        this.dropOut = 0.0d;
        this.nVisible = i;
        if (realDistribution != null) {
            this.dist = realDistribution;
        } else {
            this.dist = new NormalDistribution(randomGenerator, 0.0d, 0.01d, 1.0E-9d);
        }
        this.nHidden = i2;
        this.fanIn = d;
        this.input = doubleMatrix;
        if (randomGenerator == null) {
            this.rng = new MersenneTwister(1234);
        } else {
            this.rng = randomGenerator;
        }
        this.W = doubleMatrix2;
        if (this.W != null) {
            this.wAdaGrad = new AdaGrad(this.W.rows, this.W.columns);
        }
        this.vBias = doubleMatrix4;
        if (this.vBias != null) {
            this.vBiasAdaGrad = new AdaGrad(this.vBias.rows, this.vBias.columns);
        }
        this.hBias = doubleMatrix3;
        if (this.hBias != null) {
            this.hBiasAdaGrad = new AdaGrad(this.hBias.rows, this.hBias.columns);
        }
        initWeights();
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public double l2RegularizedCoefficient() {
        return (MatrixFunctions.pow(getW(), 2.0d).sum() / 2.0d) * this.l2;
    }

    protected void initWeights() {
        if (this.nVisible < 1) {
            throw new IllegalStateException("Number of visible can not be less than 1");
        }
        if (this.nHidden < 1) {
            throw new IllegalStateException("Number of hidden can not be less than 1");
        }
        if (this.dist == null) {
            this.dist = new NormalDistribution(this.rng, 0.0d, 0.01d, 1.0E-9d);
        }
        if (this.W == null) {
            this.W = DoubleMatrix.zeros(this.nVisible, this.nHidden);
            for (int i = 0; i < this.W.rows; i++) {
                this.W.putRow(i, new DoubleMatrix(this.dist.sample(this.W.columns)));
            }
        }
        this.wAdaGrad = new AdaGrad(this.W.rows, this.W.columns);
        if (this.hBias == null) {
            this.hBias = DoubleMatrix.zeros(this.nHidden);
        }
        this.hBiasAdaGrad = new AdaGrad(this.hBias.rows, this.hBias.columns);
        if (this.vBias == null) {
            if (this.input != null) {
                this.vBias = DoubleMatrix.zeros(this.nVisible);
            } else {
                this.vBias = DoubleMatrix.zeros(this.nVisible);
            }
        }
        this.vBiasAdaGrad = new AdaGrad(this.vBias.rows, this.vBias.columns);
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public void resetAdaGrad(double d) {
        if (this.firstTimeThrough) {
            return;
        }
        this.wAdaGrad = new AdaGrad(getW().rows, getW().columns, d);
        this.firstTimeThrough = false;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public List<NeuralNetworkGradientListener> getGradientListeners() {
        return this.gradientListeners;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public synchronized void setGradientListeners(List<NeuralNetworkGradientListener> list) {
        this.gradientListeners = list;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public void setRenderEpochs(int i) {
        this.renderWeightsEveryNumEpochs = i;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public int getRenderEpochs() {
        return this.renderWeightsEveryNumEpochs;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public double fanIn() {
        return this.fanIn < 0.0d ? 1 / this.nVisible : this.fanIn;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public void setFanIn(double d) {
        this.fanIn = d;
    }

    public void jostleWeighMatrix() {
        DoubleMatrix zeros = DoubleMatrix.zeros(this.nVisible, this.nHidden);
        for (int i = 0; i < this.W.rows; i++) {
            zeros.putRow(i, new DoubleMatrix(this.dist.sample(this.W.columns)));
        }
    }

    protected void applySparsity(DoubleMatrix doubleMatrix, double d) {
        if (this.useAdaGrad) {
            doubleMatrix.addi(this.hBiasAdaGrad.getLearningRates(this.hBias).neg().mul(this.sparsity).mul(doubleMatrix.mul(this.sparsity)));
        } else {
            doubleMatrix.addi(doubleMatrix.mul(this.sparsity).mul((-d) * this.sparsity));
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void updateGradientAccordingToParams(NeuralNetworkGradient neuralNetworkGradient, double d) {
        DoubleMatrix doubleMatrix = neuralNetworkGradient.getwGradient();
        DoubleMatrix doubleMatrix2 = neuralNetworkGradient.gethBiasGradient();
        DoubleMatrix doubleMatrix3 = neuralNetworkGradient.getvBiasGradient();
        if (this.useAdaGrad) {
            doubleMatrix.muli(this.wAdaGrad.getLearningRates(doubleMatrix));
        } else {
            doubleMatrix.muli(d);
        }
        if (this.useRegularization) {
            doubleMatrix.subi(this.W.muli(this.l2));
        }
        if (this.momentum != 0.0d) {
            doubleMatrix.addi(doubleMatrix.mul(this.momentum).add(doubleMatrix.mul(1.0d - this.momentum)));
        }
        DoubleMatrix add = this.useAdaGrad ? doubleMatrix2.mul(this.hBiasAdaGrad.getLearningRates(doubleMatrix2)).add(doubleMatrix2.mul(this.momentum)) : doubleMatrix2.mul(d).add(doubleMatrix2.mul(this.momentum));
        DoubleMatrix add2 = this.useAdaGrad ? doubleMatrix3.mul(this.vBiasAdaGrad.getLearningRates(doubleMatrix3)).add(doubleMatrix3.mul(this.momentum)) : doubleMatrix3.mul(d).add(doubleMatrix3.mul(this.momentum));
        if (this.applySparsity) {
            applySparsity(add, d);
        }
        if (this.normalizeByInputRows) {
            doubleMatrix.divi(this.input.rows);
            add2.divi(this.input.rows);
            add.divi(this.input.rows);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void triggerGradientEvents(NeuralNetworkGradient neuralNetworkGradient) {
        if (this.gradientListeners == null || this.gradientListeners.isEmpty()) {
            return;
        }
        Iterator<NeuralNetworkGradientListener> it = this.gradientListeners.iterator();
        while (it.hasNext()) {
            it.next().onGradient(neuralNetworkGradient);
        }
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public void setDropOut(double d) {
        this.dropOut = d;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public double dropOut() {
        return this.dropOut;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public AdaGrad getAdaGrad() {
        return this.wAdaGrad;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public void setAdaGrad(AdaGrad adaGrad) {
        this.wAdaGrad = adaGrad;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public NeuralNetwork transpose() {
        try {
            NeuralNetwork neuralNetwork = (NeuralNetwork) getClass().newInstance();
            neuralNetwork.sethBias(this.hBias.dup());
            neuralNetwork.setvBias(this.vBias.dup());
            neuralNetwork.setnHidden(getnVisible());
            neuralNetwork.setnVisible(getnHidden());
            neuralNetwork.setW(this.W.transpose());
            neuralNetwork.setRng(getRng());
            neuralNetwork.setAdaGrad(this.wAdaGrad);
            neuralNetwork.setDist(getDist());
            return neuralNetwork;
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public NeuralNetwork m11clone() {
        try {
            Constructor<?> constructor = getClass().getDeclaredConstructors()[0];
            constructor.setAccessible(true);
            NeuralNetwork neuralNetwork = (NeuralNetwork) constructor.newInstance(new Object[0]);
            neuralNetwork.setHbiasAdaGrad(this.hBiasAdaGrad);
            neuralNetwork.setVBiasAdaGrad(this.vBiasAdaGrad);
            neuralNetwork.sethBias(this.hBias.dup());
            neuralNetwork.setvBias(this.vBias.dup());
            neuralNetwork.setnHidden(getnHidden());
            neuralNetwork.setnVisible(getnVisible());
            neuralNetwork.setW(this.W.dup());
            neuralNetwork.setL2(this.l2);
            neuralNetwork.setMomentum(this.momentum);
            neuralNetwork.setRenderEpochs(getRenderEpochs());
            neuralNetwork.setSparsity(this.sparsity);
            neuralNetwork.setRng(getRng());
            neuralNetwork.setDist(getDist());
            neuralNetwork.setAdaGrad(this.wAdaGrad);
            neuralNetwork.setLossFunction(this.lossFunction);
            neuralNetwork.setOptimizationAlgorithm(this.optimizationAlgo);
            return neuralNetwork;
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public NeuralNetwork.LossFunction getLossFunction() {
        return this.lossFunction;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public void setLossFunction(NeuralNetwork.LossFunction lossFunction) {
        this.lossFunction = lossFunction;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public NeuralNetwork.OptimizationAlgorithm getOptimizationAlgorithm() {
        return this.optimizationAlgo;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public void setOptimizationAlgorithm(NeuralNetwork.OptimizationAlgorithm optimizationAlgorithm) {
        this.optimizationAlgo = optimizationAlgorithm;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public RealDistribution getDist() {
        return this.dist;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public void setDist(RealDistribution realDistribution) {
        this.dist = realDistribution;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public void merge(NeuralNetwork neuralNetwork, int i) {
        this.W.addi(neuralNetwork.getW().sub(this.W).div(i));
        this.hBias.addi(neuralNetwork.gethBias().sub(this.hBias).divi(i));
        this.vBias.addi(neuralNetwork.getvBias().subi(this.vBias).divi(i));
    }

    public void update(BaseNeuralNetwork baseNeuralNetwork) {
        this.W = baseNeuralNetwork.W;
        this.normalizeByInputRows = baseNeuralNetwork.normalizeByInputRows;
        this.hBias = baseNeuralNetwork.hBias;
        this.vBias = baseNeuralNetwork.vBias;
        this.l2 = baseNeuralNetwork.l2;
        this.useRegularization = baseNeuralNetwork.useRegularization;
        this.momentum = baseNeuralNetwork.momentum;
        this.nHidden = baseNeuralNetwork.nHidden;
        this.nVisible = baseNeuralNetwork.nVisible;
        this.rng = baseNeuralNetwork.rng;
        this.sparsity = baseNeuralNetwork.sparsity;
        this.wAdaGrad = baseNeuralNetwork.wAdaGrad;
        this.hBiasAdaGrad = baseNeuralNetwork.hBiasAdaGrad;
        this.vBiasAdaGrad = baseNeuralNetwork.vBiasAdaGrad;
        this.optimizationAlgo = baseNeuralNetwork.optimizationAlgo;
        this.lossFunction = baseNeuralNetwork.lossFunction;
    }

    @Override // org.deeplearning4j.nn.Persistable
    public void load(InputStream inputStream) {
        try {
            update((BaseNeuralNetwork) new ObjectInputStream(inputStream).readObject());
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public double negativeLogLikelihood() {
        DoubleMatrix reconstruct = reconstruct(this.input);
        if (!this.useRegularization) {
            double d = -this.input.mul(MatrixUtil.log(reconstruct)).add(MatrixUtil.oneMinus(this.input).mul(MatrixUtil.log(MatrixUtil.oneMinus(reconstruct)))).columnSums().mean();
            if (this.normalizeByInputRows) {
                d /= this.input.rows;
            }
            return d;
        }
        double sum = (-this.input.mul(MatrixUtil.log(reconstruct)).add(MatrixUtil.oneMinus(this.input).mul(MatrixUtil.log(MatrixUtil.oneMinus(reconstruct)))).columnSums().mean()) + ((2.0d / this.l2) * MatrixFunctions.pow(this.W, 2.0d).sum());
        if (this.normalizeByInputRows) {
            sum /= this.input.rows;
        }
        return sum;
    }

    public double negativeLoglikelihood(DoubleMatrix doubleMatrix) {
        DoubleMatrix reconstruct = reconstruct(doubleMatrix);
        if (!this.useRegularization) {
            return -doubleMatrix.mul(MatrixUtil.log(reconstruct)).add(MatrixUtil.oneMinus(doubleMatrix).mul(MatrixUtil.log(MatrixUtil.oneMinus(reconstruct)))).columnSums().mean();
        }
        return (-doubleMatrix.mul(MatrixUtil.log(reconstruct)).add(MatrixUtil.oneMinus(doubleMatrix).mul(MatrixUtil.log(MatrixUtil.oneMinus(reconstruct)))).columnSums().mean()) + ((2.0d / this.l2) * MatrixFunctions.pow(this.W, 2.0d).sum());
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public double getReConstructionCrossEntropy() {
        DoubleMatrix sigmoid = MatrixUtil.sigmoid(MatrixUtil.sigmoid(this.input.mmul(this.W).addRowVector(this.hBias)).mmul(this.W.transpose()).addRowVector(this.vBias));
        DoubleMatrix add = this.input.mul(MatrixUtil.log(sigmoid)).add(MatrixUtil.oneMinus(this.input).mul(MatrixUtil.log(MatrixUtil.oneMinus(sigmoid))));
        double d = add.length;
        if (!this.useRegularization) {
            double d2 = -add.rowSums().mean();
            if (this.normalizeByInputRows) {
                d2 /= this.input.rows;
            }
            return d2;
        }
        double l2RegularizedCoefficient = (-add.rowSums().mean()) / (d + l2RegularizedCoefficient());
        if (this.normalizeByInputRows) {
            l2RegularizedCoefficient /= this.input.rows;
        }
        return l2RegularizedCoefficient;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public boolean normalizeByInputRows() {
        return this.normalizeByInputRows;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public int getnVisible() {
        return this.nVisible;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public void setnVisible(int i) {
        this.nVisible = i;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public int getnHidden() {
        return this.nHidden;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public void setnHidden(int i) {
        this.nHidden = i;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public DoubleMatrix getW() {
        return this.W;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public void setW(DoubleMatrix doubleMatrix) {
        this.W = doubleMatrix;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public DoubleMatrix gethBias() {
        return this.hBias;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public void sethBias(DoubleMatrix doubleMatrix) {
        this.hBias = doubleMatrix;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public DoubleMatrix getvBias() {
        return this.vBias;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public void setvBias(DoubleMatrix doubleMatrix) {
        this.vBias = doubleMatrix;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public RandomGenerator getRng() {
        return this.rng;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public void setRng(RandomGenerator randomGenerator) {
        this.rng = randomGenerator;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public DoubleMatrix getInput() {
        return this.input;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public void setInput(DoubleMatrix doubleMatrix) {
        this.input = doubleMatrix;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public double getSparsity() {
        return this.sparsity;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public void setSparsity(double d) {
        this.sparsity = d;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public double getMomentum() {
        return this.momentum;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public void setMomentum(double d) {
        this.momentum = d;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public double getL2() {
        return this.l2;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public void setL2(double d) {
        this.l2 = d;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public AdaGrad gethBiasAdaGrad() {
        return this.hBiasAdaGrad;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public void setHbiasAdaGrad(AdaGrad adaGrad) {
        this.hBiasAdaGrad = adaGrad;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public AdaGrad getVBiasAdaGrad() {
        return this.vBiasAdaGrad;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public void setVBiasAdaGrad(AdaGrad adaGrad) {
        this.vBiasAdaGrad = adaGrad;
    }

    @Override // org.deeplearning4j.nn.Persistable
    public void write(OutputStream outputStream) {
        try {
            new ObjectOutputStream(outputStream).writeObject(this);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public abstract DoubleMatrix reconstruct(DoubleMatrix doubleMatrix);

    public abstract double lossFunction(Object[] objArr);

    public double lossFunction() {
        return lossFunction(null);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void applyDropOutIfNecessary(DoubleMatrix doubleMatrix) {
        if (this.dropOut > 0.0d) {
            this.doMask = DoubleMatrix.rand(doubleMatrix.rows, this.nHidden).gt(this.dropOut);
        } else {
            this.doMask = DoubleMatrix.ones(doubleMatrix.rows, this.nHidden);
        }
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public abstract void train(DoubleMatrix doubleMatrix, double d, Object[] objArr);

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public double squaredLoss() {
        double sum = MatrixFunctions.powi(reconstruct(this.input).sub(this.input), 2.0d).sum() / this.input.rows;
        if (this.useRegularization) {
            sum += 0.5d * this.l2 * MatrixFunctions.pow(this.W, 2.0d).sum();
        }
        return -sum;
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public DoubleMatrix hBiasMean() {
        return getInput().mmul(getW()).addRowVector(gethBias());
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork, org.deeplearning4j.optimize.NeuralNetEpochListener
    public void epochDone(int i) {
        int renderEpochs = getRenderEpochs();
        if (renderEpochs <= 0) {
            return;
        }
        if (i % renderEpochs == 0 || i == 0) {
            new NeuralNetPlotter().plotNetworkGradient(this, getGradient(new Object[]{1, Double.valueOf(0.001d), 1000}));
        }
    }
}
