package hex;

import java.util.Arrays;
import water.Iced;
import water.Key;
import water.MRTask;
import water.Scope;
import water.fvec.Chunk;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.TwoDimTable;

/* loaded from: input_file:hex/ConfusionMatrix.class */
public class ConfusionMatrix extends Iced {
    public TwoDimTable _cmTable;
    public long[][] _arr;
    public final double[] _classErr = classErr();
    public double _predErr = err();
    public String[] _domain;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:hex/ConfusionMatrix$CMBuilder.class */
    private static class CMBuilder extends MRTask<CMBuilder> {
        final int _len;
        long[][] _arr;

        CMBuilder(int i) {
            this._len = i;
        }

        @Override // water.MRTask
        public void map(Chunk chunk, Chunk chunk2) {
            this._arr = new long[this._len][this._len];
            for (int i = 0; i < chunk._len; i++) {
                if (!chunk.isNA0(i)) {
                    long[] jArr = this._arr[(int) chunk.at80(i)];
                    int at80 = (int) chunk2.at80(i);
                    jArr[at80] = jArr[at80] + 1;
                }
            }
        }

        @Override // water.MRTask
        public void reduce(CMBuilder cMBuilder) {
            ArrayUtils.add(this._arr, cMBuilder._arr);
        }
    }

    public ConfusionMatrix(long[][] jArr, String[] strArr) {
        this._arr = jArr;
        this._domain = strArr;
    }

    public static ConfusionMatrix buildCM(Vec vec, Vec vec2) {
        if (!vec.isEnum()) {
            throw new IllegalArgumentException("actuals must be enum.");
        }
        if (!vec2.isEnum()) {
            throw new IllegalArgumentException("predictions must be enum.");
        }
        Scope.enter();
        try {
            ConfusionMatrix confusionMatrix = new ConfusionMatrix(new CMBuilder(vec.domain().length).doAll(vec, vec2.adaptTo(vec.domain()))._arr, vec.domain());
            Scope.exit(new Key[0]);
            return confusionMatrix;
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    public void add(int i, int i2) {
        long[] jArr = this._arr[i];
        jArr[i2] = jArr[i2] + 1;
    }

    public double[] classErr() {
        double[] dArr = new double[this._arr.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = classErr(i);
        }
        return dArr;
    }

    public final int size() {
        return this._arr.length;
    }

    public void reComputeErrors() {
        for (int i = 0; i < this._arr.length; i++) {
            this._classErr[i] = classErr(i);
        }
        this._predErr = err();
    }

    public final long classErrCount(int i) {
        return ArrayUtils.sum(this._arr[i]) - this._arr[i][i];
    }

    public final double classErr(int i) {
        long sum = ArrayUtils.sum(this._arr[i]);
        if (sum == 0) {
            return 0.0d;
        }
        return (sum - this._arr[i][i]) / sum;
    }

    public long totalRows() {
        long j = 0;
        for (long[] jArr : this._arr) {
            j += ArrayUtils.sum(jArr);
        }
        return j;
    }

    public void add(ConfusionMatrix confusionMatrix) {
        ArrayUtils.add(this._arr, confusionMatrix._arr);
    }

    public double err() {
        long j = totalRows();
        long j2 = j;
        for (int i = 0; i < this._arr.length; i++) {
            j2 -= this._arr[i][i];
        }
        return j2 / j;
    }

    public long errCount() {
        long j = totalRows();
        for (int i = 0; i < this._arr.length; i++) {
            j -= this._arr[i][i];
        }
        return j;
    }

    public double accuracy() {
        return 1.0d - err();
    }

    public double specificity() {
        if (!isBinary()) {
            throw new UnsupportedOperationException("specificity is only implemented for 2 class problems.");
        }
        double d = this._arr[0][0];
        return d / (d + this._arr[0][1]);
    }

    public double recall() {
        if (!isBinary()) {
            throw new UnsupportedOperationException("recall is only implemented for 2 class problems.");
        }
        double d = this._arr[1][1];
        return d / (d + this._arr[1][0]);
    }

    public double precision() {
        if (!isBinary()) {
            throw new UnsupportedOperationException("precision is only implemented for 2 class problems.");
        }
        double d = this._arr[1][1];
        return d / (d + this._arr[0][1]);
    }

    public double mcc() {
        if (!isBinary()) {
            throw new UnsupportedOperationException("precision is only implemented for 2 class problems.");
        }
        double d = this._arr[0][0];
        double d2 = this._arr[0][1];
        double d3 = this._arr[1][1];
        double d4 = this._arr[1][0];
        return ((d3 * d) - (d2 * d4)) / Math.sqrt((((d3 + d2) * (d3 + d4)) * (d + d2)) * (d + d4));
    }

    public double max_per_class_error() {
        int nclasses = nclasses();
        if (nclasses == 0) {
            throw new UnsupportedOperationException("max per class error is only defined for classification problems");
        }
        double classErr = classErr(0);
        for (int i = 1; i < nclasses; i++) {
            classErr = Math.max(classErr, classErr(i));
        }
        return classErr;
    }

    public final int nclasses() {
        if (this._arr == null) {
            return 0;
        }
        return this._arr.length;
    }

    public final boolean isBinary() {
        return nclasses() == 2;
    }

    public double F1() {
        double precision = precision();
        double recall = recall();
        return (2.0d * (precision * recall)) / (precision + recall);
    }

    public double F2() {
        double precision = precision();
        double recall = recall();
        return (5.0d * (precision * recall)) / ((4.0d * precision) + recall);
    }

    public double F0point5() {
        double precision = precision();
        double recall = recall();
        return (1.25d * (precision * recall)) / ((0.25d * precision) + recall);
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        for (long[] jArr : this._arr) {
            sb.append(Arrays.toString(jArr) + "\n");
        }
        return sb.toString();
    }

    private static String[] createConfusionMatrixHeader(long[] jArr, String[] strArr) {
        String[] strArr2 = new String[jArr.length];
        for (int i = 0; i < strArr.length; i++) {
            if (jArr[i] >= 0 || (strArr[i] != null && strArr[i].length() > 0 && !Integer.toString(i).equals(strArr[i]))) {
                strArr2[i] = strArr[i];
            }
        }
        if (strArr.length == jArr.length - 1 && jArr[jArr.length - 1] > 0) {
            strArr2[jArr.length - 1] = "NA";
        }
        return strArr2;
    }

    public String toASCII() {
        if (this._cmTable != null || this._domain == null) {
            return "";
        }
        this._cmTable = toTable();
        return this._cmTable.toString();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public TwoDimTable toTable() {
        if (!$assertionsDisabled && (this._arr == null || this._domain == null)) {
            throw new AssertionError();
        }
        for (int i = 0; i < this._arr.length; i++) {
            if (!$assertionsDisabled && this._arr.length != this._arr[i].length) {
                throw new AssertionError();
            }
        }
        long[] jArr = new long[this._arr.length];
        long[] jArr2 = new long[this._arr[0].length];
        for (int i2 = 0; i2 < this._arr.length; i2++) {
            long j = 0;
            for (int i3 = 0; i3 < this._arr[i2].length; i3++) {
                j += this._arr[i2][i3];
                int i4 = i3;
                jArr2[i4] = jArr2[i4] + this._arr[i2][i3];
            }
            jArr[i2] = j;
        }
        String[] createConfusionMatrixHeader = createConfusionMatrixHeader(jArr, this._domain);
        String[] createConfusionMatrixHeader2 = createConfusionMatrixHeader(jArr2, this._domain);
        if (!$assertionsDisabled && createConfusionMatrixHeader.length != createConfusionMatrixHeader2.length) {
            throw new AssertionError("The confusion matrix should have the same length for both directions.");
        }
        String[] strArr = new String[createConfusionMatrixHeader.length + 1];
        for (int i5 = 0; i5 < createConfusionMatrixHeader.length; i5++) {
            strArr[i5] = createConfusionMatrixHeader[i5];
        }
        strArr[createConfusionMatrixHeader.length] = "Totals";
        String[] strArr2 = new String[createConfusionMatrixHeader2.length + 2];
        for (int i6 = 0; i6 < createConfusionMatrixHeader2.length; i6++) {
            strArr2[i6] = createConfusionMatrixHeader2[i6];
        }
        strArr2[strArr2.length - 2] = "Error";
        strArr2[strArr2.length - 1] = "";
        String[] strArr3 = new String[strArr2.length];
        String[] strArr4 = new String[strArr2.length];
        for (int i7 = 0; i7 < strArr4.length - 1; i7++) {
            strArr3[i7] = "integer";
            strArr4[i7] = "%d";
        }
        strArr3[strArr4.length - 2] = "double";
        strArr4[strArr4.length - 2] = "%.4f";
        strArr3[strArr4.length - 1] = "string";
        strArr4[strArr4.length - 1] = "= %s";
        TwoDimTable twoDimTable = new TwoDimTable("Confusion Matrix (Act/Pred)", strArr, strArr2, strArr3, strArr4);
        long j2 = 0;
        for (int i8 = 0; i8 < this._arr.length; i8++) {
            if (createConfusionMatrixHeader[i8] != null) {
                long j3 = 0;
                for (int i9 = 0; i9 < createConfusionMatrixHeader2.length; i9++) {
                    if (createConfusionMatrixHeader2[i9] != null) {
                        if (createConfusionMatrixHeader[i8].equals(createConfusionMatrixHeader2[i9])) {
                            j3 = this._arr[i8][i9];
                        }
                        twoDimTable.set(i8, i9, this._arr[i8][i9]);
                    }
                }
                long j4 = jArr[i8] - j3;
                j2 += j4;
                twoDimTable.set(i8, createConfusionMatrixHeader2.length, j4 / jArr[i8]);
                twoDimTable.set(i8, createConfusionMatrixHeader2.length + 1, String.format("%,d / %d", Long.valueOf(j4), Long.valueOf(jArr[i8])));
            }
        }
        for (int i10 = 0; i10 < createConfusionMatrixHeader2.length; i10++) {
            if (createConfusionMatrixHeader2[i10] != null) {
                twoDimTable.set(createConfusionMatrixHeader.length, i10, jArr2[i10]);
            }
        }
        long j5 = 0;
        for (long j6 : jArr) {
            j5 += j6;
        }
        twoDimTable.set(createConfusionMatrixHeader.length, createConfusionMatrixHeader2.length, ((float) j2) / ((float) j5));
        twoDimTable.set(createConfusionMatrixHeader.length, createConfusionMatrixHeader2.length + 1, String.format("%,d / %,d", Long.valueOf(j2), Long.valueOf(j5)));
        return twoDimTable;
    }

    static {
        $assertionsDisabled = !ConfusionMatrix.class.desiredAssertionStatus();
    }
}
