package org.wso2.carbon.ml.core.spark.algorithms;

import java.io.Serializable;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.classification.SVMModel;
import org.apache.spark.mllib.classification.SVMWithSGD;
import org.apache.spark.mllib.optimization.L1Updater;
import org.apache.spark.mllib.optimization.SquaredL2Updater;
import org.apache.spark.mllib.regression.LabeledPoint;
import scala.Tuple2;

/* loaded from: input_file:org/wso2/carbon/ml/core/spark/algorithms/SVM.class */
public class SVM implements Serializable {
    private static final long serialVersionUID = 7331560994085961463L;

    public SVMModel train(JavaRDD<LabeledPoint> javaRDD, int i, String str, double d, double d2, double d3) {
        SVMWithSGD sVMWithSGD = new SVMWithSGD();
        if (str.equals("L1")) {
            sVMWithSGD.optimizer().setUpdater(new L1Updater()).setRegParam(d);
        } else if (str.equals("L2")) {
            sVMWithSGD.optimizer().setUpdater(new SquaredL2Updater()).setRegParam(d);
        }
        sVMWithSGD.optimizer().setNumIterations(i).setStepSize(d2).setMiniBatchFraction(d3);
        return sVMWithSGD.run(javaRDD.rdd());
    }

    public SVMModel train(JavaRDD<LabeledPoint> javaRDD, int i, double d) {
        SVMWithSGD sVMWithSGD = new SVMWithSGD();
        sVMWithSGD.optimizer().setNumIterations(i).setRegParam(d);
        return sVMWithSGD.run(javaRDD.rdd());
    }

    public JavaRDD<Tuple2<Object, Object>> test(final SVMModel sVMModel, JavaRDD<LabeledPoint> javaRDD) {
        return javaRDD.map(new Function<LabeledPoint, Tuple2<Object, Object>>() { // from class: org.wso2.carbon.ml.core.spark.algorithms.SVM.1
            private static final long serialVersionUID = 4382737078158765112L;

            public Tuple2<Object, Object> call(LabeledPoint labeledPoint) {
                return new Tuple2<>(Double.valueOf(sVMModel.predict(labeledPoint.features())), Double.valueOf(labeledPoint.label()));
            }
        });
    }
}
