package edu.stanford.nlp.classify;

import edu.stanford.nlp.math.ArrayMath;
import java.io.Serializable;

/* loaded from: input_file:edu/stanford/nlp/classify/LogPrior.class */
public class LogPrior implements Serializable {
    private static final long serialVersionUID = 7826853908892790965L;
    private double[] means;
    private LogPrior otherPrior;
    public final LogPriorType type;
    private double[] sigmaSqM;
    private double[] sigmaQuM;
    private double sigmaSq;
    private double sigmaQu;
    private double epsilon;

    /* loaded from: input_file:edu/stanford/nlp/classify/LogPrior$LogPriorType.class */
    public enum LogPriorType {
        NULL,
        QUADRATIC,
        HUBER,
        QUARTIC,
        COSH,
        ADAPT,
        MULTIPLE_QUADRATIC
    }

    public static LogPriorType getType(String str) {
        if (str.equalsIgnoreCase("null")) {
            return LogPriorType.NULL;
        }
        if (str.equalsIgnoreCase("quadratic")) {
            return LogPriorType.QUADRATIC;
        }
        if (str.equalsIgnoreCase("huber")) {
            return LogPriorType.HUBER;
        }
        if (str.equalsIgnoreCase("quartic")) {
            return LogPriorType.QUARTIC;
        }
        if (str.equalsIgnoreCase("cosh")) {
            return LogPriorType.COSH;
        }
        throw new RuntimeException("Unknown LogPriorType: " + str);
    }

    public static LogPrior getAdaptationPrior(double[] dArr, LogPrior logPrior) {
        LogPrior logPrior2 = new LogPrior(LogPriorType.ADAPT);
        logPrior2.means = dArr;
        logPrior2.otherPrior = logPrior;
        return logPrior2;
    }

    public LogPriorType getType() {
        return this.type;
    }

    public LogPrior() {
        this(LogPriorType.QUADRATIC);
    }

    public LogPrior(int i) {
        this(i, 1.0d, 0.1d);
    }

    public LogPrior(LogPriorType logPriorType) {
        this(logPriorType, 1.0d, 0.1d);
    }

    private static LogPriorType intToType(int i) {
        for (LogPriorType logPriorType : LogPriorType.values()) {
            if (logPriorType.ordinal() == i) {
                return logPriorType;
            }
        }
        throw new IllegalArgumentException(i + " is not a legal LogPrior.");
    }

    public LogPrior(int i, double d, double d2) {
        this(intToType(i), d, d2);
    }

    public LogPrior(LogPriorType logPriorType, double d, double d2) {
        this.means = null;
        this.otherPrior = null;
        this.sigmaSqM = null;
        this.sigmaQuM = null;
        this.type = logPriorType;
        if (logPriorType != LogPriorType.ADAPT) {
            setSigma(d);
            setEpsilon(d2);
        }
    }

    public LogPrior(double[] dArr) {
        this.means = null;
        this.otherPrior = null;
        this.sigmaSqM = null;
        this.sigmaQuM = null;
        this.type = LogPriorType.MULTIPLE_QUADRATIC;
        double[] dArr2 = new double[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr2[i] = 1.0d / dArr[i];
        }
        this.sigmaSqM = dArr2;
        setSigmaSquaredM(dArr2);
    }

    public double getSigma() {
        return this.type == LogPriorType.ADAPT ? this.otherPrior.getSigma() : Math.sqrt(this.sigmaSq);
    }

    public double getSigmaSquared() {
        return this.type == LogPriorType.ADAPT ? this.otherPrior.getSigmaSquared() : this.sigmaSq;
    }

    public double[] getSigmaSquaredM() {
        if (this.type == LogPriorType.MULTIPLE_QUADRATIC) {
            return this.sigmaSqM;
        }
        throw new RuntimeException("LogPrior.getSigmaSquaredM is undefined for any prior but MULTIPLE_QUADRATIC" + this);
    }

    public double getEpsilon() {
        return this.type == LogPriorType.ADAPT ? this.otherPrior.getEpsilon() : this.epsilon;
    }

    public void setSigma(double d) {
        if (this.type == LogPriorType.ADAPT) {
            this.otherPrior.setSigma(d);
        } else {
            this.sigmaSq = d * d;
            this.sigmaQu = this.sigmaSq * this.sigmaSq;
        }
    }

    public void setSigmaSquared(double d) {
        if (this.type == LogPriorType.ADAPT) {
            this.otherPrior.setSigmaSquared(d);
        } else {
            this.sigmaSq = d;
            this.sigmaQu = d * d;
        }
    }

    public void setSigmaSquaredM(double[] dArr) {
        if (this.type == LogPriorType.ADAPT) {
            this.otherPrior.setSigmaSquaredM(dArr);
        }
        if (this.type != LogPriorType.MULTIPLE_QUADRATIC) {
            throw new RuntimeException("LogPrior.getSigmaSquaredM is undefined for any prior but MULTIPLE_QUADRATIC" + this);
        }
        this.sigmaSqM = (double[]) dArr.clone();
        double[] dArr2 = new double[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr2[i] = this.sigmaSqM[i] * this.sigmaSqM[i];
        }
        this.sigmaQuM = dArr2;
    }

    public void setEpsilon(double d) {
        if (this.type == LogPriorType.ADAPT) {
            this.otherPrior.setEpsilon(d);
        } else {
            this.epsilon = d;
        }
    }

    public double computeStochastic(double[] dArr, double[] dArr2, double d) {
        if (this.type == LogPriorType.ADAPT) {
            return this.otherPrior.computeStochastic(ArrayMath.pairwiseSubtract(dArr, this.means), dArr2, d);
        }
        if (this.type != LogPriorType.MULTIPLE_QUADRATIC) {
            double sigmaSquared = getSigmaSquared();
            setSigmaSquared(sigmaSquared / d);
            double compute = compute(dArr, dArr2);
            setSigmaSquared(sigmaSquared);
            return compute;
        }
        double[] sigmaSquaredM = getSigmaSquaredM();
        double[] dArr3 = (double[]) sigmaSquaredM.clone();
        for (int i = 0; i < dArr.length; i++) {
            int i2 = i;
            dArr3[i2] = dArr3[i2] / d;
        }
        setSigmaSquaredM(dArr3);
        double compute2 = compute(dArr, dArr2);
        setSigmaSquaredM(sigmaSquaredM);
        return compute2;
    }

    public double compute(double[] dArr, double[] dArr2) {
        double log;
        double exp;
        double d = 0.0d;
        switch (this.type) {
            case NULL:
                return 0.0d;
            case QUADRATIC:
                for (int i = 0; i < dArr.length; i++) {
                    d += ((dArr[i] * dArr[i]) / 2.0d) / this.sigmaSq;
                    int i2 = i;
                    dArr2[i2] = dArr2[i2] + (dArr[i] / this.sigmaSq);
                }
                return d;
            case HUBER:
                for (int i3 = 0; i3 < dArr.length; i3++) {
                    if (dArr[i3] < (-this.epsilon)) {
                        d += ((-dArr[i3]) - (this.epsilon / 2.0d)) / this.sigmaSq;
                        int i4 = i3;
                        dArr2[i4] = dArr2[i4] + ((-1.0d) / this.sigmaSq);
                    } else if (dArr[i3] < this.epsilon) {
                        d += (((dArr[i3] * dArr[i3]) / 2.0d) / this.epsilon) / this.sigmaSq;
                        int i5 = i3;
                        dArr2[i5] = dArr2[i5] + ((dArr[i3] / this.epsilon) / this.sigmaSq);
                    } else {
                        d += (dArr[i3] - (this.epsilon / 2.0d)) / this.sigmaSq;
                        int i6 = i3;
                        dArr2[i6] = dArr2[i6] + (1.0d / this.sigmaSq);
                    }
                }
                return d;
            case QUARTIC:
                for (int i7 = 0; i7 < dArr.length; i7++) {
                    d += (((dArr[i7] * dArr[i7]) * (dArr[i7] * dArr[i7])) / 2.0d) / this.sigmaQu;
                    int i8 = i7;
                    dArr2[i8] = dArr2[i8] + (dArr[i7] / this.sigmaQu);
                }
                return d;
            case ADAPT:
                return 0.0d + this.otherPrior.compute(ArrayMath.pairwiseSubtract(dArr, this.means), dArr2);
            case COSH:
                double norm_1 = ArrayMath.norm_1(dArr) / this.sigmaSq;
                if (norm_1 > 30.0d) {
                    log = norm_1 - Math.log(2.0d);
                    exp = 1.0d / this.sigmaSq;
                } else {
                    log = Math.log(Math.cosh(norm_1));
                    exp = ((2.0d * (1.0d / (Math.exp((-2.0d) * norm_1) + 1.0d))) - 1.0d) / this.sigmaSq;
                }
                for (int i9 = 0; i9 < dArr.length; i9++) {
                    int i10 = i9;
                    dArr2[i10] = dArr2[i10] + (Math.signum(dArr[i9]) * exp);
                }
                return log;
            case MULTIPLE_QUADRATIC:
                for (int i11 = 0; i11 < dArr.length; i11++) {
                    d += ((dArr[i11] * dArr[i11]) / 2.0d) / this.sigmaSqM[i11];
                    int i12 = i11;
                    dArr2[i12] = dArr2[i12] + (dArr[i11] / this.sigmaSqM[i11]);
                }
                return d;
            default:
                throw new RuntimeException("LogPrior.valueAt is undefined for prior of type " + this);
        }
    }
}
