package org.deeplearning4j.nn;

import org.apache.commons.math3.distribution.RealDistribution;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.nn.activation.ActivationFunction;
import org.deeplearning4j.nn.activation.Sigmoid;
import org.deeplearning4j.util.MatrixUtil;
import org.jblas.DoubleMatrix;
import org.jblas.MatrixFunctions;

/* loaded from: input_file:org/deeplearning4j/nn/RectifiedLinearHiddenLayer.class */
public class RectifiedLinearHiddenLayer extends HiddenLayer {
    private static final long serialVersionUID = 2266162281744170946L;

    /* loaded from: input_file:org/deeplearning4j/nn/RectifiedLinearHiddenLayer$Builder.class */
    public static class Builder {
        protected int nIn;
        protected int nOut;
        protected DoubleMatrix W;
        protected DoubleMatrix b;
        protected RandomGenerator rng;
        protected DoubleMatrix input;
        protected ActivationFunction activationFunction = new Sigmoid();
        protected RealDistribution dist;

        public Builder dist(RealDistribution realDistribution) {
            this.dist = realDistribution;
            return this;
        }

        public Builder nIn(int i) {
            this.nIn = i;
            return this;
        }

        public Builder nOut(int i) {
            this.nOut = i;
            return this;
        }

        public Builder withWeights(DoubleMatrix doubleMatrix) {
            this.W = doubleMatrix;
            return this;
        }

        public Builder withRng(RandomGenerator randomGenerator) {
            this.rng = randomGenerator;
            return this;
        }

        public Builder withActivation(ActivationFunction activationFunction) {
            this.activationFunction = activationFunction;
            return this;
        }

        public Builder withBias(DoubleMatrix doubleMatrix) {
            this.b = doubleMatrix;
            return this;
        }

        public Builder withInput(DoubleMatrix doubleMatrix) {
            this.input = doubleMatrix;
            return this;
        }

        public RectifiedLinearHiddenLayer build() {
            RectifiedLinearHiddenLayer rectifiedLinearHiddenLayer = new RectifiedLinearHiddenLayer(this.nIn, this.nOut, this.W, this.b, this.rng, this.input);
            rectifiedLinearHiddenLayer.activationFunction = this.activationFunction;
            rectifiedLinearHiddenLayer.dist = this.dist;
            return rectifiedLinearHiddenLayer;
        }
    }

    public RectifiedLinearHiddenLayer() {
    }

    public RectifiedLinearHiddenLayer(int i, int i2, DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2, RandomGenerator randomGenerator, DoubleMatrix doubleMatrix3, ActivationFunction activationFunction) {
        this(i, i2, doubleMatrix, doubleMatrix2, randomGenerator, doubleMatrix3, activationFunction, null);
    }

    public RectifiedLinearHiddenLayer(int i, int i2, DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2, RandomGenerator randomGenerator, DoubleMatrix doubleMatrix3) {
        this(i, i2, doubleMatrix, doubleMatrix2, randomGenerator, doubleMatrix3, null, null);
    }

    public RectifiedLinearHiddenLayer(int i, int i2, DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2, RandomGenerator randomGenerator, DoubleMatrix doubleMatrix3, ActivationFunction activationFunction, RealDistribution realDistribution) {
        super(i, i2, doubleMatrix, doubleMatrix2, randomGenerator, doubleMatrix3, activationFunction, realDistribution);
    }

    public RectifiedLinearHiddenLayer(int i, int i2, DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2, RandomGenerator randomGenerator, DoubleMatrix doubleMatrix3, RealDistribution realDistribution) {
        super(i, i2, doubleMatrix, doubleMatrix2, randomGenerator, doubleMatrix3, realDistribution);
    }

    @Override // org.deeplearning4j.nn.HiddenLayer
    public DoubleMatrix sampleHGivenV(DoubleMatrix doubleMatrix) {
        this.input = doubleMatrix;
        DoubleMatrix activate = activate();
        DoubleMatrix addi = activate.addi(MatrixUtil.normal(getRng(), activate, 1.0d).mul(MatrixFunctions.sqrt(MatrixUtil.sigmoid(activate))));
        MatrixUtil.max(0.0d, addi);
        return addi;
    }

    @Override // org.deeplearning4j.nn.HiddenLayer
    public DoubleMatrix sampleHiddenGivenVisible() {
        return sampleHGivenV(this.input);
    }
}
