package org.deeplearning4j.util;

import cc.mallet.optimize.BackTrackLineSearch;
import cc.mallet.optimize.LineOptimizer;
import cc.mallet.optimize.Optimizable;
import cc.mallet.optimize.Optimizer;
import cc.mallet.optimize.OptimizerEvaluator;
import cc.mallet.types.MatrixOps;
import org.deeplearning4j.optimize.NeuralNetEpochListener;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/util/NonZeroStoppingConjugateGradient.class */
public class NonZeroStoppingConjugateGradient implements Optimizer {
    private static Logger logger;
    boolean converged;
    Optimizable.ByGradientValue optimizable;
    LineOptimizer.ByGradient lineMaximizer;
    double initialStepSize;
    double tolerance;
    double gradientTolerance;
    int maxIterations;
    private String myName;
    private NeuralNetEpochListener listener;
    final double eps = 1.0E-10d;
    private OptimizerEvaluator.ByGradient eval;
    double fp;
    double gg;
    double gam;
    double dgg;
    double step;
    double fret;
    double[] xi;
    double[] g;
    double[] h;
    int j;
    int iterations;
    static final /* synthetic */ boolean $assertionsDisabled;

    public NonZeroStoppingConjugateGradient(Optimizable.ByGradientValue byGradientValue, double d) {
        this.converged = false;
        this.initialStepSize = 1.0d;
        this.tolerance = 1.0E-5d;
        this.gradientTolerance = 1.0E-5d;
        this.maxIterations = 10000;
        this.myName = "";
        this.eps = 1.0E-10d;
        this.initialStepSize = d;
        this.optimizable = byGradientValue;
        this.lineMaximizer = new BackTrackLineSearch(byGradientValue);
        this.lineMaximizer.setAbsTolx(this.tolerance);
    }

    public NonZeroStoppingConjugateGradient(Optimizable.ByGradientValue byGradientValue, NeuralNetEpochListener neuralNetEpochListener) {
        this(byGradientValue, 0.01d);
        this.listener = neuralNetEpochListener;
    }

    public NonZeroStoppingConjugateGradient(Optimizable.ByGradientValue byGradientValue, double d, NeuralNetEpochListener neuralNetEpochListener) {
        this(byGradientValue, d);
        this.listener = neuralNetEpochListener;
    }

    public NonZeroStoppingConjugateGradient(Optimizable.ByGradientValue byGradientValue) {
        this(byGradientValue, 0.01d);
    }

    public Optimizable getOptimizable() {
        return this.optimizable;
    }

    public boolean isConverged() {
        return this.converged;
    }

    public void setEvaluator(OptimizerEvaluator.ByGradient byGradient) {
        this.eval = byGradient;
    }

    public void setLineMaximizer(LineOptimizer.ByGradient byGradient) {
        this.lineMaximizer = byGradient;
    }

    public void setInitialStepSize(double d) {
        this.initialStepSize = d;
    }

    public double getInitialStepSize() {
        return this.initialStepSize;
    }

    public double getStepSize() {
        return this.step;
    }

    public boolean optimize() {
        return optimize(this.maxIterations);
    }

    public void setTolerance(double d) {
        this.tolerance = d;
    }

    public boolean optimize(int i) {
        this.myName = Thread.currentThread().getName();
        if (this.converged) {
            return true;
        }
        int numParameters = this.optimizable.getNumParameters();
        long currentTimeMillis = System.currentTimeMillis();
        if (this.xi == null) {
            this.fp = this.optimizable.getValue();
            this.xi = new double[numParameters];
            this.g = new double[numParameters];
            this.h = new double[numParameters];
            this.optimizable.getValueGradient(this.xi);
            System.arraycopy(this.xi, 0, this.g, 0, numParameters);
            System.arraycopy(this.xi, 0, this.h, 0, numParameters);
            this.step = this.initialStepSize;
            this.iterations = 0;
        }
        for (int i2 = 0; i2 < i; i2++) {
            long currentTimeMillis2 = System.currentTimeMillis();
            if (this.listener != null) {
                this.listener.epochDone(i2);
            }
            logger.info(this.myName + " ConjugateGradient: At iteration " + this.iterations + ", cost = " + this.fp + " -" + (currentTimeMillis2 - currentTimeMillis));
            currentTimeMillis = currentTimeMillis2;
            try {
                this.step = this.lineMaximizer.optimize(this.xi, this.step);
            } catch (Throwable th) {
                logger.info(th.getMessage());
            }
            this.fret = this.optimizable.getValue();
            this.optimizable.getValueGradient(this.xi);
            if (0.0d < this.tolerance && 2.0d * Math.abs(this.fret - this.fp) <= this.tolerance * (Math.abs(this.fret) + Math.abs(this.fp) + 1.0E-10d)) {
                logger.info("ConjugateGradient converged: old value= " + this.fp + " new value= " + this.fret + " tolerance=" + this.tolerance);
                this.converged = true;
                return true;
            }
            this.fp = this.fret;
            double twoNorm = MatrixOps.twoNorm(this.xi);
            if (twoNorm < this.gradientTolerance) {
                logger.info("ConjugateGradient converged: gradient two norm " + twoNorm + ", less than " + this.gradientTolerance);
                this.converged = true;
                return true;
            }
            this.gg = 0.0d;
            this.dgg = 0.0d;
            this.j = 0;
            while (this.j < this.xi.length) {
                this.gg += this.g[this.j] * this.g[this.j];
                this.dgg += this.xi[this.j] * (this.xi[this.j] - this.g[this.j]);
                this.j++;
            }
            this.gam = this.dgg / this.gg;
            this.j = 0;
            while (this.j < this.xi.length) {
                this.g[this.j] = this.xi[this.j];
                this.h[this.j] = this.xi[this.j] + (this.gam * this.h[this.j]);
                this.j++;
            }
            if (!$assertionsDisabled && MatrixOps.isNaN(this.h)) {
                throw new AssertionError();
            }
            if (MatrixOps.dotProduct(this.xi, this.h) > 0.0d) {
                MatrixOps.set(this.xi, this.h);
            } else {
                logger.warn("Reverting back to GA");
                MatrixOps.set(this.h, this.xi);
            }
            this.iterations++;
            if (this.iterations > this.maxIterations) {
                logger.info("Too many iterations in ConjugateGradient.java");
                this.converged = true;
                return true;
            }
            if (this.eval != null) {
                this.eval.evaluate(this.optimizable, this.iterations);
            }
            if (this.listener != null) {
                this.listener.epochDone(i2);
            }
        }
        return false;
    }

    public void reset() {
        this.xi = null;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public void setMaxIterations(int i) {
        this.maxIterations = i;
    }

    static {
        $assertionsDisabled = !NonZeroStoppingConjugateGradient.class.desiredAssertionStatus();
        logger = LoggerFactory.getLogger(NonZeroStoppingConjugateGradient.class);
    }
}
