package org.apache.spark.ml.tree.impl;

import org.apache.spark.SparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.internal.Logging;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
import org.apache.spark.ml.regression.DecisionTreeRegressor;
import org.apache.spark.mllib.tree.configuration.Algo$;
import org.apache.spark.mllib.tree.configuration.BoostingStrategy;
import org.apache.spark.mllib.tree.configuration.Strategy;
import org.apache.spark.mllib.tree.impurity.Variance$;
import org.apache.spark.mllib.tree.loss.Loss;
import org.apache.spark.rdd.RDD;
import org.apache.spark.rdd.RDD$;
import org.apache.spark.rdd.util.PeriodicRDDCheckpointer;
import org.apache.spark.storage.StorageLevel;
import org.apache.spark.storage.StorageLevel$;
import org.slf4j.Logger;
import scala.Array$;
import scala.Enumeration;
import scala.Function0;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.IterableLike;
import scala.collection.TraversableLike;
import scala.collection.immutable.IndexedSeq;
import scala.collection.immutable.IndexedSeq$;
import scala.collection.immutable.Range;
import scala.collection.mutable.ArrayOps;
import scala.math.Ordering$Double$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.IntRef;
import scala.runtime.ObjectRef;
import scala.runtime.ScalaRunTime$;

/* compiled from: GradientBoostedTrees.scala */
/* loaded from: input_file:org/apache/spark/ml/tree/impl/GradientBoostedTrees$.class */
public final class GradientBoostedTrees$ implements Logging {
    public static GradientBoostedTrees$ MODULE$;
    private transient Logger org$apache$spark$internal$Logging$$log_;

    static {
        new GradientBoostedTrees$();
    }

    public String logName() {
        return Logging.logName$(this);
    }

    public Logger log() {
        return Logging.log$(this);
    }

    public void logInfo(Function0<String> function0) {
        Logging.logInfo$(this, function0);
    }

    public void logDebug(Function0<String> function0) {
        Logging.logDebug$(this, function0);
    }

    public void logTrace(Function0<String> function0) {
        Logging.logTrace$(this, function0);
    }

    public void logWarning(Function0<String> function0) {
        Logging.logWarning$(this, function0);
    }

    public void logError(Function0<String> function0) {
        Logging.logError$(this, function0);
    }

    public void logInfo(Function0<String> function0, Throwable th) {
        Logging.logInfo$(this, function0, th);
    }

    public void logDebug(Function0<String> function0, Throwable th) {
        Logging.logDebug$(this, function0, th);
    }

    public void logTrace(Function0<String> function0, Throwable th) {
        Logging.logTrace$(this, function0, th);
    }

    public void logWarning(Function0<String> function0, Throwable th) {
        Logging.logWarning$(this, function0, th);
    }

    public void logError(Function0<String> function0, Throwable th) {
        Logging.logError$(this, function0, th);
    }

    public boolean isTraceEnabled() {
        return Logging.isTraceEnabled$(this);
    }

    public void initializeLogIfNecessary(boolean z) {
        Logging.initializeLogIfNecessary$(this, z);
    }

    public boolean initializeLogIfNecessary(boolean z, boolean z2) {
        return Logging.initializeLogIfNecessary$(this, z, z2);
    }

    public boolean initializeLogIfNecessary$default$2() {
        return Logging.initializeLogIfNecessary$default$2$(this);
    }

    public Logger org$apache$spark$internal$Logging$$log_() {
        return this.org$apache$spark$internal$Logging$$log_;
    }

    public void org$apache$spark$internal$Logging$$log__$eq(Logger logger) {
        this.org$apache$spark$internal$Logging$$log_ = logger;
    }

    public Tuple2<DecisionTreeRegressionModel[], double[]> run(RDD<LabeledPoint> rdd, BoostingStrategy boostingStrategy, long j, String str) {
        Tuple2<DecisionTreeRegressionModel[], double[]> boost;
        Enumeration.Value algo = boostingStrategy.treeStrategy().algo();
        Enumeration.Value Regression = Algo$.MODULE$.Regression();
        if (Regression != null ? !Regression.equals(algo) : algo != null) {
            Enumeration.Value Classification = Algo$.MODULE$.Classification();
            if (Classification != null ? !Classification.equals(algo) : algo != null) {
                throw new IllegalArgumentException(new StringBuilder(39).append(algo).append(" is not supported by gradient boosting.").toString());
            }
            RDD<LabeledPoint> map = rdd.map(labeledPoint -> {
                return new LabeledPoint((labeledPoint.label() * 2) - 1, labeledPoint.features());
            }, ClassTag$.MODULE$.apply(LabeledPoint.class));
            boost = boost(map, map, boostingStrategy, false, j, str);
        } else {
            boost = boost(rdd, rdd, boostingStrategy, false, j, str);
        }
        return boost;
    }

    public Tuple2<DecisionTreeRegressionModel[], double[]> runWithValidation(RDD<LabeledPoint> rdd, RDD<LabeledPoint> rdd2, BoostingStrategy boostingStrategy, long j, String str) {
        Tuple2<DecisionTreeRegressionModel[], double[]> boost;
        Enumeration.Value algo = boostingStrategy.treeStrategy().algo();
        Enumeration.Value Regression = Algo$.MODULE$.Regression();
        if (Regression != null ? !Regression.equals(algo) : algo != null) {
            Enumeration.Value Classification = Algo$.MODULE$.Classification();
            if (Classification != null ? !Classification.equals(algo) : algo != null) {
                throw new IllegalArgumentException(new StringBuilder(43).append(algo).append(" is not supported by the gradient boosting.").toString());
            }
            boost = boost(rdd.map(labeledPoint -> {
                return new LabeledPoint((labeledPoint.label() * 2) - 1, labeledPoint.features());
            }, ClassTag$.MODULE$.apply(LabeledPoint.class)), rdd2.map(labeledPoint2 -> {
                return new LabeledPoint((labeledPoint2.label() * 2) - 1, labeledPoint2.features());
            }, ClassTag$.MODULE$.apply(LabeledPoint.class)), boostingStrategy, true, j, str);
        } else {
            boost = boost(rdd, rdd2, boostingStrategy, true, j, str);
        }
        return boost;
    }

    public RDD<Tuple2<Object, Object>> computeInitialPredictionAndError(RDD<LabeledPoint> rdd, double d, DecisionTreeRegressionModel decisionTreeRegressionModel, Loss loss) {
        return rdd.map(labeledPoint -> {
            double updatePrediction = MODULE$.updatePrediction(labeledPoint.features(), 0.0d, decisionTreeRegressionModel, d);
            return new Tuple2.mcDD.sp(updatePrediction, loss.computeError(updatePrediction, labeledPoint.label()));
        }, ClassTag$.MODULE$.apply(Tuple2.class));
    }

    public RDD<Tuple2<Object, Object>> updatePredictionError(RDD<LabeledPoint> rdd, RDD<Tuple2<Object, Object>> rdd2, double d, DecisionTreeRegressionModel decisionTreeRegressionModel, Loss loss) {
        RDD zip = rdd.zip(rdd2, ClassTag$.MODULE$.apply(Tuple2.class));
        return zip.mapPartitions(iterator -> {
            return iterator.map(tuple2 -> {
                if (tuple2 != null) {
                    LabeledPoint labeledPoint = (LabeledPoint) tuple2._1();
                    Tuple2 tuple2 = (Tuple2) tuple2._2();
                    if (tuple2 != null) {
                        double updatePrediction = MODULE$.updatePrediction(labeledPoint.features(), tuple2._1$mcD$sp(), decisionTreeRegressionModel, d);
                        return new Tuple2.mcDD.sp(updatePrediction, loss.computeError(updatePrediction, labeledPoint.label()));
                    }
                }
                throw new MatchError(tuple2);
            });
        }, zip.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Tuple2.class));
    }

    public double updatePrediction(Vector vector, double d, DecisionTreeRegressionModel decisionTreeRegressionModel, double d2) {
        return d + (decisionTreeRegressionModel.rootNode().predictImpl(vector).prediction() * d2);
    }

    public double computeError(RDD<LabeledPoint> rdd, DecisionTreeRegressionModel[] decisionTreeRegressionModelArr, double[] dArr, Loss loss) {
        return RDD$.MODULE$.doubleRDDToDoubleRDDFunctions(rdd.map(labeledPoint -> {
            return BoxesRunTime.boxToDouble($anonfun$computeError$1(decisionTreeRegressionModelArr, dArr, loss, labeledPoint));
        }, ClassTag$.MODULE$.Double())).mean();
    }

    public double[] evaluateEachIteration(RDD<LabeledPoint> rdd, DecisionTreeRegressionModel[] decisionTreeRegressionModelArr, double[] dArr, Loss loss, Enumeration.Value value) {
        SparkContext sparkContext = rdd.sparkContext();
        Enumeration.Value Classification = Algo$.MODULE$.Classification();
        RDD<LabeledPoint> map = (Classification != null ? !Classification.equals(value) : value != null) ? rdd : rdd.map(labeledPoint -> {
            return new LabeledPoint((labeledPoint.label() * 2) - 1, labeledPoint.features());
        }, ClassTag$.MODULE$.apply(LabeledPoint.class));
        Broadcast broadcast = sparkContext.broadcast(decisionTreeRegressionModelArr, ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(DecisionTreeRegressionModel.class)));
        Range indices = new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(decisionTreeRegressionModelArr)).indices();
        long count = map.count();
        IndexedSeq indexedSeq = (IndexedSeq) ((TraversableLike) map.map(labeledPoint2 -> {
            return (IndexedSeq) ((TraversableLike) ((IterableLike) ((TraversableLike) indices.map(i -> {
                return ((DecisionTreeRegressionModel[]) broadcast.value())[i].rootNode().predictImpl(labeledPoint2.features()).prediction() * dArr[i];
            }, IndexedSeq$.MODULE$.canBuildFrom())).scanLeft(BoxesRunTime.boxToDouble(0.0d), (d, d2) -> {
                return d + d2;
            }, IndexedSeq$.MODULE$.canBuildFrom())).drop(1)).map(d3 -> {
                return loss.computeError(d3, labeledPoint2.label());
            }, IndexedSeq$.MODULE$.canBuildFrom());
        }, ClassTag$.MODULE$.apply(IndexedSeq.class)).aggregate(indices.map(i -> {
            return 0.0d;
        }, IndexedSeq$.MODULE$.canBuildFrom()), (indexedSeq2, indexedSeq3) -> {
            return (IndexedSeq) indices.map(i2 -> {
                return BoxesRunTime.unboxToDouble(indexedSeq2.apply(i2)) + BoxesRunTime.unboxToDouble(indexedSeq3.apply(i2));
            }, IndexedSeq$.MODULE$.canBuildFrom());
        }, (indexedSeq4, indexedSeq5) -> {
            return (IndexedSeq) indices.map(i2 -> {
                return BoxesRunTime.unboxToDouble(indexedSeq4.apply(i2)) + BoxesRunTime.unboxToDouble(indexedSeq5.apply(i2));
            }, IndexedSeq$.MODULE$.canBuildFrom());
        }, ClassTag$.MODULE$.apply(IndexedSeq.class))).map(d -> {
            return d / count;
        }, IndexedSeq$.MODULE$.canBuildFrom());
        broadcast.destroy(false);
        return (double[]) indexedSeq.toArray(ClassTag$.MODULE$.Double());
    }

    public Tuple2<DecisionTreeRegressionModel[], double[]> boost(RDD<LabeledPoint> rdd, RDD<LabeledPoint> rdd2, BoostingStrategy boostingStrategy, boolean z, long j, String str) {
        boolean z2;
        TimeTracker timeTracker = new TimeTracker();
        timeTracker.start("total");
        timeTracker.start("init");
        boostingStrategy.assertValid();
        int numIterations = boostingStrategy.numIterations();
        DecisionTreeRegressionModel[] decisionTreeRegressionModelArr = new DecisionTreeRegressionModel[numIterations];
        double[] dArr = new double[numIterations];
        Loss loss = boostingStrategy.loss();
        double learningRate = boostingStrategy.learningRate();
        Strategy copy = boostingStrategy.treeStrategy().copy();
        double validationTol = boostingStrategy.validationTol();
        copy.algo_$eq(Algo$.MODULE$.Regression());
        copy.impurity_$eq(Variance$.MODULE$);
        copy.assertValid();
        StorageLevel storageLevel = rdd.getStorageLevel();
        StorageLevel NONE = StorageLevel$.MODULE$.NONE();
        if (storageLevel != null ? !storageLevel.equals(NONE) : NONE != null) {
            z2 = false;
        } else {
            rdd.persist(StorageLevel$.MODULE$.MEMORY_AND_DISK());
            z2 = true;
        }
        boolean z3 = z2;
        PeriodicRDDCheckpointer periodicRDDCheckpointer = new PeriodicRDDCheckpointer(copy.getCheckpointInterval(), rdd.sparkContext());
        PeriodicRDDCheckpointer periodicRDDCheckpointer2 = new PeriodicRDDCheckpointer(copy.getCheckpointInterval(), rdd.sparkContext());
        timeTracker.stop("init");
        logDebug(() -> {
            return "##########";
        });
        logDebug(() -> {
            return "Building tree 0";
        });
        logDebug(() -> {
            return "##########";
        });
        timeTracker.start("building tree 0");
        DecisionTreeRegressionModel train = new DecisionTreeRegressor().setSeed(j).train(rdd, copy, str);
        decisionTreeRegressionModelArr[0] = train;
        dArr[0] = 1.0d;
        ObjectRef create = ObjectRef.create(computeInitialPredictionAndError(rdd, 1.0d, train, loss));
        periodicRDDCheckpointer.update((RDD) create.elem);
        logDebug(() -> {
            return new StringBuilder(15).append("error of gbt = ").append(RDD$.MODULE$.doubleRDDToDoubleRDDFunctions(RDD$.MODULE$.rddToPairRDDFunctions((RDD) create.elem, ClassTag$.MODULE$.Double(), ClassTag$.MODULE$.Double(), Ordering$Double$.MODULE$).values()).mean()).toString();
        });
        timeTracker.stop("building tree 0");
        RDD<Tuple2<Object, Object>> computeInitialPredictionAndError = computeInitialPredictionAndError(rdd2, 1.0d, train, loss);
        if (z) {
            periodicRDDCheckpointer2.update(computeInitialPredictionAndError);
        }
        double mean = z ? RDD$.MODULE$.doubleRDDToDoubleRDDFunctions(RDD$.MODULE$.rddToPairRDDFunctions(computeInitialPredictionAndError, ClassTag$.MODULE$.Double(), ClassTag$.MODULE$.Double(), Ordering$Double$.MODULE$).values()).mean() : 0.0d;
        int i = 1;
        IntRef create2 = IntRef.create(1);
        boolean z4 = false;
        while (create2.elem < numIterations && !z4) {
            RDD<LabeledPoint> map = ((RDD) create.elem).zip(rdd, ClassTag$.MODULE$.apply(LabeledPoint.class)).map(tuple2 -> {
                if (tuple2 != null) {
                    Tuple2 tuple2 = (Tuple2) tuple2._1();
                    LabeledPoint labeledPoint = (LabeledPoint) tuple2._2();
                    if (tuple2 != null) {
                        return new LabeledPoint(-loss.gradient(tuple2._1$mcD$sp(), labeledPoint.label()), labeledPoint.features());
                    }
                }
                throw new MatchError(tuple2);
            }, ClassTag$.MODULE$.apply(LabeledPoint.class));
            timeTracker.start(new StringBuilder(14).append("building tree ").append(create2.elem).toString());
            logDebug(() -> {
                return "###################################################";
            });
            logDebug(() -> {
                return new StringBuilder(33).append("Gradient boosting tree iteration ").append(create2.elem).toString();
            });
            logDebug(() -> {
                return "###################################################";
            });
            DecisionTreeRegressionModel train2 = new DecisionTreeRegressor().setSeed(j + create2.elem).train(map, copy, str);
            timeTracker.stop(new StringBuilder(14).append("building tree ").append(create2.elem).toString());
            decisionTreeRegressionModelArr[create2.elem] = train2;
            dArr[create2.elem] = learningRate;
            create.elem = updatePredictionError(rdd, (RDD) create.elem, dArr[create2.elem], decisionTreeRegressionModelArr[create2.elem], loss);
            periodicRDDCheckpointer.update((RDD) create.elem);
            logDebug(() -> {
                return new StringBuilder(15).append("error of gbt = ").append(RDD$.MODULE$.doubleRDDToDoubleRDDFunctions(RDD$.MODULE$.rddToPairRDDFunctions((RDD) create.elem, ClassTag$.MODULE$.Double(), ClassTag$.MODULE$.Double(), Ordering$Double$.MODULE$).values()).mean()).toString();
            });
            if (z) {
                computeInitialPredictionAndError = updatePredictionError(rdd2, computeInitialPredictionAndError, dArr[create2.elem], decisionTreeRegressionModelArr[create2.elem], loss);
                periodicRDDCheckpointer2.update(computeInitialPredictionAndError);
                double mean2 = RDD$.MODULE$.doubleRDDToDoubleRDDFunctions(RDD$.MODULE$.rddToPairRDDFunctions(computeInitialPredictionAndError, ClassTag$.MODULE$.Double(), ClassTag$.MODULE$.Double(), Ordering$Double$.MODULE$).values()).mean();
                if (mean - mean2 < validationTol * Math.max(mean2, 0.01d)) {
                    z4 = true;
                } else if (mean2 < mean) {
                    mean = mean2;
                    i = create2.elem + 1;
                }
            }
            create2.elem++;
        }
        timeTracker.stop("total");
        logInfo(() -> {
            return "Internal timing for DecisionTree:";
        });
        logInfo(() -> {
            return String.valueOf(timeTracker);
        });
        periodicRDDCheckpointer.unpersistDataSet();
        periodicRDDCheckpointer.deleteAllCheckpoints();
        periodicRDDCheckpointer2.unpersistDataSet();
        periodicRDDCheckpointer2.deleteAllCheckpoints();
        if (z3) {
            rdd.unpersist(rdd.unpersist$default$1());
        } else {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
        return z ? new Tuple2<>(new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(decisionTreeRegressionModelArr)).slice(0, i), new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(dArr)).slice(0, i)) : new Tuple2<>(decisionTreeRegressionModelArr, dArr);
    }

    public static final /* synthetic */ double $anonfun$computeError$2(LabeledPoint labeledPoint, double d, Tuple2 tuple2) {
        Tuple2 tuple22 = new Tuple2(BoxesRunTime.boxToDouble(d), tuple2);
        if (tuple22 != null) {
            double _1$mcD$sp = tuple22._1$mcD$sp();
            Tuple2 tuple23 = (Tuple2) tuple22._2();
            if (tuple23 != null) {
                return MODULE$.updatePrediction(labeledPoint.features(), _1$mcD$sp, (DecisionTreeRegressionModel) tuple23._1(), tuple23._2$mcD$sp());
            }
        }
        throw new MatchError(tuple22);
    }

    public static final /* synthetic */ double $anonfun$computeError$1(DecisionTreeRegressionModel[] decisionTreeRegressionModelArr, double[] dArr, Loss loss, LabeledPoint labeledPoint) {
        return loss.computeError(BoxesRunTime.unboxToDouble(new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(decisionTreeRegressionModelArr)).zip(Predef$.MODULE$.wrapDoubleArray(dArr), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class))))).foldLeft(BoxesRunTime.boxToDouble(0.0d), (obj, tuple2) -> {
            return BoxesRunTime.boxToDouble($anonfun$computeError$2(labeledPoint, BoxesRunTime.unboxToDouble(obj), tuple2));
        })), labeledPoint.label());
    }

    private GradientBoostedTrees$() {
        MODULE$ = this;
        Logging.$init$(this);
    }
}
