package org.deeplearning4j.optimize;

import cc.mallet.optimize.Optimizable;
import java.io.Serializable;
import org.deeplearning4j.nn.BaseMultiLayerNetwork;
import org.deeplearning4j.nn.gradient.LogisticRegressionGradient;
import org.jblas.DoubleMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/optimize/MultiLayerNetworkOptimizer.class */
public class MultiLayerNetworkOptimizer implements Optimizable.ByGradientValue, Serializable, OptimizableByGradientValueMatrix {
    private static final long serialVersionUID = -3012638773299331828L;
    protected BaseMultiLayerNetwork network;
    private static Logger log = LoggerFactory.getLogger(MultiLayerNetworkOptimizer.class);
    private double lr;

    public MultiLayerNetworkOptimizer(BaseMultiLayerNetwork baseMultiLayerNetwork, double d) {
        this.network = baseMultiLayerNetwork;
        this.lr = d;
    }

    public void optimize(DoubleMatrix doubleMatrix, double d, int i) {
        this.network.getLogLayer().setLabels(doubleMatrix);
        DoubleMatrix sampleHiddenGivenVisible = sampleHiddenGivenVisible();
        if (!this.network.isForceNumEpochs()) {
            this.network.getLogLayer().trainTillConvergence(sampleHiddenGivenVisible, doubleMatrix, d, i);
            if (this.network.isShouldBackProp()) {
                this.network.backProp(d, i);
                return;
            }
            return;
        }
        log.info("Training for " + i + " epochs");
        for (int i2 = 0; i2 < i; i2++) {
            this.network.getLogLayer().train(sampleHiddenGivenVisible, doubleMatrix, d);
        }
        if (this.network.isShouldBackProp()) {
            this.network.backProp(d, i);
        }
    }

    private DoubleMatrix sampleHiddenGivenVisible() {
        return this.network.isUseHiddenActivationsForwardProp() ? this.network.getSigmoidLayers()[this.network.getnLayers() - 1].sampleHiddenGivenVisible() : this.network.getLayers()[this.network.getnLayers() - 1].sampleHiddenGivenVisible(this.network.getLayers()[this.network.getnLayers() - 1].getInput()).getSecond();
    }

    @Override // org.deeplearning4j.optimize.OptimizableByGradientValueMatrix
    public int getNumParameters() {
        return this.network.getLogLayer().getW().length + this.network.getLogLayer().getB().length;
    }

    public void getParameters(double[] dArr) {
        int i = 0;
        for (int i2 = 0; i2 < this.network.getLogLayer().getW().length; i2++) {
            int i3 = i;
            i++;
            dArr[i3] = this.network.getLogLayer().getW().get(i2);
        }
        for (int i4 = 0; i4 < this.network.getLogLayer().getB().length; i4++) {
            int i5 = i;
            i++;
            dArr[i5] = this.network.getLogLayer().getB().get(i4);
        }
    }

    @Override // org.deeplearning4j.optimize.OptimizableByGradientValueMatrix
    public double getParameter(int i) {
        if (i < this.network.getLogLayer().getW().length) {
            return this.network.getLogLayer().getW().get(i);
        }
        return this.network.getLogLayer().getB().get(i - this.network.getLogLayer().getB().length);
    }

    public void setParameters(double[] dArr) {
        int i = 0;
        for (int i2 = 0; i2 < this.network.getLogLayer().getW().length; i2++) {
            int i3 = i;
            i++;
            this.network.getLogLayer().getW().put(i2, dArr[i3]);
        }
        for (int i4 = 0; i4 < this.network.getLogLayer().getB().length; i4++) {
            int i5 = i;
            i++;
            this.network.getLogLayer().getB().put(i4, dArr[i5]);
        }
    }

    @Override // org.deeplearning4j.optimize.OptimizableByGradientValueMatrix
    public void setParameter(int i, double d) {
        if (i < this.network.getLogLayer().getW().length) {
            this.network.getLogLayer().getW().put(i, d);
        } else {
            this.network.getLogLayer().getB().put(i - this.network.getLogLayer().getB().length, d);
        }
    }

    public void getValueGradient(double[] dArr) {
        LogisticRegressionGradient gradient = this.network.getLogLayer().getGradient(this.lr);
        DoubleMatrix doubleMatrix = gradient.getwGradient();
        DoubleMatrix doubleMatrix2 = gradient.getbGradient();
        int i = 0;
        for (int i2 = 0; i2 < doubleMatrix.length; i2++) {
            int i3 = i;
            i++;
            dArr[i3] = doubleMatrix.get(i2);
        }
        for (int i4 = 0; i4 < doubleMatrix2.length; i4++) {
            int i5 = i;
            i++;
            dArr[i5] = doubleMatrix2.get(i4);
        }
    }

    @Override // org.deeplearning4j.optimize.OptimizableByGradientValueMatrix
    public double getValue() {
        return this.network.negativeLogLikelihood();
    }

    @Override // org.deeplearning4j.optimize.OptimizableByGradientValueMatrix
    public DoubleMatrix getParameters() {
        double[] dArr = new double[getNumParameters()];
        getParameters(dArr);
        return new DoubleMatrix(dArr);
    }

    @Override // org.deeplearning4j.optimize.OptimizableByGradientValueMatrix
    public void setParameters(DoubleMatrix doubleMatrix) {
        setParameters(doubleMatrix.toArray());
    }

    @Override // org.deeplearning4j.optimize.OptimizableByGradientValueMatrix
    public DoubleMatrix getValueGradient() {
        double[] dArr = new double[getNumParameters()];
        getValueGradient(dArr);
        return new DoubleMatrix(dArr);
    }
}
