package org.apache.spark.ml.classification;

import org.apache.spark.annotation.DeveloperApi;
import org.apache.spark.ml.classification.ClassifierParams;
import org.apache.spark.ml.classification.ProbabilisticClassificationModel;
import org.apache.spark.ml.classification.ProbabilisticClassifierParams;
import org.apache.spark.ml.param.DoubleArrayParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.shared.HasProbabilityCol;
import org.apache.spark.ml.param.shared.HasThresholds;
import org.apache.spark.ml.param.shared.HasThresholds$$anonfun$2;
import org.apache.spark.mllib.linalg.DenseVector;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors$;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.StructType;
import scala.Array$;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.immutable.StringOps;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.reflect.api.Mirror;
import scala.reflect.api.TypeCreator;
import scala.reflect.api.Types;
import scala.reflect.api.Universe;
import scala.reflect.runtime.package$;

/* compiled from: ProbabilisticClassifier.scala */
@DeveloperApi
@ScalaSignature(bytes = "\u0006\u0001\u00055b!B\u0001\u0003\u0003\u0003i!\u0001\t)s_\n\f'-\u001b7jgRL7m\u00117bgNLg-[2bi&|g.T8eK2T!a\u0001\u0003\u0002\u001d\rd\u0017m]:jM&\u001c\u0017\r^5p]*\u0011QAB\u0001\u0003[2T!a\u0002\u0005\u0002\u000bM\u0004\u0018M]6\u000b\u0005%Q\u0011AB1qC\u000eDWMC\u0001\f\u0003\ry'oZ\u0002\u0001+\rqQCI\n\u0004\u0001=1\u0003\u0003\u0002\t\u0012'\u0005j\u0011AA\u0005\u0003%\t\u00111c\u00117bgNLg-[2bi&|g.T8eK2\u0004\"\u0001F\u000b\r\u0001\u0011)a\u0003\u0001b\u0001/\taa)Z1ukJ,7\u000fV=qKF\u0011\u0001D\b\t\u00033qi\u0011A\u0007\u0006\u00027\u0005)1oY1mC&\u0011QD\u0007\u0002\b\u001d>$\b.\u001b8h!\tIr$\u0003\u0002!5\t\u0019\u0011I\\=\u0011\u0005Q\u0011C!B\u0012\u0001\u0005\u0004!#!A'\u0012\u0005a)\u0003\u0003\u0002\t\u0001'\u0005\u0002\"\u0001E\u0014\n\u0005!\u0012!!\b)s_\n\f'-\u001b7jgRL7m\u00117bgNLg-[3s!\u0006\u0014\u0018-\\:\t\u000b)\u0002A\u0011A\u0016\u0002\rqJg.\u001b;?)\u0005)\u0003\"B\u0017\u0001\t\u0003q\u0013!E:fiB\u0013xNY1cS2LG/_\"pYR\u0011\u0011e\f\u0005\u0006a1\u0002\r!M\u0001\u0006m\u0006dW/\u001a\t\u0003eUr!!G\u001a\n\u0005QR\u0012A\u0002)sK\u0012,g-\u0003\u00027o\t11\u000b\u001e:j]\u001eT!\u0001\u000e\u000e\t\u000be\u0002A\u0011\u0001\u001e\u0002\u001bM,G\u000f\u00165sKNDw\u000e\u001c3t)\t\t3\bC\u00031q\u0001\u0007A\bE\u0002\u001a{}J!A\u0010\u000e\u0003\u000b\u0005\u0013(/Y=\u0011\u0005e\u0001\u0015BA!\u001b\u0005\u0019!u.\u001e2mK\")1\t\u0001C!\t\u0006IAO]1og\u001a|'/\u001c\u000b\u0003\u000b.\u0003\"AR%\u000e\u0003\u001dS!\u0001\u0013\u0004\u0002\u0007M\fH.\u0003\u0002K\u000f\nIA)\u0019;b\rJ\fW.\u001a\u0005\u0006\u0019\n\u0003\r!R\u0001\bI\u0006$\u0018m]3u\u0011\u0015q\u0005A\"\u0005P\u0003Y\u0011\u0018m\u001e\u001aqe>\u0014\u0017MY5mSRL\u0018J\u001c)mC\u000e,GC\u0001)Y!\t\tf+D\u0001S\u0015\t\u0019F+\u0001\u0004mS:\fGn\u001a\u0006\u0003+\u001a\tQ!\u001c7mS\nL!a\u0016*\u0003\rY+7\r^8s\u0011\u0015IV\n1\u0001Q\u00035\u0011\u0018m\u001e)sK\u0012L7\r^5p]\")1\f\u0001C\t9\u0006y!/Y<3aJ|'-\u00192jY&$\u0018\u0010\u0006\u0002Q;\")\u0011L\u0017a\u0001!\")q\f\u0001C)A\u0006q!/Y<3aJ,G-[2uS>tGCA b\u0011\u0015If\f1\u0001Q\u0011\u0015\u0019\u0007\u0001\"\u0005e\u0003I\u0001(/\u001a3jGR\u0004&o\u001c2bE&d\u0017\u000e^=\u0015\u0005A+\u0007\"\u00024c\u0001\u0004\u0019\u0012\u0001\u00034fCR,(/Z:\t\u000b!\u0004A\u0011C5\u0002-A\u0014xNY1cS2LG/\u001f\u001aqe\u0016$\u0017n\u0019;j_:$\"a\u00106\t\u000b-<\u0007\u0019\u0001)\u0002\u0017A\u0014xNY1cS2LG/\u001f\u0015\u0003\u00015\u0004\"A\\9\u000e\u0003=T!\u0001\u001d\u0004\u0002\u0015\u0005tgn\u001c;bi&|g.\u0003\u0002s_\naA)\u001a<fY>\u0004XM]!qS\u001e1AO\u0001E\u0001\tU\f\u0001\u0005\u0015:pE\u0006\u0014\u0017\u000e\\5ti&\u001c7\t\\1tg&4\u0017nY1uS>tWj\u001c3fYB\u0011\u0001C\u001e\u0004\u0007\u0003\tA\t\u0001B<\u0014\u0007YD8\u0010\u0005\u0002\u001as&\u0011!P\u0007\u0002\u0007\u0003:L(+\u001a4\u0011\u0005ea\u0018BA?\u001b\u00051\u0019VM]5bY&T\u0018M\u00197f\u0011\u0015Qc\u000f\"\u0001��)\u0005)\bbBA\u0002m\u0012\u0005\u0011QA\u0001 ]>\u0014X.\u00197ju\u0016$v\u000e\u0015:pE\u0006\u0014\u0017\u000e\\5uS\u0016\u001c\u0018J\u001c)mC\u000e,G\u0003BA\u0004\u0003\u001b\u00012!GA\u0005\u0013\r\tYA\u0007\u0002\u0005+:LG\u000f\u0003\u0005\u0002\u0010\u0005\u0005\u0001\u0019AA\t\u0003\u00051\bcA)\u0002\u0014%\u0019\u0011Q\u0003*\u0003\u0017\u0011+gn]3WK\u000e$xN\u001d\u0005\n\u000331\u0018\u0011!C\u0005\u00037\t1B]3bIJ+7o\u001c7wKR\u0011\u0011Q\u0004\t\u0005\u0003?\tI#\u0004\u0002\u0002\")!\u00111EA\u0013\u0003\u0011a\u0017M\\4\u000b\u0005\u0005\u001d\u0012\u0001\u00026bm\u0006LA!a\u000b\u0002\"\t1qJ\u00196fGR\u0004")
/* loaded from: input_file:org/apache/spark/ml/classification/ProbabilisticClassificationModel.class */
public abstract class ProbabilisticClassificationModel<FeaturesType, M extends ProbabilisticClassificationModel<FeaturesType, M>> extends ClassificationModel<FeaturesType, M> implements ProbabilisticClassifierParams {
    private final DoubleArrayParam thresholds;
    private final Param<String> probabilityCol;

    public static void normalizeToProbabilitiesInPlace(DenseVector denseVector) {
        ProbabilisticClassificationModel$.MODULE$.normalizeToProbabilitiesInPlace(denseVector);
    }

    @Override // org.apache.spark.ml.classification.ProbabilisticClassifierParams
    public /* synthetic */ StructType org$apache$spark$ml$classification$ProbabilisticClassifierParams$$super$validateAndTransformSchema(StructType structType, boolean z, DataType dataType) {
        return ClassifierParams.Cclass.validateAndTransformSchema(this, structType, z, dataType);
    }

    @Override // org.apache.spark.ml.classification.ClassificationModel, org.apache.spark.ml.PredictionModel, org.apache.spark.ml.PredictorParams
    public StructType validateAndTransformSchema(StructType structType, boolean z, DataType dataType) {
        return ProbabilisticClassifierParams.Cclass.validateAndTransformSchema(this, structType, z, dataType);
    }

    @Override // org.apache.spark.ml.param.shared.HasThresholds
    public final DoubleArrayParam thresholds() {
        return this.thresholds;
    }

    @Override // org.apache.spark.ml.param.shared.HasThresholds
    public final void org$apache$spark$ml$param$shared$HasThresholds$_setter_$thresholds_$eq(DoubleArrayParam doubleArrayParam) {
        this.thresholds = doubleArrayParam;
    }

    public double[] getThresholds() {
        return HasThresholds.Cclass.getThresholds(this);
    }

    @Override // org.apache.spark.ml.param.shared.HasProbabilityCol
    public final Param<String> probabilityCol() {
        return this.probabilityCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasProbabilityCol
    public final void org$apache$spark$ml$param$shared$HasProbabilityCol$_setter_$probabilityCol_$eq(Param param) {
        this.probabilityCol = param;
    }

    @Override // org.apache.spark.ml.param.shared.HasProbabilityCol
    public final String getProbabilityCol() {
        return HasProbabilityCol.Cclass.getProbabilityCol(this);
    }

    public M setProbabilityCol(String str) {
        return (M) set((Param<Param>) probabilityCol(), (Param) str);
    }

    public M setThresholds(double[] dArr) {
        return (M) set((Param<DoubleArrayParam>) thresholds(), (DoubleArrayParam) dArr);
    }

    @Override // org.apache.spark.ml.classification.ClassificationModel, org.apache.spark.ml.PredictionModel, org.apache.spark.ml.Transformer
    public DataFrame transform(DataFrame dataFrame) {
        transformSchema(dataFrame.schema(), true);
        if (isDefined(thresholds())) {
            Predef$.MODULE$.require(((double[]) $(thresholds())).length == numClasses(), new ProbabilisticClassificationModel$$anonfun$transform$1(this));
        }
        DataFrame dataFrame2 = dataFrame;
        int i = 0;
        if (new StringOps(Predef$.MODULE$.augmentString((String) $(rawPredictionCol()))).nonEmpty()) {
            dataFrame2 = dataFrame2.withColumn(getRawPredictionCol(), functions$.MODULE$.udf(new ProbabilisticClassificationModel$$anonfun$1(this), package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(ProbabilisticClassificationModel.class.getClassLoader()), new TypeCreator(this) { // from class: org.apache.spark.ml.classification.ProbabilisticClassificationModel$$typecreator1$1
                public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                    mirror.universe();
                    return mirror.staticClass("org.apache.spark.mllib.linalg.Vector").asType().toTypeConstructor();
                }
            }), package$.MODULE$.universe().TypeTag().Any()).apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(getFeaturesCol())})));
            i = 0 + 1;
        }
        if (new StringOps(Predef$.MODULE$.augmentString((String) $(probabilityCol()))).nonEmpty()) {
            dataFrame2 = dataFrame2.withColumn((String) $(probabilityCol()), new StringOps(Predef$.MODULE$.augmentString((String) $(rawPredictionCol()))).nonEmpty() ? functions$.MODULE$.udf(new ProbabilisticClassificationModel$$anonfun$2(this), package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(ProbabilisticClassificationModel.class.getClassLoader()), new TypeCreator(this) { // from class: org.apache.spark.ml.classification.ProbabilisticClassificationModel$$typecreator2$1
                public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                    mirror.universe();
                    return mirror.staticClass("org.apache.spark.mllib.linalg.Vector").asType().toTypeConstructor();
                }
            }), package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(ProbabilisticClassificationModel.class.getClassLoader()), new TypeCreator(this) { // from class: org.apache.spark.ml.classification.ProbabilisticClassificationModel$$typecreator3$1
                public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                    mirror.universe();
                    return mirror.staticClass("org.apache.spark.mllib.linalg.Vector").asType().toTypeConstructor();
                }
            })).apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col((String) $(rawPredictionCol()))})) : functions$.MODULE$.udf(new ProbabilisticClassificationModel$$anonfun$3(this), package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(ProbabilisticClassificationModel.class.getClassLoader()), new TypeCreator(this) { // from class: org.apache.spark.ml.classification.ProbabilisticClassificationModel$$typecreator4$1
                public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                    mirror.universe();
                    return mirror.staticClass("org.apache.spark.mllib.linalg.Vector").asType().toTypeConstructor();
                }
            }), package$.MODULE$.universe().TypeTag().Any()).apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col((String) $(featuresCol()))})));
            i++;
        }
        if (new StringOps(Predef$.MODULE$.augmentString((String) $(predictionCol()))).nonEmpty()) {
            dataFrame2 = dataFrame2.withColumn((String) $(predictionCol()), new StringOps(Predef$.MODULE$.augmentString((String) $(rawPredictionCol()))).nonEmpty() ? functions$.MODULE$.udf(new ProbabilisticClassificationModel$$anonfun$4(this), package$.MODULE$.universe().TypeTag().Double(), package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(ProbabilisticClassificationModel.class.getClassLoader()), new TypeCreator(this) { // from class: org.apache.spark.ml.classification.ProbabilisticClassificationModel$$typecreator5$1
                public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                    mirror.universe();
                    return mirror.staticClass("org.apache.spark.mllib.linalg.Vector").asType().toTypeConstructor();
                }
            })).apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col((String) $(rawPredictionCol()))})) : new StringOps(Predef$.MODULE$.augmentString((String) $(probabilityCol()))).nonEmpty() ? functions$.MODULE$.udf(new ProbabilisticClassificationModel$$anonfun$5(this), package$.MODULE$.universe().TypeTag().Double(), package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(ProbabilisticClassificationModel.class.getClassLoader()), new TypeCreator(this) { // from class: org.apache.spark.ml.classification.ProbabilisticClassificationModel$$typecreator6$1
                public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                    mirror.universe();
                    return mirror.staticClass("org.apache.spark.mllib.linalg.Vector").asType().toTypeConstructor();
                }
            })).apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col((String) $(probabilityCol()))})) : functions$.MODULE$.udf(new ProbabilisticClassificationModel$$anonfun$6(this), package$.MODULE$.universe().TypeTag().Double(), package$.MODULE$.universe().TypeTag().Any()).apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col((String) $(featuresCol()))})));
            i++;
        }
        if (i == 0) {
            logWarning(new ProbabilisticClassificationModel$$anonfun$transform$2(this));
        }
        return dataFrame2;
    }

    public abstract Vector raw2probabilityInPlace(Vector vector);

    public Vector raw2probability(Vector vector) {
        return raw2probabilityInPlace(vector.copy());
    }

    @Override // org.apache.spark.ml.classification.ClassificationModel
    public double raw2prediction(Vector vector) {
        return isDefined(thresholds()) ? probability2prediction(raw2probability(vector)) : vector.argmax();
    }

    public Vector predictProbability(FeaturesType featurestype) {
        return raw2probabilityInPlace(predictRaw(featurestype));
    }

    public double probability2prediction(Vector vector) {
        if (!isDefined(thresholds())) {
            return vector.argmax();
        }
        return Vectors$.MODULE$.dense((double[]) Predef$.MODULE$.refArrayOps((Object[]) Predef$.MODULE$.doubleArrayOps(vector.toArray()).zip(Predef$.MODULE$.wrapDoubleArray(getThresholds()), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)))).map(new ProbabilisticClassificationModel$$anonfun$7(this), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Double()))).argmax();
    }

    public ProbabilisticClassificationModel() {
        HasProbabilityCol.Cclass.$init$(this);
        org$apache$spark$ml$param$shared$HasThresholds$_setter_$thresholds_$eq(new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.", new HasThresholds$$anonfun$2(this)));
        ProbabilisticClassifierParams.Cclass.$init$(this);
    }
}
