package edu.stanford.nlp.ie.crf;

import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.sequences.SeqClassifierFlags;

/* loaded from: input_file:edu/stanford/nlp/ie/crf/NonLinearCliquePotentialFunction.class */
public class NonLinearCliquePotentialFunction implements CliquePotentialFunction {
    double[][] linearWeights;
    double[][] inputLayerWeights;
    double[][] outputLayerWeights;
    SeqClassifierFlags flags;

    private static double sigmoid(double d) {
        return 1.0d / (1.0d + Math.exp(-d));
    }

    public NonLinearCliquePotentialFunction(double[][] dArr, double[][] dArr2, double[][] dArr3, SeqClassifierFlags seqClassifierFlags) {
        this.linearWeights = dArr;
        this.inputLayerWeights = dArr2;
        this.outputLayerWeights = dArr3;
        this.flags = seqClassifierFlags;
    }

    public static double[] hiddenLayerOutput(double[][] dArr, int[] iArr, SeqClassifierFlags seqClassifierFlags, double[] dArr2) {
        int length = dArr.length;
        double[] dArr3 = new double[length];
        for (int i = 0; i < length; i++) {
            double[] dArr4 = dArr[i];
            double d = 0.0d;
            for (int i2 = 0; i2 < iArr.length; i2++) {
                double d2 = dArr4[iArr[i2]];
                if (dArr2 != null) {
                    d2 *= dArr2[i2];
                }
                d += d2;
            }
            dArr3[i] = d;
        }
        double[] dArr5 = new double[length];
        for (int i3 = 0; i3 < length; i3++) {
            if (!seqClassifierFlags.useHiddenLayer) {
                dArr5[i3] = dArr3[i3];
            } else if (seqClassifierFlags.useSigmoid) {
                dArr5[i3] = sigmoid(dArr3[i3]);
            } else {
                dArr5[i3] = Math.tanh(dArr3[i3]);
            }
        }
        return dArr5;
    }

    @Override // edu.stanford.nlp.ie.crf.CliquePotentialFunction
    public double computeCliquePotential(int i, int i2, int[] iArr, double[] dArr) {
        double d = 0.0d;
        if (i > 1) {
            for (int i3 : iArr) {
                d += this.linearWeights[i3][i2];
            }
        } else {
            double[] hiddenLayerOutput = hiddenLayerOutput(this.inputLayerWeights, iArr, this.flags, dArr);
            int length = this.inputLayerWeights.length / this.outputLayerWeights[0].length;
            if (this.flags.useOutputLayer) {
                double[] dArr2 = this.flags.tieOutputLayer ? this.outputLayerWeights[0] : this.outputLayerWeights[i2];
                if (this.flags.softmaxOutputLayer) {
                    dArr2 = ArrayMath.softmax(dArr2);
                }
                for (int i4 = 0; i4 < this.inputLayerWeights.length; i4++) {
                    if (!this.flags.sparseOutputLayer && !this.flags.tieOutputLayer) {
                        d += dArr2[i4] * hiddenLayerOutput[i4];
                    } else if (i4 % length == i2) {
                        d += dArr2[i4 / length] * hiddenLayerOutput[i4];
                    }
                }
            } else {
                d = hiddenLayerOutput[i2];
            }
        }
        return d;
    }
}
