package edu.stanford.nlp.optimization;

import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.optimization.AbstractStochasticCachingDiffFunction;
import edu.stanford.nlp.optimization.Function;
import edu.stanford.nlp.util.Timing;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Random;
import java.util.Set;

/* loaded from: input_file:edu/stanford/nlp/optimization/SGDWithAdaGradAndFOBOS.class */
public class SGDWithAdaGradAndFOBOS<T extends Function> implements Minimizer<T>, HasEvaluators {
    protected double[] x;
    protected double initRate;
    protected double lambda;
    protected double alpha;
    protected boolean quiet;
    private static final int DEFAULT_NUM_PASSES = 50;
    protected final int numPasses;
    protected int bSize;
    private static final int DEFAULT_TUNING_SAMPLES = Integer.MAX_VALUE;
    private static final int DEFAULT_BATCH_SIZE = 1000;
    protected final int tuningSamples;
    private final double eps = 0.001d;
    protected Random gen;
    protected long maxTime;
    private int evaluateIters;
    private Evaluator[] evaluators;
    private Prior prior;
    private boolean useEvalImprovement;
    private boolean suppressTestPrompt;
    private int terminateOnEvalImprovementNumOfEpoch;
    private double bestEvalSoFar;
    private double[] xBest;
    private int noImproveItrCount;
    private static final NumberFormat nf = new DecimalFormat("0.000E0");

    /* loaded from: input_file:edu/stanford/nlp/optimization/SGDWithAdaGradAndFOBOS$Prior.class */
    public enum Prior {
        LASSO,
        RIDGE,
        aeLASSO,
        gLASSO,
        sgLASSO
    }

    public void terminateOnEvalImprovement(boolean z) {
        this.useEvalImprovement = z;
    }

    public void suppressTestPrompt(boolean z) {
        this.suppressTestPrompt = z;
    }

    public void setTerminateOnEvalImprovementNumOfEpoch(int i) {
        this.terminateOnEvalImprovementNumOfEpoch = i;
    }

    public boolean toContinue(double[] dArr, double d) {
        if (d < this.bestEvalSoFar) {
            this.noImproveItrCount++;
            return this.noImproveItrCount <= this.terminateOnEvalImprovementNumOfEpoch;
        }
        this.bestEvalSoFar = d;
        this.noImproveItrCount = 0;
        if (this.xBest == null) {
            this.xBest = Arrays.copyOf(dArr, dArr.length);
            return true;
        }
        System.arraycopy(dArr, 0, this.xBest, 0, dArr.length);
        return true;
    }

    private Prior getPrior(String str) {
        if (str.equals("lasso")) {
            return Prior.LASSO;
        }
        if (str.equals("ridge")) {
            return Prior.RIDGE;
        }
        if (str.equals("ae-lasso")) {
            return Prior.aeLASSO;
        }
        if (str.equals("g-lasso")) {
            return Prior.gLASSO;
        }
        if (str.equals("sg-lasso")) {
            return Prior.sgLASSO;
        }
        throw new IllegalArgumentException("prior type " + str + " not recognized; supported prior are:lasso, ridge, ae-lasso, g-lasso");
    }

    public SGDWithAdaGradAndFOBOS(double d, double d2, int i) {
        this(d, d2, i, -1);
    }

    public SGDWithAdaGradAndFOBOS(double d, double d2, int i, int i2) {
        this(d, d2, i, i2, DEFAULT_BATCH_SIZE);
    }

    public SGDWithAdaGradAndFOBOS(double d, double d2, int i, int i2, int i3) {
        this(d, d2, i, i2, i3, "lasso", 1.0d);
    }

    public SGDWithAdaGradAndFOBOS(double d, double d2, int i, int i2, int i3, String str, double d3) {
        this.alpha = 1.0d;
        this.quiet = false;
        this.bSize = 1;
        this.eps = 0.001d;
        this.gen = new Random(1L);
        this.maxTime = Long.MAX_VALUE;
        this.evaluateIters = 0;
        this.prior = Prior.LASSO;
        this.useEvalImprovement = false;
        this.suppressTestPrompt = false;
        this.terminateOnEvalImprovementNumOfEpoch = 1;
        this.bestEvalSoFar = Double.NEGATIVE_INFINITY;
        this.noImproveItrCount = 0;
        this.initRate = d;
        this.prior = getPrior(str);
        this.bSize = i3;
        this.lambda = d2;
        this.alpha = d3;
        if (i >= 0) {
            this.numPasses = i;
        } else {
            this.numPasses = 50;
            sayln("  SGDWithAdaGradAndFOBOS: numPasses=" + i + ", defaulting to " + this.numPasses);
        }
        if (i2 > 0) {
            this.tuningSamples = i2;
        } else {
            this.tuningSamples = DEFAULT_TUNING_SAMPLES;
            sayln("  SGDWithAdaGradAndFOBOS: tuneSampleSize=" + i2 + ", defaulting to " + this.tuningSamples);
        }
    }

    public void shutUp() {
        this.quiet = true;
    }

    protected String getName() {
        return "SGDWithAdaGradAndFOBOS" + this.bSize + "_lambda" + nf.format(this.lambda) + "_alpha" + nf.format(this.alpha);
    }

    @Override // edu.stanford.nlp.optimization.HasEvaluators
    public void setEvaluators(int i, Evaluator[] evaluatorArr) {
        this.evaluateIters = i;
        this.evaluators = evaluatorArr;
    }

    private static double getNorm(double[] dArr) {
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            d += dArr[i] * dArr[i];
        }
        return Math.sqrt(d);
    }

    private double doEvaluation(double[] dArr) {
        if (this.evaluators == null) {
            return Double.NEGATIVE_INFINITY;
        }
        double d = Double.NEGATIVE_INFINITY;
        for (Evaluator evaluator : this.evaluators) {
            if (!this.suppressTestPrompt) {
                sayln("  Evaluating: " + evaluator.toString());
            }
            double evaluate = evaluator.evaluate(dArr);
            if (evaluate != Double.NEGATIVE_INFINITY) {
                d = evaluate;
            }
        }
        return d;
    }

    private double pospart(double d) {
        if (d > 0.0d) {
            return d;
        }
        return 0.0d;
    }

    @Override // edu.stanford.nlp.optimization.Minimizer
    public double[] minimize(Function function, double d, double[] dArr) {
        return minimize(function, d, dArr, -1);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // edu.stanford.nlp.optimization.Minimizer
    public double[] minimize(Function function, double d, double[] dArr, int i) {
        Set hashSet;
        int i2 = 0;
        sayln("Using lambda=" + this.lambda);
        if (function instanceof AbstractStochasticCachingDiffUpdateFunction) {
            AbstractStochasticCachingDiffUpdateFunction abstractStochasticCachingDiffUpdateFunction = (AbstractStochasticCachingDiffUpdateFunction) function;
            abstractStochasticCachingDiffUpdateFunction.sampleMethod = AbstractStochasticCachingDiffFunction.SamplingMethod.Shuffled;
            i2 = abstractStochasticCachingDiffUpdateFunction.dataDimension();
            int min = Math.min(i2, this.tuningSamples);
            if (min < this.tuningSamples) {
                System.err.println("WARNING: Total number of samples=" + i2 + " is smaller than requested tuning sample size=" + this.tuningSamples + "!!!");
            }
            sayln("Using sample size=" + min);
            if (this.bSize > i2) {
                System.err.println("WARNING: Total number of samples=" + i2 + " is smaller than requested batch size=" + this.bSize + "!!!");
                this.bSize = i2;
                sayln("Using batch size=" + this.bSize);
            }
        }
        this.x = new double[dArr.length];
        double[] dArr2 = null;
        double[] dArr3 = null;
        double[] dArr4 = new double[dArr.length];
        int[][] iArr = (int[][]) null;
        if (this.prior != Prior.LASSO) {
            dArr2 = new double[dArr.length];
            dArr3 = new double[dArr.length];
        }
        if (this.prior != Prior.LASSO && this.prior != Prior.RIDGE) {
            if (!(function instanceof HasFeatureGrouping)) {
                throw new UnsupportedOperationException("prior is specified to be ae-lasso or g-lasso, but function does not support feature grouping");
            }
            iArr = ((HasFeatureGrouping) function).getFeatureGrouping();
        }
        double[] dArr5 = this.prior == Prior.sgLASSO ? new double[dArr.length] : null;
        System.arraycopy(dArr, 0, this.x, 0, this.x.length);
        int i3 = function instanceof AbstractStochasticCachingDiffUpdateFunction ? i2 / this.bSize : 1;
        if (!(i > 0 || this.numPasses > 0)) {
            throw new UnsupportedOperationException("No maximum number of iterations has been specified.");
        }
        int max = Math.max(i, this.numPasses) * i3;
        sayln("       Batch size of: " + this.bSize);
        sayln("       Data dimension of: " + i2);
        sayln("       Batches per pass through data:  " + i3);
        sayln("       Number of passes is = " + this.numPasses);
        sayln("       Max iterations is = " + max);
        Timing timing = new Timing();
        Timing timing2 = new Timing();
        timing.start();
        timing2.start();
        int i4 = 0;
        int i5 = 0;
        while (true) {
            if (i5 >= this.numPasses) {
                break;
            }
            double d2 = Double.NEGATIVE_INFINITY;
            if (i5 > 0 && this.evaluateIters > 0 && i5 % this.evaluateIters == 0) {
                d2 = doEvaluation(this.x);
                if (this.useEvalImprovement && !toContinue(this.x, d2)) {
                    break;
                }
            }
            say("Iter: " + i4 + " pass " + i5 + " batch 1 ... ");
            int i6 = 0;
            int i7 = 0;
            for (int i8 = 0; i8 < i3; i8++) {
                i4++;
                double[] dArr6 = null;
                if (function instanceof AbstractStochasticCachingDiffUpdateFunction) {
                    AbstractStochasticCachingDiffUpdateFunction abstractStochasticCachingDiffUpdateFunction2 = (AbstractStochasticCachingDiffUpdateFunction) function;
                    abstractStochasticCachingDiffUpdateFunction2.calculateStochasticGradient(this.x, this.bSize);
                    dArr6 = abstractStochasticCachingDiffUpdateFunction2.getDerivative();
                } else if (function instanceof AbstractCachingDiffFunction) {
                    dArr6 = ((AbstractCachingDiffFunction) function).derivativeAt(this.x);
                }
                if (this.prior == Prior.LASSO || this.prior == Prior.RIDGE) {
                    double d3 = 0.0d;
                    if (function instanceof HasRegularizerParamRange) {
                        hashSet = ((HasRegularizerParamRange) function).getRegularizerParamRange(this.x);
                    } else {
                        hashSet = new HashSet();
                        for (int i9 = 0; i9 < this.x.length; i9++) {
                            hashSet.add(Integer.valueOf(i9));
                        }
                    }
                    Iterator it = hashSet.iterator();
                    while (it.hasNext()) {
                        int intValue = ((Integer) it.next()).intValue();
                        double d4 = dArr6[intValue];
                        dArr4[intValue] = dArr4[intValue] + (d4 * d4);
                        double d5 = this.x[intValue];
                        double sqrt = this.initRate / (Math.sqrt(dArr4[intValue]) + 0.001d);
                        double d6 = d5 - (sqrt * d4);
                        if (this.prior == Prior.LASSO) {
                            double signum = Math.signum(d6) * pospart(Math.abs(d6) - (sqrt * this.lambda));
                            this.x[intValue] = signum;
                            if (signum != 0.0d) {
                                i6++;
                            }
                        } else if (this.prior == Prior.RIDGE) {
                            d3 += d6 * d6;
                            dArr2[intValue] = d6;
                            dArr3[intValue] = sqrt;
                        }
                    }
                    if (this.prior == Prior.RIDGE) {
                        double sqrt2 = Math.sqrt(d3);
                        for (int i10 = 0; i10 < dArr2.length; i10++) {
                            double pospart = dArr2[i10] * pospart(1.0d - ((dArr3[i10] * this.lambda) / sqrt2));
                            this.x[i10] = pospart;
                            if (pospart != 0.0d) {
                                i6++;
                            }
                        }
                    }
                } else {
                    for (int[] iArr2 : iArr) {
                        double d7 = 0.0d;
                        double d8 = 0.0d;
                        double length = iArr2.length;
                        double log = Math.log(length);
                        for (int i11 : iArr2) {
                            double d9 = dArr6[i11];
                            dArr4[i11] = dArr4[i11] + (d9 * d9);
                            double d10 = this.x[i11];
                            double sqrt3 = this.initRate / (Math.sqrt(dArr4[i11]) + 0.001d);
                            double d11 = d10 - (sqrt3 * d9);
                            d7 += d11 * d11;
                            d8 += Math.abs(d11);
                            dArr2[i11] = d11;
                            dArr3[i11] = sqrt3;
                        }
                        if (this.prior == Prior.gLASSO) {
                            double sqrt4 = Math.sqrt(d7);
                            boolean z = false;
                            for (int i12 : iArr2) {
                                double pospart2 = dArr2[i12] * pospart(1.0d - (((dArr3[i12] * this.lambda) * log) / sqrt4));
                                this.x[i12] = pospart2;
                                if (pospart2 != 0.0d) {
                                    i6++;
                                    z = true;
                                }
                            }
                            if (z) {
                                i7++;
                            }
                        } else if (this.prior == Prior.aeLASSO) {
                            int i13 = 0;
                            boolean z2 = false;
                            for (int i14 : iArr2) {
                                double signum2 = Math.signum(dArr2[i14]) * pospart(Math.abs(dArr2[i14]) - (((dArr3[i14] * this.lambda) / (1.0d + ((dArr3[i14] * this.lambda) * length))) * d8));
                                this.x[i14] = signum2;
                                if (signum2 != 0.0d) {
                                    i6++;
                                    i13++;
                                    z2 = true;
                                }
                            }
                            if (z2) {
                                i7++;
                            }
                        } else if (this.prior == Prior.sgLASSO) {
                            double d12 = 0.0d;
                            for (int i15 : iArr2) {
                                double signum3 = Math.signum(dArr2[i15]) * pospart(Math.abs(dArr2[i15]) - ((dArr3[i15] * this.alpha) * this.lambda));
                                dArr5[i15] = signum3;
                                d12 += signum3 * signum3;
                            }
                            double sqrt5 = Math.sqrt(d12);
                            int i16 = 0;
                            boolean z3 = false;
                            for (int i17 : iArr2) {
                                double pospart3 = dArr5[i17] * pospart(1.0d - ((((dArr3[i17] * (1.0d - this.alpha)) * this.lambda) * log) / sqrt5));
                                this.x[i17] = pospart3;
                                if (pospart3 != 0.0d) {
                                    i6++;
                                    i16++;
                                    z3 = true;
                                }
                            }
                            if (z3) {
                                i7++;
                            }
                        }
                    }
                }
            }
            try {
                ArrayMath.assertFinite(this.x, "x");
                sayln(String.valueOf(i3) + ", n0-fCount:" + i6 + ((this.prior == Prior.LASSO || this.prior == Prior.RIDGE) ? "" : ", n0-gCount:" + i7) + (d2 != Double.NEGATIVE_INFINITY ? ", evalScore:" + d2 : ""));
                if (i4 >= max) {
                    sayln("Online Optimization complete.  Stopped after max iterations");
                    break;
                }
                if (timing.report() >= this.maxTime) {
                    sayln("Online Optimization complete.  Stopped after max time");
                    break;
                }
                i5++;
            } catch (ArrayMath.InvalidElementException e) {
                System.err.println(e.toString());
                for (int i18 = 0; i18 < this.x.length; i18++) {
                    this.x[i18] = Double.NaN;
                }
            }
        }
        if (this.evaluateIters > 0) {
            sayln("final evalScore is: " + (this.useEvalImprovement ? doEvaluation(this.xBest) : doEvaluation(this.x)));
        }
        sayln("Completed in: " + Timing.toSecondsString(timing.report()) + " s");
        return this.useEvalImprovement ? this.xBest : this.x;
    }

    protected void sayln(String str) {
        if (this.quiet) {
            return;
        }
        System.err.println(str);
    }

    protected void say(String str) {
        if (this.quiet) {
            return;
        }
        System.err.print(str);
    }
}
