package hex;

import java.util.Arrays;
import water.H2O;
import water.Iced;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.PrettyPrint;

/* loaded from: input_file:hex/ConfusionMatrix2.class */
public class ConfusionMatrix2 extends Iced {
    public long[][] _arr;
    public final double[] _classErr;
    public double _predErr;

    /* loaded from: input_file:hex/ConfusionMatrix2$CM.class */
    private static class CM extends MRTask<CM> {
        final int _len;
        long[][] _arr;

        CM(int i) {
            this._len = i;
        }

        @Override // water.MRTask
        public void map(Chunk chunk, Chunk chunk2) {
            this._arr = new long[this._len][this._len];
            for (int i = 0; i < chunk._len; i++) {
                if (!chunk.isNA0(i)) {
                    long[] jArr = this._arr[(int) chunk.at80(i)];
                    int at80 = (int) chunk2.at80(i);
                    jArr[at80] = jArr[at80] + 1;
                }
            }
        }

        @Override // water.MRTask
        public void reduce(CM cm) {
            ArrayUtils.add(this._arr, cm._arr);
        }
    }

    /* loaded from: input_file:hex/ConfusionMatrix2$ErrMetric.class */
    public enum ErrMetric {
        MAXC,
        SUMC,
        TOTAL;

        public double computeErr(ConfusionMatrix2 confusionMatrix2) {
            switch (this) {
                case MAXC:
                    return ArrayUtils.maxValue(confusionMatrix2.classErr());
                case SUMC:
                    return ArrayUtils.sum(confusionMatrix2.classErr());
                case TOTAL:
                    return confusionMatrix2.err();
                default:
                    throw H2O.unimpl();
            }
        }
    }

    public ConfusionMatrix2(int i) {
        this._arr = new long[i][i];
        this._classErr = classErr();
        this._predErr = err();
    }

    public ConfusionMatrix2(long[][] jArr) {
        this._arr = jArr;
        this._classErr = classErr();
        this._predErr = err();
    }

    public ConfusionMatrix2(long[][] jArr, int i) {
        this._arr = new long[i][i];
        for (int i2 = 0; i2 < i; i2++) {
            System.arraycopy(jArr[i2], 0, this._arr[i2], 0, i);
        }
        this._classErr = classErr();
        this._predErr = err();
    }

    public ConfusionMatrix2(Vec vec, Frame frame) {
        this(new CM(vec.domain().length).doAll(vec, frame.vecs()[0])._arr);
    }

    public void add(int i, int i2) {
        long[] jArr = this._arr[i];
        jArr[i2] = jArr[i2] + 1;
    }

    public double[] classErr() {
        double[] dArr = new double[this._arr.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = classErr(i);
        }
        return dArr;
    }

    public final int size() {
        return this._arr.length;
    }

    public void reComputeErrors() {
        for (int i = 0; i < this._arr.length; i++) {
            this._classErr[i] = classErr(i);
        }
        this._predErr = err();
    }

    public final long classErrCount(int i) {
        return ArrayUtils.sum(this._arr[i]) - this._arr[i][i];
    }

    public final double classErr(int i) {
        long sum = ArrayUtils.sum(this._arr[i]);
        if (sum == 0) {
            return 0.0d;
        }
        return (sum - this._arr[i][i]) / sum;
    }

    public long totalRows() {
        long j = 0;
        for (long[] jArr : this._arr) {
            j += ArrayUtils.sum(jArr);
        }
        return j;
    }

    public void add(ConfusionMatrix2 confusionMatrix2) {
        ArrayUtils.add(this._arr, confusionMatrix2._arr);
    }

    public double err() {
        long j = totalRows();
        long j2 = j;
        for (int i = 0; i < this._arr.length; i++) {
            j2 -= this._arr[i][i];
        }
        return j2 / j;
    }

    public long errCount() {
        long j = totalRows();
        for (int i = 0; i < this._arr.length; i++) {
            j -= this._arr[i][i];
        }
        return j;
    }

    public double accuracy() {
        return 1.0d - err();
    }

    public double specificity() {
        if (!isBinary()) {
            throw new UnsupportedOperationException("specificity is only implemented for 2 class problems.");
        }
        double d = this._arr[0][0];
        return d / (d + this._arr[0][1]);
    }

    public double recall() {
        if (!isBinary()) {
            throw new UnsupportedOperationException("recall is only implemented for 2 class problems.");
        }
        double d = this._arr[1][1];
        return d / (d + this._arr[1][0]);
    }

    public double precision() {
        if (!isBinary()) {
            throw new UnsupportedOperationException("precision is only implemented for 2 class problems.");
        }
        double d = this._arr[1][1];
        return d / (d + this._arr[0][1]);
    }

    public double mcc() {
        if (!isBinary()) {
            throw new UnsupportedOperationException("precision is only implemented for 2 class problems.");
        }
        double d = this._arr[0][0];
        double d2 = this._arr[0][1];
        double d3 = this._arr[1][1];
        double d4 = this._arr[1][0];
        return ((d3 * d) - (d2 * d4)) / Math.sqrt((((d3 + d2) * (d3 + d4)) * (d + d2)) * (d + d4));
    }

    public double max_per_class_error() {
        int nclasses = nclasses();
        if (nclasses == 0) {
            throw new UnsupportedOperationException("max per class error is only defined for classification problems");
        }
        double classErr = classErr(0);
        for (int i = 1; i < nclasses; i++) {
            classErr = Math.max(classErr, classErr(i));
        }
        return classErr;
    }

    public final int nclasses() {
        if (this._arr == null) {
            return 0;
        }
        return this._arr.length;
    }

    public final boolean isBinary() {
        return nclasses() == 2;
    }

    public double F1() {
        double precision = precision();
        double recall = recall();
        return (2.0d * (precision * recall)) / (precision + recall);
    }

    public double F2() {
        double precision = precision();
        double recall = recall();
        return (5.0d * (precision * recall)) / ((4.0d * precision) + recall);
    }

    public double F0point5() {
        double precision = precision();
        double recall = recall();
        return (1.25d * (precision * recall)) / ((0.25d * precision) + recall);
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        for (long[] jArr : this._arr) {
            sb.append(Arrays.toString(jArr) + "\n");
        }
        return sb.toString();
    }

    public String toASCII(String[] strArr) {
        return PrettyPrint.printConfusionMatrix(new StringBuilder(), this._arr, strArr, false).toString();
    }
}
