package org.wso2.carbon.ml.core.spark.algorithms;

import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.api.java.JavaDoubleRDD;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics;
import org.apache.spark.mllib.stat.Statistics;
import org.json.JSONArray;
import org.wso2.carbon.ml.commons.domain.Workflow;
import org.wso2.carbon.ml.core.exceptions.DatasetPreProcessingException;
import org.wso2.carbon.ml.core.spark.summary.ClassClassificationAndRegressionModelSummary;
import org.wso2.carbon.ml.core.spark.summary.PredictedVsActual;
import org.wso2.carbon.ml.core.spark.summary.ProbabilisticClassificationModelSummary;
import org.wso2.carbon.ml.core.spark.transformations.DiscardedRowsFilter;
import org.wso2.carbon.ml.core.spark.transformations.HeaderFilter;
import org.wso2.carbon.ml.core.spark.transformations.LineToTokens;
import org.wso2.carbon.ml.core.spark.transformations.MeanImputation;
import org.wso2.carbon.ml.core.spark.transformations.MissingValuesFilter;
import org.wso2.carbon.ml.core.spark.transformations.StringArrayToDoubleArray;
import org.wso2.carbon.ml.core.spark.transformations.TokensToVectors;
import org.wso2.carbon.ml.core.utils.MLUtils;
import scala.Tuple2;

/* loaded from: input_file:org/wso2/carbon/ml/core/spark/algorithms/SparkModelUtils.class */
public class SparkModelUtils {
    private static final Log log = LogFactory.getLog(SparkModelUtils.class);

    private SparkModelUtils() {
    }

    public static ProbabilisticClassificationModelSummary generateProbabilisticClassificationModelSummary(JavaRDD<Tuple2<Object, Object>> javaRDD) {
        ProbabilisticClassificationModelSummary probabilisticClassificationModelSummary = new ProbabilisticClassificationModelSummary();
        ArrayList arrayList = new ArrayList();
        DecimalFormat decimalFormat = new DecimalFormat("#.00");
        for (Tuple2 tuple2 : javaRDD.collect()) {
            PredictedVsActual predictedVsActual = new PredictedVsActual();
            predictedVsActual.setPredicted(Double.parseDouble(decimalFormat.format(tuple2._1())));
            predictedVsActual.setActual(Double.parseDouble(decimalFormat.format(tuple2._2())));
            arrayList.add(predictedVsActual);
            if (log.isTraceEnabled()) {
                log.trace("Predicted: " + predictedVsActual.getPredicted() + " ------ Actual: " + predictedVsActual.getActual());
            }
        }
        probabilisticClassificationModelSummary.setPredictedVsActuals(arrayList);
        BinaryClassificationMetrics binaryClassificationMetrics = new BinaryClassificationMetrics(JavaRDD.toRDD(javaRDD));
        probabilisticClassificationModelSummary.setAuc(binaryClassificationMetrics.areaUnderROC());
        List collect = binaryClassificationMetrics.roc().toJavaRDD().collect();
        JSONArray jSONArray = new JSONArray();
        for (int i = 0; i < collect.size(); i++) {
            JSONArray jSONArray2 = new JSONArray();
            jSONArray2.put(decimalFormat.format(((Tuple2) collect.get(i))._1()));
            jSONArray2.put(decimalFormat.format(((Tuple2) collect.get(i))._2()));
            jSONArray.put(jSONArray2);
        }
        probabilisticClassificationModelSummary.setRoc(jSONArray.toString());
        return probabilisticClassificationModelSummary;
    }

    public static ClassClassificationAndRegressionModelSummary generateRegressionModelSummary(JavaRDD<Tuple2<Double, Double>> javaRDD) {
        ClassClassificationAndRegressionModelSummary classClassificationAndRegressionModelSummary = new ClassClassificationAndRegressionModelSummary();
        ArrayList arrayList = new ArrayList();
        DecimalFormat decimalFormat = new DecimalFormat("#.00");
        for (Tuple2 tuple2 : javaRDD.collect()) {
            PredictedVsActual predictedVsActual = new PredictedVsActual();
            predictedVsActual.setPredicted(Double.parseDouble(decimalFormat.format(tuple2._1())));
            predictedVsActual.setActual(Double.parseDouble(decimalFormat.format(tuple2._2())));
            arrayList.add(predictedVsActual);
        }
        classClassificationAndRegressionModelSummary.setPredictedVsActuals(arrayList);
        classClassificationAndRegressionModelSummary.setError(new JavaDoubleRDD(javaRDD.map(new Function<Tuple2<Double, Double>, Object>() { // from class: org.wso2.carbon.ml.core.spark.algorithms.SparkModelUtils.1
            public Object call(Tuple2<Double, Double> tuple22) {
                return Double.valueOf(Math.pow(((Double) tuple22._1()).doubleValue() - ((Double) tuple22._2()).doubleValue(), 2.0d));
            }
        }).rdd()).mean().doubleValue());
        return classClassificationAndRegressionModelSummary;
    }

    public static ClassClassificationAndRegressionModelSummary getClassClassificationModelSummary(JavaPairRDD<Double, Double> javaPairRDD) {
        ClassClassificationAndRegressionModelSummary classClassificationAndRegressionModelSummary = new ClassClassificationAndRegressionModelSummary();
        ArrayList arrayList = new ArrayList();
        for (Tuple2 tuple2 : javaPairRDD.collect()) {
            PredictedVsActual predictedVsActual = new PredictedVsActual();
            predictedVsActual.setPredicted(((Double) tuple2._1()).doubleValue());
            predictedVsActual.setActual(((Double) tuple2._2()).doubleValue());
            arrayList.add(predictedVsActual);
        }
        classClassificationAndRegressionModelSummary.setPredictedVsActuals(arrayList);
        classClassificationAndRegressionModelSummary.setError((1.0d * javaPairRDD.filter(new Function<Tuple2<Double, Double>, Boolean>() { // from class: org.wso2.carbon.ml.core.spark.algorithms.SparkModelUtils.2
            public Boolean call(Tuple2<Double, Double> tuple22) {
                return Boolean.valueOf(!((Double) tuple22._1()).equals(tuple22._2()));
            }
        }).count()) / javaPairRDD.count());
        return classClassificationAndRegressionModelSummary;
    }

    public static JavaRDD<double[]> preProcess(JavaSparkContext javaSparkContext, Workflow workflow, JavaRDD<String> javaRDD, String str, String str2) throws DatasetPreProcessingException {
        JavaRDD filter = javaRDD.filter(new HeaderFilter(str)).map(new LineToTokens(Pattern.compile(str2))).filter(new DiscardedRowsFilter(MLUtils.getImputeFeatureIndices(workflow, "DISCARD")));
        List<Integer> imputeFeatureIndices = MLUtils.getImputeFeatureIndices(workflow, "REPLACE_WTH_MEAN");
        return imputeFeatureIndices.size() > 0 ? filter.map(new MeanImputation(getMeans(javaSparkContext, filter, imputeFeatureIndices, 0.01d))) : filter.map(new StringArrayToDoubleArray());
    }

    private static Map<Integer, Double> getMeans(JavaSparkContext javaSparkContext, JavaRDD<String[]> javaRDD, List<Integer> list, double d) throws DatasetPreProcessingException {
        HashMap hashMap = new HashMap();
        javaRDD.filter(new MissingValuesFilter());
        TokensToVectors tokensToVectors = new TokensToVectors(list);
        double[] array = Statistics.colStats((d < 1.0d ? javaRDD.sample(false, d).map(tokensToVectors) : javaRDD.map(tokensToVectors)).rdd()).mean().toArray();
        for (int i = 0; i < array.length; i++) {
            hashMap.put(list.get(i), Double.valueOf(array[i]));
        }
        return hashMap;
    }
}
