package org.apache.spark.mllib.regression;

import java.io.Serializable;
import java.util.List;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.util.LinearDataGenerator;
import org.jblas.DoubleMatrix;
import org.jblas.util.Random;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.class */
public class JavaRidgeRegressionSuite implements Serializable {
    private transient JavaSparkContext sc;

    @Before
    public void setUp() {
        this.sc = new JavaSparkContext("local", "JavaRidgeRegressionSuite");
    }

    @After
    public void tearDown() {
        this.sc.stop();
        this.sc = null;
    }

    double predictionError(List<LabeledPoint> list, RidgeRegressionModel ridgeRegressionModel) {
        double d = 0.0d;
        for (LabeledPoint labeledPoint : list) {
            Double valueOf = Double.valueOf(ridgeRegressionModel.predict(labeledPoint.features()));
            d += (valueOf.doubleValue() - labeledPoint.label()) * (valueOf.doubleValue() - labeledPoint.label());
        }
        return d / list.size();
    }

    List<LabeledPoint> generateRidgeData(int i, int i2, double d) {
        Random.seed(42L);
        return LinearDataGenerator.generateLinearInputAsList(0.0d, DoubleMatrix.rand(i2, 1).subi(0.5d).data, i, 42, d);
    }

    @Test
    public void runRidgeRegressionUsingConstructor() {
        List<LabeledPoint> generateRidgeData = generateRidgeData(2 * 50, 20, 10.0d);
        JavaRDD parallelize = this.sc.parallelize(generateRidgeData.subList(0, 50));
        List<LabeledPoint> subList = generateRidgeData.subList(50, 2 * 50);
        RidgeRegressionWithSGD ridgeRegressionWithSGD = new RidgeRegressionWithSGD();
        ridgeRegressionWithSGD.optimizer().setStepSize(1.0d).setRegParam(0.0d).setNumIterations(200);
        double predictionError = predictionError(subList, (RidgeRegressionModel) ridgeRegressionWithSGD.run(parallelize.rdd()));
        ridgeRegressionWithSGD.optimizer().setRegParam(0.1d);
        Assert.assertTrue(predictionError(subList, (RidgeRegressionModel) ridgeRegressionWithSGD.run(parallelize.rdd())) < predictionError);
    }

    @Test
    public void runRidgeRegressionUsingStaticMethods() {
        List<LabeledPoint> generateRidgeData = generateRidgeData(2 * 50, 20, 10.0d);
        JavaRDD parallelize = this.sc.parallelize(generateRidgeData.subList(0, 50));
        List<LabeledPoint> subList = generateRidgeData.subList(50, 2 * 50);
        Assert.assertTrue(predictionError(subList, RidgeRegressionWithSGD.train(parallelize.rdd(), 200, 1.0d, 0.1d)) < predictionError(subList, RidgeRegressionWithSGD.train(parallelize.rdd(), 200, 1.0d, 0.0d)));
    }
}
