package org.wso2.carbon.ml.core.spark.algorithms;

import java.io.Serializable;
import java.util.Arrays;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.classification.LogisticRegressionModel;
import org.apache.spark.mllib.classification.LogisticRegressionWithSGD;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.optimization.L1Updater;
import org.apache.spark.mllib.optimization.LBFGS;
import org.apache.spark.mllib.optimization.LogisticGradient;
import org.apache.spark.mllib.optimization.SquaredL2Updater;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.util.MLUtils;
import scala.Tuple2;

/* loaded from: input_file:org/wso2/carbon/ml/core/spark/algorithms/LogisticRegression.class */
public class LogisticRegression implements Serializable {
    private static final long serialVersionUID = -8590453832704525074L;

    public LogisticRegressionModel trainWithSGD(JavaRDD<LabeledPoint> javaRDD, double d, int i, String str, double d2, double d3) {
        LogisticRegressionWithSGD logisticRegressionWithSGD = new LogisticRegressionWithSGD(d, i, d2, d3);
        if ("L1".equals(str)) {
            logisticRegressionWithSGD.optimizer().setUpdater(new L1Updater());
        } else if ("L2".equals(str)) {
            logisticRegressionWithSGD.optimizer().setUpdater(new SquaredL2Updater());
        }
        logisticRegressionWithSGD.setIntercept(true);
        return logisticRegressionWithSGD.run(javaRDD.rdd());
    }

    public LogisticRegressionModel trainWithLBFGS(JavaRDD<LabeledPoint> javaRDD, int i, double d, int i2, double d2) {
        int size = ((LabeledPoint) javaRDD.take(1).get(0)).features().size();
        JavaRDD map = javaRDD.map(new Function<LabeledPoint, Tuple2<Object, Vector>>() { // from class: org.wso2.carbon.ml.core.spark.algorithms.LogisticRegression.1
            private static final long serialVersionUID = 8486284563910067157L;

            public Tuple2<Object, Vector> call(LabeledPoint labeledPoint) {
                return new Tuple2<>(Double.valueOf(labeledPoint.label()), MLUtils.appendBias(labeledPoint.features()));
            }
        });
        map.cache();
        Vector vector = (Vector) LBFGS.runLBFGS(map.rdd(), new LogisticGradient(), new SquaredL2Updater(), i, d, i2, d2, Vectors.dense(new double[size + 1]))._1();
        return new LogisticRegressionModel(Vectors.dense(Arrays.copyOf(vector.toArray(), vector.size() - 1)), vector.toArray()[vector.size() - 1]);
    }

    public JavaRDD<Tuple2<Object, Object>> test(final LogisticRegressionModel logisticRegressionModel, JavaRDD<LabeledPoint> javaRDD) {
        return javaRDD.map(new Function<LabeledPoint, Tuple2<Object, Object>>() { // from class: org.wso2.carbon.ml.core.spark.algorithms.LogisticRegression.2
            private static final long serialVersionUID = 910861043765821669L;

            public Tuple2<Object, Object> call(LabeledPoint labeledPoint) {
                return new Tuple2<>(Double.valueOf(logisticRegressionModel.predict(labeledPoint.features())), Double.valueOf(labeledPoint.label()));
            }
        });
    }
}
