package hex;

import hex.Model;
import hex.SupervisedModel.SupervisedOutput;
import hex.SupervisedModel.SupervisedParameters;
import water.Key;
import water.fvec.Chunk;
import water.util.MRUtils;
import water.util.ModelUtils;

/* loaded from: input_file:hex/SupervisedModel.class */
public abstract class SupervisedModel<M extends Model<M, P, O>, P extends SupervisedParameters, O extends SupervisedOutput> extends Model<M, P, O> {
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:hex/SupervisedModel$SupervisedOutput.class */
    public static abstract class SupervisedOutput extends Model.Output {
        public long[] _distribution;
        public float[] _priorClassDist;
        public float[] _modelClassDist;

        public SupervisedOutput() {
            this(null);
        }

        public SupervisedOutput(SupervisedModelBuilder supervisedModelBuilder) {
            super(supervisedModelBuilder);
            if (supervisedModelBuilder == null) {
                return;
            }
            this._names = supervisedModelBuilder._train.names();
            this._domains = supervisedModelBuilder._train.domains();
            if (supervisedModelBuilder.isClassifier()) {
                MRUtils.ClassDist doAll = new MRUtils.ClassDist(supervisedModelBuilder._nclass).doAll(supervisedModelBuilder._response);
                this._distribution = doAll.dist();
                this._priorClassDist = doAll.rel_dist();
            } else {
                this._distribution = new long[]{supervisedModelBuilder._train.numRows()};
                this._priorClassDist = new float[]{1.0f};
            }
            this._modelClassDist = this._priorClassDist;
        }

        @Override // hex.Model.Output
        public int nfeatures() {
            return this._names.length - 1;
        }

        @Override // hex.Model.Output
        public int nclasses() {
            return this._distribution.length;
        }

        @Override // hex.Model.Output
        public boolean isClassifier() {
            return nclasses() > 1;
        }

        @Override // hex.Model.Output
        public Model.ModelCategory getModelCategory() {
            return nclasses() == 1 ? Model.ModelCategory.Regression : nclasses() == 2 ? Model.ModelCategory.Binomial : Model.ModelCategory.Multinomial;
        }
    }

    /* loaded from: input_file:hex/SupervisedModel$SupervisedParameters.class */
    public static abstract class SupervisedParameters extends Model.Parameters {
        public String _response_column;
        public float[] _class_sampling_factors;
        public boolean _convert_to_enum = false;
        public boolean _balance_classes = false;
        public float _max_after_balance_size = 5.0f;
        public int _max_hit_ratio_k = 10;

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // hex.Model.Parameters
        public long checksum_impl() {
            return ((super.checksum_impl() ^ this._response_column.hashCode()) ^ (this._convert_to_enum ? 31 : 33)) ^ (this._balance_classes ? 37 : 39);
        }
    }

    public SupervisedModel(Key key, P p, O o) {
        super(key, p, o);
    }

    @Override // hex.Model
    public boolean isSupervised() {
        return true;
    }

    @Override // hex.Model
    public float[] score0(Chunk[] chunkArr, int i, double[] dArr, float[] fArr) {
        if (!$assertionsDisabled && chunkArr.length < ((SupervisedOutput) this._output)._names.length) {
            throw new AssertionError();
        }
        for (int i2 = 0; i2 < ((SupervisedOutput) this._output)._names.length - 1; i2++) {
            dArr[i2] = chunkArr[i2].atd(i);
        }
        float[] score0 = score0(dArr, fArr);
        if (((SupervisedOutput) this._output).isClassifier() && ((SupervisedOutput) this._output)._priorClassDist != null && ((SupervisedOutput) this._output)._modelClassDist != null) {
            ModelUtils.correctProbabilities(score0, ((SupervisedOutput) this._output)._priorClassDist, ((SupervisedOutput) this._output)._modelClassDist);
            score0[0] = ModelUtils.getPrediction(score0, dArr);
        }
        return score0;
    }

    static {
        $assertionsDisabled = !SupervisedModel.class.desiredAssertionStatus();
    }
}
