package hex;

import hex.AUC;
import hex.Model;
import hex.Model.Output;
import hex.SupervisedModel.SupervisedParameters;
import water.Key;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.Log;
import water.util.ModelUtils;

/* loaded from: input_file:hex/SupervisedModel.class */
public abstract class SupervisedModel<M extends Model<M, P, O>, P extends SupervisedParameters, O extends Model.Output> extends Model<M, P, O> {
    protected float[] _priorClassDist;
    protected float[] _modelClassDist;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:hex/SupervisedModel$SupervisedParameters.class */
    public static abstract class SupervisedParameters extends Model.Parameters {
        public String _response_column;
        public transient Vec _response;
        public int _nclass;
        public boolean _classification;
        public int _ncols;
        public long _nrows;
        static final /* synthetic */ boolean $assertionsDisabled;

        @Override // hex.Model.Parameters
        public int sanityCheckParameters() {
            if (!$assertionsDisabled && this._train == null) {
                throw new AssertionError();
            }
            if (!$assertionsDisabled && this._response_column == null) {
                throw new AssertionError();
            }
            int find = train().find(this._response_column);
            if (find == -1) {
                validation_error("response_column", "Response column " + this._response_column + " not found in frame: " + this._train + ".");
            }
            this._response = train().vecs()[find];
            this._nclass = this._response.domain() == null ? 1 : this._response.domain().length;
            this._classification = this._response.isEnum();
            this._ncols = train().numCols();
            this._nrows = train().numRows() - this._response.naCnt();
            if (this._ncols <= 1) {
                validation_error("_training_frame", "Training data must have at least 2 features (incl. response).");
            }
            if (this._response.isBad()) {
                validation_error("_response_column", "Response column is all NAs!");
            }
            if (this._response.isConst()) {
                validation_error("_response_column", "Response column is constant!");
            }
            int i = 0;
            for (Vec vec : train().vecs()) {
                if (!vec.isBad() && !vec.isConst()) {
                    i++;
                }
            }
            if (i == 0) {
                throw new IllegalArgumentException("There is no usable column to generate model!");
            }
            return this._validation_error_count;
        }

        @Override // hex.Model.Parameters
        public long checksum() {
            return super.checksum() + this._response_column.hashCode();
        }

        static {
            $assertionsDisabled = !SupervisedModel.class.desiredAssertionStatus();
        }
    }

    public void setModelClassDistribution(float[] fArr) {
        this._modelClassDist = (float[]) fArr.clone();
    }

    @Override // hex.Model
    public boolean isSupervised() {
        return true;
    }

    public SupervisedModel(Key key, Frame frame, P p, O o, float[] fArr) {
        this(key, frame.names(), frame.domains(), p, o, fArr);
    }

    public SupervisedModel(Key key, String[] strArr, String[][] strArr2, P p, O o, float[] fArr) {
        super(key, strArr, strArr2, p, o);
        this._priorClassDist = fArr;
    }

    public double calcError(Frame frame, Vec vec, Frame frame2, Frame frame3, String str, boolean z, int i, ConfusionMatrix confusionMatrix, AUC auc, HitRatio hitRatio) {
        StringBuilder sb = new StringBuilder();
        double d = Double.POSITIVE_INFINITY;
        if (auc != null) {
            if (!$assertionsDisabled && !this._output.isClassifier()) {
                throw new AssertionError();
            }
            if (!$assertionsDisabled && this._output.nclasses() != 2) {
                throw new AssertionError();
            }
            auc.actual = frame;
            auc.vactual = vec;
            auc.predict = frame2;
            auc.vpredict = frame2.vecs()[2];
            auc.threshold_criterion = AUC.ThresholdCriterion.maximum_F1;
            auc.execImpl();
            d = auc.data().err();
        }
        if (confusionMatrix != null) {
            confusionMatrix.actual = frame;
            confusionMatrix.vactual = vec;
            confusionMatrix.predict = frame2;
            confusionMatrix.vpredict = frame2.vecs()[0];
            confusionMatrix.execImpl();
            if (this._output.isClassifier()) {
                if (auc != null) {
                    confusionMatrix.cm = new long[3][3];
                    confusionMatrix.cm[0][0] = auc.data().cm()[0][0];
                    confusionMatrix.cm[1][0] = auc.data().cm()[1][0];
                    confusionMatrix.cm[0][1] = auc.data().cm()[0][1];
                    confusionMatrix.cm[1][1] = auc.data().cm()[1][1];
                    if (!$assertionsDisabled && new ConfusionMatrix2(confusionMatrix.cm).err() != auc.data().err()) {
                        throw new AssertionError();
                    }
                } else {
                    d = new ConfusionMatrix2(confusionMatrix.cm).err();
                }
                if (confusionMatrix.cm.length <= i) {
                    confusionMatrix.toASCII(sb);
                }
            } else {
                if (!$assertionsDisabled && auc != null) {
                    throw new AssertionError();
                }
                d = confusionMatrix.mse;
                confusionMatrix.toASCII(sb);
            }
        }
        if (hitRatio != null) {
            if (!$assertionsDisabled && !this._output.isClassifier()) {
                throw new AssertionError();
            }
            hitRatio.actual = frame;
            hitRatio.vactual = vec;
            hitRatio.predict = frame3;
            hitRatio.execImpl();
            hitRatio.toASCII(sb);
        }
        if (z && sb.length() > 0) {
            Log.info("Scoring on " + str + " data:");
            for (String str2 : sb.toString().split("\n")) {
                Log.info(str2);
            }
        }
        return d;
    }

    @Override // hex.Model
    protected float[] score0(Chunk[] chunkArr, int i, double[] dArr, float[] fArr) {
        if (!$assertionsDisabled && chunkArr.length < this._output._names.length) {
            throw new AssertionError();
        }
        for (int i2 = 0; i2 < this._output._names.length - 1; i2++) {
            dArr[i2] = chunkArr[i2].at0(i);
        }
        float[] score0 = score0(dArr, fArr);
        if (this._output.isClassifier() && this._priorClassDist != null && this._modelClassDist != null) {
            if (!$assertionsDisabled && score0.length != this._output.nclasses() + 1) {
                throw new AssertionError();
            }
            double d = 0.0d;
            for (int i3 = 1; i3 < score0.length; i3++) {
                double d2 = this._priorClassDist[i3 - 1];
                if (!$assertionsDisabled && d2 <= 0.0d) {
                    throw new AssertionError();
                }
                double d3 = this._modelClassDist[i3 - 1];
                if (!$assertionsDisabled && d3 <= 0.0d) {
                    throw new AssertionError();
                }
                if (!$assertionsDisabled && Double.isNaN(score0[i3])) {
                    throw new AssertionError();
                }
                score0[i3] = (float) (score0[r1] * (d2 / d3));
                d += score0[i3];
            }
            for (int i4 = 1; i4 < score0.length; i4++) {
                score0[i4] = (float) (score0[r1] / d);
            }
            score0[0] = ModelUtils.getPrediction(score0, dArr);
        }
        return score0;
    }

    static {
        $assertionsDisabled = !SupervisedModel.class.desiredAssertionStatus();
    }
}
