package hex;

import hex.Model;
import water.DKV;
import water.Iced;
import water.Key;
import water.Keyed;
import water.Value;
import water.api.ModelMetricsBase;
import water.api.ModelMetricsV3;
import water.fvec.Frame;
import water.util.ArrayUtils;
import water.util.ModelUtils;

/* loaded from: input_file:hex/ModelMetrics.class */
public final class ModelMetrics extends Keyed {
    final Key _modelKey;
    final Key _frameKey;
    final Model.ModelCategory _model_category;
    final long _model_checksum;
    final long _frame_checksum;
    transient Model _model;
    transient Frame _frame;
    long duration_in_ms;
    long scoring_time;
    public final double _sigma;
    public final double _mse;
    public final AUCData _aucdata;
    public final ConfusionMatrix _cm;
    public final HitRatio _hr;

    /* loaded from: input_file:hex/ModelMetrics$MetricBuilder.class */
    public static final class MetricBuilder extends Iced {
        final String[] _domain;
        final int _nclasses;
        final float[] _thresholds;
        long[][][] _cms;
        public double _sumsqe;
        public transient float[] _work;
        static final /* synthetic */ boolean $assertionsDisabled;

        public MetricBuilder(String[] strArr) {
            this(strArr, new float[]{0.5f});
        }

        public MetricBuilder(String[] strArr, float[] fArr) {
            this._domain = strArr;
            int length = strArr == null ? 1 : strArr.length;
            this._nclasses = length;
            if (!$assertionsDisabled && ((length != 2 || fArr.length <= 0) && (length == 2 || fArr.length != 1))) {
                throw new AssertionError();
            }
            this._thresholds = fArr;
            this._cms = new long[fArr.length][length][length];
            this._work = new float[length + 1];
        }

        public float[] perRow(float[] fArr, float f) {
            float f2;
            if (!Float.isNaN(f) && !Float.isNaN(fArr[0])) {
                int length = fArr.length - 1;
                int i = (int) f;
                if (length > 1) {
                    float f3 = 0.0f;
                    for (int i2 = 1; i2 < fArr.length; i2++) {
                        if (!$assertionsDisabled && (0.0f > fArr[i2] || fArr[i2] > 1.0f)) {
                            throw new AssertionError();
                        }
                        f3 += fArr[i2];
                    }
                    if (!$assertionsDisabled && Math.abs(f3 - 1.0f) >= 1.0E-6d) {
                        throw new AssertionError();
                    }
                    f2 = 1.0f - fArr[i + 1];
                } else {
                    f2 = f - fArr[0];
                }
                this._sumsqe += f2 * f2;
                if (!$assertionsDisabled && Double.isNaN(this._sumsqe)) {
                    throw new AssertionError();
                }
                if (length == 1) {
                    long[] jArr = this._cms[0][0];
                    jArr[0] = jArr[0] + 1;
                } else if (length == 2) {
                    float f4 = fArr[2];
                    for (int i3 = 0; i3 < ModelUtils.DEFAULT_THRESHOLDS.length; i3++) {
                        boolean z = f4 >= ModelUtils.DEFAULT_THRESHOLDS[i3];
                        long[] jArr2 = this._cms[i3][i];
                        jArr2[z ? 1 : 0] = jArr2[z ? 1 : 0] + 1;
                    }
                } else {
                    long[] jArr3 = this._cms[0][i];
                    int i4 = (int) fArr[0];
                    jArr3[i4] = jArr3[i4] + 1;
                }
                return fArr;
            }
            return fArr;
        }

        public void reduce(MetricBuilder metricBuilder) {
            ArrayUtils.add(this._cms, metricBuilder._cms);
            this._sumsqe += metricBuilder._sumsqe;
        }

        public ModelMetrics makeModelMetrics(Model model, Frame frame, double d) {
            AUCData aUCData;
            ConfusionMatrix confusionMatrix;
            if (this._cms.length > 1) {
                ConfusionMatrix[] confusionMatrixArr = new ConfusionMatrix[this._cms.length];
                for (int i = 0; i < confusionMatrixArr.length; i++) {
                    confusionMatrixArr[i] = new ConfusionMatrix(this._cms[i], this._domain);
                }
                aUCData = new AUC(confusionMatrixArr, this._thresholds, this._domain).data();
                confusionMatrix = aUCData.CM();
            } else {
                aUCData = null;
                confusionMatrix = new ConfusionMatrix(this._cms[0], this._domain);
            }
            return model._output.addModelMetrics(new ModelMetrics(model, frame, aUCData, confusionMatrix, null, d, this._sumsqe / confusionMatrix.totalRows()));
        }

        static {
            $assertionsDisabled = !ModelMetrics.class.desiredAssertionStatus();
        }
    }

    public ModelMetrics(Model model, Frame frame) {
        this(model, frame, null, null, null, Double.NaN, Double.NaN);
    }

    private ModelMetrics(Model model, Frame frame, AUCData aUCData, ConfusionMatrix confusionMatrix, HitRatio hitRatio, double d, double d2) {
        super(buildKey(model, frame));
        this.duration_in_ms = -1L;
        this.scoring_time = -1L;
        this._modelKey = model._key;
        this._frameKey = frame._key;
        this._model_category = model._output.getModelCategory();
        this._model = model;
        this._frame = frame;
        this._model_checksum = model.checksum();
        this._frame_checksum = frame.checksum();
        this._sigma = d;
        this._mse = d2;
        this._aucdata = aUCData;
        this._cm = confusionMatrix;
        this._hr = hitRatio;
        DKV.put(this);
    }

    public Model model() {
        if (this._model != null) {
            return this._model;
        }
        Model model = (Model) DKV.getGet(this._modelKey);
        this._model = model;
        return model;
    }

    public Frame frame() {
        if (this._model != null) {
            return this._frame;
        }
        Frame frame = (Frame) DKV.getGet(this._frameKey);
        this._frame = frame;
        return frame;
    }

    public double r2() {
        return 1.0d - (this._mse / (this._sigma * this._sigma));
    }

    public ModelMetricsBase schema() {
        return new ModelMetricsV3();
    }

    private static Key buildKey(Key key, long j, Key key2, long j2) {
        return Key.make("modelmetrics_" + key + "@" + j + "_on_" + key2 + "@" + j2);
    }

    private static Key buildKey(Model model, Frame frame) {
        return buildKey(model._key, model.checksum(), frame._key, frame.checksum());
    }

    public boolean isForModel(Model model) {
        return this._model_checksum == model.checksum();
    }

    public boolean isForFrame(Frame frame) {
        return this._frame_checksum == frame.checksum();
    }

    public static ModelMetrics getFromDKV(Model model, Frame frame) {
        Value value = DKV.get(buildKey(model, frame));
        if (null == value) {
            return null;
        }
        return (ModelMetrics) value.get();
    }

    @Override // water.Keyed
    public long checksum() {
        return (this._frame_checksum * 13) + (this._model_checksum * 17);
    }
}
