package org.wso2.carbon.ml.core.spark.algorithms;

import java.io.Serializable;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.regression.LinearRegressionModel;
import org.apache.spark.mllib.regression.LinearRegressionWithSGD;
import scala.Tuple2;

/* loaded from: input_file:org/wso2/carbon/ml/core/spark/algorithms/LinearRegression.class */
public class LinearRegression implements Serializable {
    private static final long serialVersionUID = -5137378340857656687L;

    public LinearRegressionModel train(JavaRDD<LabeledPoint> javaRDD, int i, double d, double d2) {
        return LinearRegressionWithSGD.train(javaRDD.rdd(), i, d, d2);
    }

    public LinearRegressionModel train(JavaRDD<LabeledPoint> javaRDD, int i) {
        return LinearRegressionWithSGD.train(javaRDD.rdd(), i);
    }

    public JavaRDD<Tuple2<Double, Double>> test(final LinearRegressionModel linearRegressionModel, JavaRDD<LabeledPoint> javaRDD) {
        return javaRDD.map(new Function<LabeledPoint, Tuple2<Double, Double>>() { // from class: org.wso2.carbon.ml.core.spark.algorithms.LinearRegression.1
            private static final long serialVersionUID = 2027559237268104710L;

            public Tuple2<Double, Double> call(LabeledPoint labeledPoint) {
                return new Tuple2<>(Double.valueOf(linearRegressionModel.predict(labeledPoint.features())), Double.valueOf(labeledPoint.label()));
            }
        });
    }
}
