package opennlp.tools.ml.maxent.quasinewton;

import opennlp.tools.ml.maxent.quasinewton.LineSearch;

/* loaded from: input_file:opennlp/tools/ml/maxent/quasinewton/QNMinimizer.class */
public class QNMinimizer {
    public static final double CONVERGE_TOLERANCE = 1.0E-4d;
    public static final double REL_GRAD_NORM_TOL = 1.0E-4d;
    public static final double INITIAL_STEP_SIZE = 1.0d;
    public static final double MIN_STEP_SIZE = 1.0E-10d;
    public static final double L1COST_DEFAULT = 0.0d;
    public static final double L2COST_DEFAULT = 0.0d;
    public static final int NUM_ITERATIONS_DEFAULT = 100;
    public static final int M_DEFAULT = 15;
    public static final int MAX_FCT_EVAL_DEFAULT = 30000;
    private double l1Cost;
    private double l2Cost;
    private int iterations;
    private int m;
    private int maxFctEval;
    private boolean verbose;
    private int dimension;
    private UpdateInfo updateInfo;
    private Evaluator evaluator;

    /* loaded from: input_file:opennlp/tools/ml/maxent/quasinewton/QNMinimizer$Evaluator.class */
    public interface Evaluator {
        double evaluate(double[] dArr);
    }

    /* loaded from: input_file:opennlp/tools/ml/maxent/quasinewton/QNMinimizer$L2RegFunction.class */
    public static class L2RegFunction implements Function {
        private Function f;
        private double l2Cost;

        public L2RegFunction(Function function, double d) {
            this.f = function;
            this.l2Cost = d;
        }

        @Override // opennlp.tools.ml.maxent.quasinewton.Function
        public int getDimension() {
            return this.f.getDimension();
        }

        @Override // opennlp.tools.ml.maxent.quasinewton.Function
        public double valueAt(double[] dArr) {
            checkDimension(dArr);
            double valueAt = this.f.valueAt(dArr);
            if (this.l2Cost > 0.0d) {
                valueAt += this.l2Cost * ArrayMath.innerProduct(dArr, dArr);
            }
            return valueAt;
        }

        @Override // opennlp.tools.ml.maxent.quasinewton.Function
        public double[] gradientAt(double[] dArr) {
            checkDimension(dArr);
            double[] gradientAt = this.f.gradientAt(dArr);
            if (this.l2Cost > 0.0d) {
                for (int i = 0; i < dArr.length; i++) {
                    int i2 = i;
                    gradientAt[i2] = gradientAt[i2] + (2.0d * this.l2Cost * dArr[i]);
                }
            }
            return gradientAt;
        }

        private void checkDimension(double[] dArr) {
            if (dArr.length != getDimension()) {
                throw new IllegalArgumentException("x's dimension is not the same as function's dimension");
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:opennlp/tools/ml/maxent/quasinewton/QNMinimizer$UpdateInfo.class */
    public class UpdateInfo {
        private double[][] S;
        private double[][] Y;
        private double[] rho;
        private double[] alpha;
        private int m;
        private int kCounter = 0;

        UpdateInfo(int i, int i2) {
            this.m = i;
            this.S = new double[this.m][i2];
            this.Y = new double[this.m][i2];
            this.rho = new double[this.m];
            this.alpha = new double[this.m];
        }

        public void update(LineSearch.LineSearchResult lineSearchResult) {
            double[] currPoint = lineSearchResult.getCurrPoint();
            double[] gradAtCurr = lineSearchResult.getGradAtCurr();
            double[] nextPoint = lineSearchResult.getNextPoint();
            double[] gradAtNext = lineSearchResult.getGradAtNext();
            double d = 0.0d;
            if (this.kCounter < this.m) {
                for (int i = 0; i < QNMinimizer.this.dimension; i++) {
                    this.S[this.kCounter][i] = nextPoint[i] - currPoint[i];
                    this.Y[this.kCounter][i] = gradAtNext[i] - gradAtCurr[i];
                    d += this.S[this.kCounter][i] * this.Y[this.kCounter][i];
                }
                this.rho[this.kCounter] = 1.0d / d;
            } else {
                for (int i2 = 0; i2 < this.m - 1; i2++) {
                    this.S[i2] = this.S[i2 + 1];
                    this.Y[i2] = this.Y[i2 + 1];
                    this.rho[i2] = this.rho[i2 + 1];
                }
                for (int i3 = 0; i3 < QNMinimizer.this.dimension; i3++) {
                    this.S[this.m - 1][i3] = nextPoint[i3] - currPoint[i3];
                    this.Y[this.m - 1][i3] = gradAtNext[i3] - gradAtCurr[i3];
                    d += this.S[this.m - 1][i3] * this.Y[this.m - 1][i3];
                }
                this.rho[this.m - 1] = 1.0d / d;
            }
            if (this.kCounter < this.m) {
                this.kCounter++;
            }
        }
    }

    public QNMinimizer() {
        this(0.0d, 0.0d);
    }

    public QNMinimizer(double d, double d2) {
        this(d, d2, 100);
    }

    public QNMinimizer(double d, double d2, int i) {
        this(d, d2, i, 15, 30000);
    }

    public QNMinimizer(double d, double d2, int i, int i2, int i3) {
        this(d, d2, i, i2, i3, true);
    }

    public QNMinimizer(double d, double d2, int i, int i2, int i3, boolean z) {
        if (d < 0.0d || d2 < 0.0d) {
            throw new IllegalArgumentException("L1-cost and L2-cost must not be less than zero");
        }
        if (i <= 0) {
            throw new IllegalArgumentException("Number of iterations must be larger than zero");
        }
        if (i2 <= 0) {
            throw new IllegalArgumentException("Number of Hessian updates must be larger than zero");
        }
        if (i3 <= 0) {
            throw new IllegalArgumentException("Maximum number of function evaluations must be larger than zero");
        }
        this.l1Cost = d;
        this.l2Cost = d2;
        this.iterations = i;
        this.m = i2;
        this.maxFctEval = i3;
        this.verbose = z;
    }

    public Evaluator getEvaluator() {
        return this.evaluator;
    }

    public void setEvaluator(Evaluator evaluator) {
        this.evaluator = evaluator;
    }

    public double[] minimize(Function function) {
        L2RegFunction l2RegFunction = new L2RegFunction(function, this.l2Cost);
        this.dimension = l2RegFunction.getDimension();
        this.updateInfo = new UpdateInfo(this.m, this.dimension);
        double[] dArr = new double[this.dimension];
        double valueAt = l2RegFunction.valueAt(dArr);
        double[] dArr2 = new double[this.dimension];
        System.arraycopy(l2RegFunction.gradientAt(dArr), 0, dArr2, 0, this.dimension);
        double[] dArr3 = null;
        if (this.l1Cost > 0.0d) {
            valueAt += this.l1Cost * ArrayMath.l1norm(dArr);
            dArr3 = new double[this.dimension];
            computePseudoGrad(dArr, dArr2, dArr3);
        }
        LineSearch.LineSearchResult initialObjectForL1 = this.l1Cost > 0.0d ? LineSearch.LineSearchResult.getInitialObjectForL1(valueAt, dArr2, dArr3, dArr) : LineSearch.LineSearchResult.getInitialObject(valueAt, dArr2, dArr);
        if (this.verbose) {
            display("\nSolving convex optimization problem.");
            display("\nObjective function has " + this.dimension + " variable(s).");
            display("\n\nPerforming " + this.iterations + " iterations with L1Cost=" + this.l1Cost + " and L2Cost=" + this.l2Cost + "\n");
        }
        double[] dArr4 = new double[this.dimension];
        long currentTimeMillis = System.currentTimeMillis();
        double invL2norm = this.l1Cost > 0.0d ? ArrayMath.invL2norm(initialObjectForL1.getPseudoGradAtNext()) : ArrayMath.invL2norm(initialObjectForL1.getGradAtNext());
        for (int i = 1; i <= this.iterations; i++) {
            if (this.l1Cost > 0.0d) {
                System.arraycopy(initialObjectForL1.getPseudoGradAtNext(), 0, dArr4, 0, dArr4.length);
            } else {
                System.arraycopy(initialObjectForL1.getGradAtNext(), 0, dArr4, 0, dArr4.length);
            }
            computeDirection(dArr4);
            if (this.l1Cost > 0.0d) {
                double[] pseudoGradAtNext = initialObjectForL1.getPseudoGradAtNext();
                for (int i2 = 0; i2 < this.dimension; i2++) {
                    if (dArr4[i2] * pseudoGradAtNext[i2] >= 0.0d) {
                        dArr4[i2] = 0.0d;
                    }
                }
                LineSearch.doConstrainedLineSearch(l2RegFunction, dArr4, initialObjectForL1, this.l1Cost, invL2norm);
                computePseudoGrad(initialObjectForL1.getNextPoint(), initialObjectForL1.getGradAtNext(), pseudoGradAtNext);
                initialObjectForL1.setPseudoGradAtNext(pseudoGradAtNext);
            } else {
                LineSearch.doLineSearch(l2RegFunction, dArr4, initialObjectForL1, invL2norm);
            }
            this.updateInfo.update(initialObjectForL1);
            if (this.verbose) {
                if (i < 10) {
                    display("  " + i + ":  ");
                } else if (i < 100) {
                    display(" " + i + ":  ");
                } else {
                    display(i + ":  ");
                }
                if (this.evaluator != null) {
                    display("\t" + initialObjectForL1.getValueAtNext() + "\t" + initialObjectForL1.getFuncChangeRate() + "\t" + this.evaluator.evaluate(initialObjectForL1.getNextPoint()) + "\n");
                } else {
                    display("\t " + initialObjectForL1.getValueAtNext() + "\t" + initialObjectForL1.getFuncChangeRate() + "\n");
                }
            }
            if (isConverged(initialObjectForL1)) {
                break;
            }
            invL2norm = 1.0d;
        }
        if (this.l1Cost > 0.0d && this.l2Cost > 0.0d) {
            double[] nextPoint = initialObjectForL1.getNextPoint();
            for (int i3 = 0; i3 < this.dimension; i3++) {
                nextPoint[i3] = Math.sqrt(1.0d + this.l2Cost) * nextPoint[i3];
            }
        }
        display("Running time: " + ((System.currentTimeMillis() - currentTimeMillis) / 1000.0d) + "s\n");
        this.updateInfo = null;
        System.gc();
        double[] dArr5 = new double[this.dimension];
        System.arraycopy(initialObjectForL1.getNextPoint(), 0, dArr5, 0, this.dimension);
        return dArr5;
    }

    private void computePseudoGrad(double[] dArr, double[] dArr2, double[] dArr3) {
        for (int i = 0; i < this.dimension; i++) {
            if (dArr[i] < 0.0d) {
                dArr3[i] = dArr2[i] - this.l1Cost;
            } else if (dArr[i] > 0.0d) {
                dArr3[i] = dArr2[i] + this.l1Cost;
            } else if (dArr2[i] < (-this.l1Cost)) {
                dArr3[i] = dArr2[i] + this.l1Cost;
            } else if (dArr2[i] > this.l1Cost) {
                dArr3[i] = dArr2[i] - this.l1Cost;
            } else {
                dArr3[i] = 0.0d;
            }
        }
    }

    private void computeDirection(double[] dArr) {
        int i = this.updateInfo.kCounter;
        double[] dArr2 = this.updateInfo.rho;
        double[] dArr3 = this.updateInfo.alpha;
        double[][] dArr4 = this.updateInfo.S;
        double[][] dArr5 = this.updateInfo.Y;
        for (int i2 = i - 1; i2 >= 0; i2--) {
            dArr3[i2] = dArr2[i2] * ArrayMath.innerProduct(dArr4[i2], dArr);
            for (int i3 = 0; i3 < this.dimension; i3++) {
                dArr[i3] = dArr[i3] - (dArr3[i2] * dArr5[i2][i3]);
            }
        }
        for (int i4 = 0; i4 < i; i4++) {
            double innerProduct = dArr2[i4] * ArrayMath.innerProduct(dArr5[i4], dArr);
            for (int i5 = 0; i5 < this.dimension; i5++) {
                dArr[i5] = dArr[i5] + (dArr4[i4][i5] * (dArr3[i4] - innerProduct));
            }
        }
        for (int i6 = 0; i6 < this.dimension; i6++) {
            dArr[i6] = -dArr[i6];
        }
    }

    private boolean isConverged(LineSearch.LineSearchResult lineSearchResult) {
        if (lineSearchResult.getFuncChangeRate() < 1.0E-4d) {
            if (!this.verbose) {
                return true;
            }
            display("Function change rate is smaller than the threshold 1.0E-4.\nTraining will stop.\n\n");
            return true;
        }
        if ((this.l1Cost > 0.0d ? ArrayMath.l2norm(lineSearchResult.getPseudoGradAtNext()) : ArrayMath.l2norm(lineSearchResult.getGradAtNext())) / Math.max(1.0d, ArrayMath.l2norm(lineSearchResult.getNextPoint())) < 1.0E-4d) {
            if (!this.verbose) {
                return true;
            }
            display("Relative L2-norm of the gradient is smaller than the threshold 1.0E-4.\nTraining will stop.\n\n");
            return true;
        }
        if (lineSearchResult.getStepSize() < 1.0E-10d) {
            if (!this.verbose) {
                return true;
            }
            display("Step size is smaller than the minimum step size 1.0E-10.\nTraining will stop.\n\n");
            return true;
        }
        if (lineSearchResult.getFctEvalCount() <= this.maxFctEval) {
            return false;
        }
        if (!this.verbose) {
            return true;
        }
        display("Maximum number of function evaluations has exceeded the threshold " + this.maxFctEval + ".\nTraining will stop.\n\n");
        return true;
    }

    private void display(String str) {
        System.out.print(str);
    }
}
