package water.api;

import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import water.ConfusionMatrix2;
import water.DKV;
import water.Iced;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;

/* loaded from: input_file:water/api/AUC.class */
public class AUC extends Iced {
    public Frame actual;
    public Vec vactual;
    public Frame predict;
    public Vec vpredict;
    private float[] thresholds;
    public ThresholdCriterion threshold_criterion;

    @API(help = "AUC Data", json = true)
    AUCData aucdata;

    /* loaded from: input_file:water/api/AUC$AUCTask.class */
    private static class AUCTask extends MRTask<AUCTask> {
        private ConfusionMatrix2[] _cms;
        double nullDev;
        double resDev;
        final double ymu;
        private final float[] _thresh;
        static final /* synthetic */ boolean $assertionsDisabled;

        /* JADX INFO: Access modifiers changed from: private */
        public final ConfusionMatrix2[] getCMs() {
            return this._cms;
        }

        AUCTask(float[] fArr, double d) {
            this._thresh = (float[]) fArr.clone();
            this.ymu = d;
        }

        static final double y_log_y(double d, double d2) {
            if (d == 0.0d) {
                return 0.0d;
            }
            if (d2 < Double.MIN_NORMAL) {
                d2 = Double.MIN_NORMAL;
            }
            return d * Math.log(d / d2);
        }

        public static double binomial_deviance(double d, double d2) {
            return 2.0d * (y_log_y(d, d2) + y_log_y(1.0d - d, 1.0d - d2));
        }

        @Override // water.MRTask
        public void map(Chunk chunk, Chunk chunk2) {
            this._cms = new ConfusionMatrix2[this._thresh.length];
            for (int i = 0; i < this._cms.length; i++) {
                this._cms[i] = new ConfusionMatrix2(2);
            }
            int min = Math.min(chunk.len(), chunk2.len());
            for (int i2 = 0; i2 < min; i2++) {
                if (chunk.isNA0(i2)) {
                    throw new UnsupportedOperationException("Actual class label cannot be a missing value!");
                }
                int at80 = (int) chunk.at80(i2);
                if (!$assertionsDisabled && at80 != 0 && at80 != 1) {
                    throw new AssertionError("Invalid values in vactual: must be binary (0 or 1).");
                }
                if (!chunk2.isNA0(i2)) {
                    double at0 = chunk2.at0(i2);
                    for (int i3 = 0; i3 < this._cms.length; i3++) {
                        this._cms[i3].add(at80, at0 >= ((double) this._thresh[i3]) ? 1 : 0);
                    }
                }
            }
        }

        @Override // water.MRTask
        public void reduce(AUCTask aUCTask) {
            for (int i = 0; i < this._cms.length; i++) {
                this._cms[i].add(aUCTask._cms[i]);
            }
            this.nullDev += aUCTask.nullDev;
            this.resDev += aUCTask.resDev;
        }

        @Override // water.MRTask
        public void postGlobal() {
        }

        static {
            $assertionsDisabled = !AUC.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:water/api/AUC$ThresholdCriterion.class */
    public enum ThresholdCriterion {
        maximum_F1,
        maximum_F2,
        maximum_F0point5,
        maximum_Accuracy,
        maximum_Precision,
        maximum_Recall,
        maximum_Specificity,
        maximum_absolute_MCC,
        minimizing_max_per_class_Error
    }

    public AUCData data() {
        return this.aucdata;
    }

    public AUC() {
        this.threshold_criterion = ThresholdCriterion.maximum_F1;
    }

    public AUC(ConfusionMatrix2[] confusionMatrix2Arr, float[] fArr) {
        this(confusionMatrix2Arr, fArr, null);
    }

    public AUC(ConfusionMatrix2[] confusionMatrix2Arr, float[] fArr, String[] strArr) {
        this.threshold_criterion = ThresholdCriterion.maximum_F1;
        this.aucdata = new AUCData().compute(confusionMatrix2Arr, fArr, strArr, this.threshold_criterion);
    }

    private void init() throws IllegalArgumentException {
        if (this.vactual == null || this.vpredict == null) {
            throw new IllegalArgumentException("Missing vactual or vpredict!");
        }
        if (this.vactual.length() != this.vpredict.length()) {
            throw new IllegalArgumentException("Both arguments must have the same length (" + this.vactual.length() + "!=" + this.vpredict.length() + ")!");
        }
        if (!this.vactual.isInt()) {
            throw new IllegalArgumentException("Actual column must be integer class labels!");
        }
        if (this.vactual.cardinality() != -1 && this.vactual.cardinality() != 2) {
            throw new IllegalArgumentException("Actual column must contain binary class labels, but found cardinality " + this.vactual.cardinality() + "!");
        }
        if (this.vpredict.isEnum()) {
            throw new IllegalArgumentException("vpredict cannot be class labels, expect probabilities.");
        }
    }

    public void execImpl() {
        init();
        Vec vec = null;
        try {
            Vec vec2 = this.vactual.toEnum();
            Vec vec3 = this.vpredict;
            if (!vec2.group().equals(vec3.group())) {
                vec3 = vec2.align(vec3);
            }
            if (this.thresholds != null) {
                Arrays.sort(this.thresholds);
                if (ArrayUtils.minValue(this.thresholds) < 0.0f) {
                    throw new IllegalArgumentException("Minimum threshold cannot be negative.");
                }
                if (ArrayUtils.maxValue(this.thresholds) > 1.0f) {
                    throw new IllegalArgumentException("Maximum threshold cannot be greater than 1.");
                }
            } else {
                HashSet hashSet = new HashSet();
                int min = (int) Math.min(this.vpredict.length(), 200L);
                long max = Math.max(this.vpredict.length() / min, 1L);
                for (int i = 0; i < min; i++) {
                    hashSet.add(new Float(this.vpredict.at(i * max)));
                }
                for (int i2 = 0; i2 < 51; i2++) {
                    hashSet.add(new Float(i2 / 50.0d));
                }
                this.thresholds = new float[hashSet.size()];
                int i3 = 0;
                Iterator it = hashSet.iterator();
                while (it.hasNext()) {
                    int i4 = i3;
                    i3++;
                    this.thresholds[i4] = ((Float) it.next()).floatValue();
                }
                Arrays.sort(this.thresholds);
            }
            this.aucdata = new AUCData().compute(new AUCTask(this.thresholds, vec2.mean()).doAll(vec2, vec3).getCMs(), this.thresholds, vec2.factors(), this.threshold_criterion);
            if (vec2 != null) {
                DKV.remove(vec2._key);
            }
        } catch (Throwable th) {
            if (0 != 0) {
                DKV.remove(vec._key);
            }
            throw th;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static boolean isBetter(ConfusionMatrix2 confusionMatrix2, ConfusionMatrix2 confusionMatrix22, ThresholdCriterion thresholdCriterion) {
        if (thresholdCriterion == ThresholdCriterion.maximum_F1) {
            return !Double.isNaN(confusionMatrix2.F1()) && (Double.isNaN(confusionMatrix22.F1()) || confusionMatrix2.F1() > confusionMatrix22.F1());
        }
        if (thresholdCriterion == ThresholdCriterion.maximum_F2) {
            return !Double.isNaN(confusionMatrix2.F2()) && (Double.isNaN(confusionMatrix22.F2()) || confusionMatrix2.F2() > confusionMatrix22.F2());
        }
        if (thresholdCriterion == ThresholdCriterion.maximum_F0point5) {
            return !Double.isNaN(confusionMatrix2.F0point5()) && (Double.isNaN(confusionMatrix22.F0point5()) || confusionMatrix2.F0point5() > confusionMatrix22.F0point5());
        }
        if (thresholdCriterion == ThresholdCriterion.maximum_Recall) {
            return !Double.isNaN(confusionMatrix2.recall()) && (Double.isNaN(confusionMatrix22.recall()) || confusionMatrix2.recall() > confusionMatrix22.recall());
        }
        if (thresholdCriterion == ThresholdCriterion.maximum_Precision) {
            return !Double.isNaN(confusionMatrix2.precision()) && (Double.isNaN(confusionMatrix22.precision()) || confusionMatrix2.precision() > confusionMatrix22.precision());
        }
        if (thresholdCriterion == ThresholdCriterion.maximum_Accuracy) {
            return confusionMatrix2.accuracy() > confusionMatrix22.accuracy();
        }
        if (thresholdCriterion == ThresholdCriterion.minimizing_max_per_class_Error) {
            return confusionMatrix2.max_per_class_error() < confusionMatrix22.max_per_class_error();
        }
        if (thresholdCriterion == ThresholdCriterion.maximum_Specificity) {
            return !Double.isNaN(confusionMatrix2.specificity()) && (Double.isNaN(confusionMatrix22.specificity()) || confusionMatrix2.specificity() > confusionMatrix22.specificity());
        }
        if (thresholdCriterion == ThresholdCriterion.maximum_absolute_MCC) {
            return !Double.isNaN(confusionMatrix2.mcc()) && (Double.isNaN(confusionMatrix22.mcc()) || Math.abs(confusionMatrix2.mcc()) > Math.abs(confusionMatrix22.mcc()));
        }
        throw new IllegalArgumentException("Unknown threshold criterion.");
    }

    public boolean toHTML(StringBuilder sb) {
        return this.aucdata.toHTML(sb);
    }

    public void toASCII(StringBuilder sb) {
        this.aucdata.toASCII(sb);
    }
}
