package mulan.classifier.neural;

import java.io.Serializable;
import java.util.Arrays;
import weka.core.Utils;
import weka.core.matrix.Matrix;

/* loaded from: input_file:mulan/classifier/neural/ThresholdFunction.class */
public class ThresholdFunction implements Serializable {
    private static final long serialVersionUID = 5347411552628371402L;
    private double[] parameters;

    public ThresholdFunction(double[][] dArr, double[][] dArr2) {
        build(dArr, dArr2);
    }

    public double computeThreshold(double[] dArr) {
        int length = this.parameters.length - 1;
        if (dArr.length != length) {
            throw new IllegalArgumentException("The array of label confidences has wrong dimension.The function expect parameters of length : " + length);
        }
        double d = 0.0d;
        for (int i = 0; i < length; i++) {
            d += dArr[i] * this.parameters[i];
        }
        return d + this.parameters[length];
    }

    public void build(double[][] dArr, double[][] dArr2) {
        if (dArr == null || dArr2 == null) {
            throw new IllegalArgumentException("Non of the input parameters can be null.");
        }
        int length = dArr.length;
        int length2 = dArr[0].length;
        if (dArr2.length != length || dArr2[0].length != length2) {
            throw new IllegalArgumentException("Matrix dimensions of input parameters does not agree.");
        }
        double[] dArr3 = new double[length];
        double[] dArr4 = new double[length2];
        double[] dArr5 = new double[length2];
        for (int i = 0; i < length; i++) {
            Arrays.fill(dArr4, Double.MAX_VALUE);
            Arrays.fill(dArr5, -1.7976931348623157E308d);
            for (int i2 = 0; i2 < length2; i2++) {
                if (dArr[i][i2] == 1.0d) {
                    dArr4[i2] = dArr2[i][i2];
                } else {
                    dArr5[i2] = dArr2[i][i2];
                }
            }
            double d = dArr4[Utils.minIndex(dArr4)];
            double d2 = dArr5[Utils.maxIndex(dArr5)];
            if (d == d2) {
                dArr3[i] = d;
            } else if (d == Double.MAX_VALUE) {
                dArr3[i] = d2 + 0.1d;
            } else if (d2 == -1.7976931348623157E308d) {
                dArr3[i] = d - 0.1d;
            } else {
                dArr3[i] = (d + d2) / 2.0d;
            }
        }
        Matrix matrix = new Matrix(length, length2 + 1, 1.0d);
        matrix.setMatrix(0, length - 1, 0, length2 - 1, new Matrix(dArr2));
        double[][] array = matrix.solve(new Matrix(dArr3, dArr3.length)).transpose().getArray();
        this.parameters = Arrays.copyOf(array[0], array[0].length);
    }

    protected double[] getFunctionParameters() {
        return Arrays.copyOf(this.parameters, this.parameters.length);
    }
}
