package edu.stanford.nlp.classify;

import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.optimization.AbstractCachingDiffFunction;
import edu.stanford.nlp.semgraph.semgrex.ssurgeon.AddNode;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.IntTriple;
import edu.stanford.nlp.util.IntTuple;
import edu.stanford.nlp.util.IntUni;
import java.util.Arrays;

/* loaded from: input_file:edu/stanford/nlp/classify/LogConditionalEqConstraintFunction.class */
public class LogConditionalEqConstraintFunction extends AbstractCachingDiffFunction {
    public static final int NO_PRIOR = 0;
    public static final int QUADRATIC_PRIOR = 1;
    public static final int HUBER_PRIOR = 2;
    public static final int QUARTIC_PRIOR = 3;
    protected int numFeatures;
    protected int numClasses;
    protected int[][] data;
    protected int[] labels;
    protected int[] numValues;
    private int prior;
    private double sigma;
    private double epsilon;
    private Index<IntTuple> featureIndex;

    @Override // edu.stanford.nlp.optimization.Function
    public int domainDimension() {
        return this.featureIndex.size();
    }

    int classOf(int i) {
        return this.featureIndex.get(i).get(0);
    }

    int featureOf(int i) {
        IntTuple intTuple = this.featureIndex.get(i);
        if (intTuple.length() == 1) {
            return -1;
        }
        return intTuple.get(1);
    }

    protected int indexOf(int i) {
        return this.featureIndex.indexOf(new IntUni(i));
    }

    protected int indexOf(int i, int i2, int i3) {
        return this.featureIndex.indexOf(new IntTriple(i2, i, i3));
    }

    protected Index<IntTuple> createIndex() {
        HashIndex hashIndex = new HashIndex();
        for (int i = 0; i < this.numClasses; i++) {
            hashIndex.add(new IntUni(i));
            for (int i2 = 0; i2 < this.numFeatures; i2++) {
                for (int i3 = 0; i3 < this.numValues[i2]; i3++) {
                    hashIndex.add(new IntTriple(i, i2, i3));
                }
            }
        }
        return hashIndex;
    }

    public double[][][] to3D(double[] dArr) {
        double[] normalize = normalize(dArr);
        double[][][] dArr2 = new double[this.numClasses][this.numFeatures];
        for (int i = 0; i < this.numClasses; i++) {
            for (int i2 = 0; i2 < this.numFeatures; i2++) {
                dArr2[i][i2] = new double[this.numValues[i2]];
                for (int i3 = 0; i3 < this.numValues[i2]; i3++) {
                    dArr2[i][i2][i3] = normalize[indexOf(i2, i, i3)];
                }
            }
        }
        return dArr2;
    }

    public double[] priors(double[] dArr) {
        double[] normalize = normalize(dArr);
        double[] dArr2 = new double[this.numClasses];
        for (int i = 0; i < this.numClasses; i++) {
            dArr2[i] = normalize[indexOf(i)];
        }
        return dArr2;
    }

    private double[] normalize(double[] dArr) {
        double[] dArr2 = new double[dArr.length];
        copy(dArr2, dArr);
        double[] dArr3 = new double[this.numClasses];
        for (int i = 0; i < this.numClasses; i++) {
            int i2 = i;
            dArr3[i2] = dArr3[i2] + dArr[indexOf(i)];
        }
        double logSum = ArrayMath.logSum(dArr3);
        for (int i3 = 0; i3 < this.numClasses; i3++) {
            int indexOf = indexOf(i3);
            dArr2[indexOf] = dArr2[indexOf] - logSum;
        }
        for (int i4 = 0; i4 < this.numClasses; i4++) {
            for (int i5 = 0; i5 < this.numFeatures; i5++) {
                double[] dArr4 = new double[this.numValues[i5]];
                for (int i6 = 0; i6 < this.numValues[i5]; i6++) {
                    dArr4[i6] = dArr[indexOf(i5, i4, i6)];
                }
                double logSum2 = ArrayMath.logSum(dArr4);
                for (int i7 = 0; i7 < this.numValues[i5]; i7++) {
                    int indexOf2 = indexOf(i5, i4, i7);
                    dArr2[indexOf2] = dArr2[indexOf2] - logSum2;
                }
            }
        }
        return dArr2;
    }

    @Override // edu.stanford.nlp.optimization.AbstractCachingDiffFunction
    protected void calculate(double[] dArr) {
        double[] normalize = normalize(dArr);
        double[] dArr2 = new double[normalize.length];
        for (int i = 0; i < normalize.length; i++) {
            dArr2[i] = Math.exp(normalize[i]);
        }
        this.value = 0.0d;
        Arrays.fill(this.derivative, 0.0d);
        double[] dArr3 = new double[this.numClasses];
        double[] dArr4 = new double[this.numClasses];
        for (int i2 = 0; i2 < this.data.length; i2++) {
            int[] iArr = this.data[i2];
            Arrays.fill(dArr3, 0.0d);
            for (int i3 = 0; i3 < this.numClasses; i3++) {
                int i4 = i3;
                dArr3[i4] = dArr3[i4] + normalize[indexOf(i3)];
                for (int i5 = 0; i5 < iArr.length; i5++) {
                    int i6 = i3;
                    dArr3[i6] = dArr3[i6] + normalize[indexOf(i5, i3, iArr[i5])];
                }
            }
            double logSum = ArrayMath.logSum(dArr3);
            for (int i7 = 0; i7 < this.numClasses; i7++) {
                dArr4[i7] = Math.exp(dArr3[i7] - logSum);
                int indexOf = indexOf(i7);
                double[] dArr5 = this.derivative;
                dArr5[indexOf] = dArr5[indexOf] + dArr4[i7];
                for (int i8 = 0; i8 < iArr.length; i8++) {
                    for (int i9 = 0; i9 < this.numValues[i8]; i9++) {
                        int indexOf2 = indexOf(i8, i7, i9);
                        double d = dArr2[indexOf2];
                        double[] dArr6 = this.derivative;
                        dArr6[indexOf2] = dArr6[indexOf2] - (dArr4[i7] * d);
                        if (this.labels[i2] == i7) {
                            double[] dArr7 = this.derivative;
                            dArr7[indexOf2] = dArr7[indexOf2] + d;
                        }
                    }
                }
            }
            for (int i10 = 0; i10 < iArr.length; i10++) {
                int indexOf3 = indexOf(i10, this.labels[i2], iArr[i10]);
                double[] dArr8 = this.derivative;
                dArr8[indexOf3] = dArr8[indexOf3] - 1.0d;
                for (int i11 = 0; i11 < this.numClasses; i11++) {
                    int indexOf4 = indexOf(i10, i11, iArr[i10]);
                    double[] dArr9 = this.derivative;
                    dArr9[indexOf4] = dArr9[indexOf4] + dArr4[i11];
                }
            }
            this.value -= dArr3[this.labels[i2]] - logSum;
            int indexOf5 = indexOf(this.labels[i2]);
            double[] dArr10 = this.derivative;
            dArr10[indexOf5] = dArr10[indexOf5] - 1.0d;
        }
        if (this.prior == 1) {
            double d2 = this.sigma * this.sigma;
            for (int i12 = 0; i12 < dArr.length; i12++) {
                double d3 = dArr[i12];
                this.value += (((1.0d * d3) * d3) / 2.0d) / d2;
                double[] dArr11 = this.derivative;
                int i13 = i12;
                dArr11[i13] = dArr11[i13] + ((1.0d * d3) / d2);
            }
            return;
        }
        if (this.prior != 2) {
            if (this.prior == 3) {
                double d4 = this.sigma * this.sigma * this.sigma * this.sigma;
                for (int i14 = 0; i14 < normalize.length; i14++) {
                    double d5 = dArr[i14];
                    this.value += (((((1.0d * d5) * d5) * d5) * d5) / 2.0d) / d4;
                    double[] dArr12 = this.derivative;
                    int i15 = i14;
                    dArr12[i15] = dArr12[i15] + ((1.0d * d5) / d4);
                }
                return;
            }
            return;
        }
        double d6 = this.sigma * this.sigma;
        for (int i16 = 0; i16 < dArr.length; i16++) {
            double d7 = dArr[i16];
            double abs = Math.abs(d7);
            if (abs < this.epsilon) {
                this.value += (((d7 * d7) / 2.0d) / this.epsilon) / d6;
                double[] dArr13 = this.derivative;
                int i17 = i16;
                dArr13[i17] = dArr13[i17] + ((d7 / this.epsilon) / d6);
            } else {
                this.value += (abs - (this.epsilon / 2.0d)) / d6;
                double[] dArr14 = this.derivative;
                int i18 = i16;
                dArr14[i18] = dArr14[i18] + ((d7 < 0.0d ? -1.0d : 1.0d) / d6);
            }
        }
    }

    public LogConditionalEqConstraintFunction(int i, int i2, int[][] iArr, int[] iArr2) {
        this(i, i2, iArr, iArr2, 1.0d);
    }

    public LogConditionalEqConstraintFunction(int i, int i2, int[][] iArr, int[] iArr2, double d) {
        this(i, i2, iArr, iArr2, 1, d, 0.0d);
    }

    public LogConditionalEqConstraintFunction(int i, int i2, int[][] iArr, int[] iArr2, int i3, double d, double d2) {
        this.numFeatures = 0;
        this.numClasses = 0;
        this.data = (int[][]) null;
        this.labels = null;
        this.numValues = null;
        this.sigma = 1.0d;
        this.numFeatures = i;
        this.numClasses = i2;
        this.data = iArr;
        this.labels = iArr2;
        if (i3 < 0 || i3 > 3) {
            throw new IllegalArgumentException("Invalid prior: " + i3);
        }
        this.prior = i3;
        this.epsilon = d2;
        this.sigma = d;
        this.numValues = NaiveBayesClassifierFactory.numberValues(iArr, i);
        for (int i4 = 0; i4 < this.numValues.length; i4++) {
            System.out.println("numValues " + i4 + AddNode.ATOM_DELIMITER + this.numValues[i4]);
        }
        this.featureIndex = createIndex();
    }

    @Override // edu.stanford.nlp.optimization.AbstractCachingDiffFunction, edu.stanford.nlp.optimization.HasInitial
    public double[] initial() {
        double[] dArr = new double[domainDimension()];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = Math.random() - 0.5d;
        }
        return dArr;
    }
}
