/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.optimize;

import cc.mallet.optimize.Optimizable;
import java.io.Serializable;
import org.deeplearning4j.nn.BaseMultiLayerNetwork;
import org.deeplearning4j.nn.gradient.LogisticRegressionGradient;
import org.deeplearning4j.optimize.OptimizableByGradientValueMatrix;
import org.jblas.DoubleMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MultiLayerNetworkOptimizer
implements Optimizable.ByGradientValue,
Serializable,
OptimizableByGradientValueMatrix {
    private static final long serialVersionUID = -3012638773299331828L;
    protected BaseMultiLayerNetwork network;
    private static Logger log = LoggerFactory.getLogger(MultiLayerNetworkOptimizer.class);
    private double lr;

    public MultiLayerNetworkOptimizer(BaseMultiLayerNetwork network, double lr) {
        this.network = network;
        this.lr = lr;
    }

    public void optimize(DoubleMatrix labels, double lr, int epochs) {
        this.network.getLogLayer().setLabels(labels);
        DoubleMatrix train = this.sampleHiddenGivenVisible();
        if (!this.network.isForceNumEpochs()) {
            this.network.getLogLayer().trainTillConvergence(train, labels, lr, epochs);
            if (this.network.isShouldBackProp()) {
                this.network.backProp(lr, epochs);
            }
        } else {
            log.info("Training for " + epochs + " epochs");
            for (int i = 0; i < epochs; ++i) {
                this.network.getLogLayer().train(train, labels, lr);
            }
            if (this.network.isShouldBackProp()) {
                this.network.backProp(lr, epochs);
            }
        }
    }

    private DoubleMatrix sampleHiddenGivenVisible() {
        if (this.network.isUseHiddenActivationsForwardProp()) {
            return this.network.getSigmoidLayers()[this.network.getnLayers() - 1].sampleHiddenGivenVisible();
        }
        return this.network.getLayers()[this.network.getnLayers() - 1].sampleHiddenGivenVisible(this.network.getLayers()[this.network.getnLayers() - 1].getInput()).getSecond();
    }

    @Override
    public int getNumParameters() {
        return this.network.getLogLayer().getW().length + this.network.getLogLayer().getB().length;
    }

    public void getParameters(double[] buffer) {
        int i;
        int idx = 0;
        for (i = 0; i < this.network.getLogLayer().getW().length; ++i) {
            buffer[idx++] = this.network.getLogLayer().getW().get(i);
        }
        for (i = 0; i < this.network.getLogLayer().getB().length; ++i) {
            buffer[idx++] = this.network.getLogLayer().getB().get(i);
        }
    }

    @Override
    public double getParameter(int index) {
        if (index >= this.network.getLogLayer().getW().length) {
            int i = index - this.network.getLogLayer().getB().length;
            return this.network.getLogLayer().getB().get(i);
        }
        return this.network.getLogLayer().getW().get(index);
    }

    public void setParameters(double[] params) {
        int i;
        int idx = 0;
        for (i = 0; i < this.network.getLogLayer().getW().length; ++i) {
            this.network.getLogLayer().getW().put(i, params[idx++]);
        }
        for (i = 0; i < this.network.getLogLayer().getB().length; ++i) {
            this.network.getLogLayer().getB().put(i, params[idx++]);
        }
    }

    @Override
    public void setParameter(int index, double value) {
        if (index >= this.network.getLogLayer().getW().length) {
            int i = index - this.network.getLogLayer().getB().length;
            this.network.getLogLayer().getB().put(i, value);
        } else {
            this.network.getLogLayer().getW().put(index, value);
        }
    }

    public void getValueGradient(double[] buffer) {
        int i;
        LogisticRegressionGradient gradient = this.network.getLogLayer().getGradient(this.lr);
        DoubleMatrix weightGradient = gradient.getwGradient();
        DoubleMatrix biasGradient = gradient.getbGradient();
        int idx = 0;
        for (i = 0; i < weightGradient.length; ++i) {
            buffer[idx++] = weightGradient.get(i);
        }
        for (i = 0; i < biasGradient.length; ++i) {
            buffer[idx++] = biasGradient.get(i);
        }
    }

    @Override
    public double getValue() {
        return this.network.negativeLogLikelihood();
    }

    @Override
    public DoubleMatrix getParameters() {
        double[] d = new double[this.getNumParameters()];
        this.getParameters(d);
        return new DoubleMatrix(d);
    }

    @Override
    public void setParameters(DoubleMatrix params) {
        this.setParameters(params.toArray());
    }

    @Override
    public DoubleMatrix getValueGradient() {
        double[] buffer = new double[this.getNumParameters()];
        this.getValueGradient(buffer);
        return new DoubleMatrix(buffer);
    }
}

