package hex.genmodel.algos.gbm;

import hex.genmodel.GenModel;
import hex.genmodel.algos.deepwater.caffe.nano.Deepwater;
import hex.genmodel.algos.tree.SharedTreeGraphConverter;
import hex.genmodel.algos.tree.SharedTreeMojoModelWithContributions;
import hex.genmodel.algos.tree.TreeSHAPPredictor;
import hex.genmodel.utils.DistributionFamily;
import hex.genmodel.utils.LinkFunctionType;

/* loaded from: input_file:hex/genmodel/algos/gbm/GbmMojoModel.class */
public final class GbmMojoModel extends SharedTreeMojoModelWithContributions implements SharedTreeGraphConverter {
    public DistributionFamily _family;
    public LinkFunctionType _link_function;
    public double _init_f;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: hex.genmodel.algos.gbm.GbmMojoModel$1, reason: invalid class name */
    /* loaded from: input_file:hex/genmodel/algos/gbm/GbmMojoModel$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$hex$genmodel$utils$LinkFunctionType = new int[LinkFunctionType.values().length];

        static {
            try {
                $SwitchMap$hex$genmodel$utils$LinkFunctionType[LinkFunctionType.log.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$hex$genmodel$utils$LinkFunctionType[LinkFunctionType.logit.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$hex$genmodel$utils$LinkFunctionType[LinkFunctionType.ologit.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$hex$genmodel$utils$LinkFunctionType[LinkFunctionType.ologlog.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$hex$genmodel$utils$LinkFunctionType[LinkFunctionType.oprobit.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$hex$genmodel$utils$LinkFunctionType[LinkFunctionType.inverse.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$hex$genmodel$utils$LinkFunctionType[LinkFunctionType.identity.ordinal()] = 7;
            } catch (NoSuchFieldError e7) {
            }
        }
    }

    public GbmMojoModel(String[] strArr, String[][] strArr2, String str) {
        super(strArr, strArr2, str);
    }

    @Override // hex.genmodel.algos.tree.SharedTreeMojoModelWithContributions
    protected SharedTreeMojoModelWithContributions.ContributionsPredictor getContributionsPredictor(TreeSHAPPredictor<double[]> treeSHAPPredictor) {
        return new SharedTreeMojoModelWithContributions.ContributionsPredictor(treeSHAPPredictor);
    }

    @Override // hex.genmodel.algos.tree.SharedTreeMojoModelWithContributions
    public double getInitF() {
        return this._init_f;
    }

    @Override // hex.genmodel.GenModel
    public final double[] score0(double[] dArr, double d, double[] dArr2) {
        super.scoreAllTrees(dArr, dArr2);
        return unifyPreds(dArr, d, dArr2);
    }

    @Override // hex.genmodel.algos.tree.SharedTreeMojoModel
    public final double[] unifyPreds(double[] dArr, double d, double[] dArr2) {
        if (this._family == DistributionFamily.bernoulli || this._family == DistributionFamily.quasibinomial || this._family == DistributionFamily.modified_huber) {
            dArr2[2] = linkInv(this._link_function, dArr2[1] + this._init_f + d);
            dArr2[1] = 1.0d - dArr2[2];
        } else {
            if (this._family != DistributionFamily.multinomial) {
                dArr2[0] = linkInv(this._link_function, dArr2[0] + this._init_f + d);
                return dArr2;
            }
            if (this._nclasses == 2) {
                dArr2[1] = dArr2[1] + this._init_f + d;
                dArr2[2] = -dArr2[1];
            }
            GenModel.GBM_rescale(dArr2);
        }
        if (this._balanceClasses) {
            GenModel.correctProbabilities(dArr2, this._priorClassDistrib, this._modelClassDistrib);
        }
        dArr2[0] = GenModel.getPrediction(dArr2, this._priorClassDistrib, dArr, this._defaultThreshold);
        return dArr2;
    }

    private double linkInv(LinkFunctionType linkFunctionType, double d) {
        switch (AnonymousClass1.$SwitchMap$hex$genmodel$utils$LinkFunctionType[linkFunctionType.ordinal()]) {
            case Deepwater.Train /* 1 */:
                return exp(d);
            case Deepwater.Predict /* 2 */:
            case Deepwater.SaveGraph /* 3 */:
                return 1.0d / (1.0d + exp(-d));
            case Deepwater.Save /* 4 */:
                return 1.0d - exp((-1.0d) * exp(d));
            case Deepwater.Load /* 5 */:
                return 0.0d;
            case 6:
                return 1.0d / (d < 0.0d ? Math.min(-1.0E-5d, d) : Math.max(-1.0E-5d, d));
            case 7:
            default:
                return d;
        }
    }

    public static double exp(double d) {
        return Math.min(1.0E19d, Math.exp(d));
    }

    public static double log(double d) {
        double max = Math.max(0.0d, d);
        if (max == 0.0d) {
            return -19.0d;
        }
        return Math.max(-19.0d, Math.log(max));
    }

    @Override // hex.genmodel.GenModel
    public double[] score0(double[] dArr, double[] dArr2) {
        return score0(dArr, 0.0d, dArr2);
    }

    public String[] leaf_node_assignment(double[] dArr) {
        return getDecisionPath(dArr);
    }
}
