package org.wso2.carbon.ml.core.spark.algorithms;

import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.api.java.JavaDoubleRDD;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.json.JSONArray;
import org.json.JSONException;
import org.wso2.carbon.ml.commons.domain.Feature;
import org.wso2.carbon.ml.core.internal.MLModelConfigurationContext;
import org.wso2.carbon.ml.core.spark.summary.ClassClassificationAndRegressionModelSummary;
import org.wso2.carbon.ml.core.spark.summary.PredictedVsActual;
import org.wso2.carbon.ml.core.spark.summary.ProbabilisticClassificationModelSummary;
import org.wso2.carbon.ml.core.spark.summary.TestResultDataPoint;
import org.wso2.carbon.ml.core.utils.MLConstants;
import org.wso2.carbon.ml.core.utils.MLCoreServiceValueHolder;
import scala.Tuple2;

/* loaded from: input_file:org/wso2/carbon/ml/core/spark/algorithms/SparkModelUtils.class */
public class SparkModelUtils {
    private static final Log log = LogFactory.getLog(SparkModelUtils.class);

    private SparkModelUtils() {
    }

    public static ProbabilisticClassificationModelSummary generateProbabilisticClassificationModelSummary(JavaSparkContext javaSparkContext, JavaRDD<LabeledPoint> javaRDD, JavaRDD<Tuple2<Object, Object>> javaRDD2) {
        int sampleSize = MLCoreServiceValueHolder.getInstance().getSummaryStatSettings().getSampleSize();
        ProbabilisticClassificationModelSummary probabilisticClassificationModelSummary = new ProbabilisticClassificationModelSummary();
        ArrayList arrayList = new ArrayList();
        DecimalFormat decimalFormat = new DecimalFormat("#.00");
        for (Tuple2 tuple2 : javaRDD2.cache().take(sampleSize)) {
            PredictedVsActual predictedVsActual = new PredictedVsActual();
            predictedVsActual.setPredicted(Double.parseDouble(decimalFormat.format(tuple2._1())));
            predictedVsActual.setActual(Double.parseDouble(decimalFormat.format(tuple2._2())));
            arrayList.add(predictedVsActual);
            if (log.isTraceEnabled()) {
                log.trace("Predicted: " + predictedVsActual.getPredicted() + " ------ Actual: " + predictedVsActual.getActual());
            }
        }
        BinaryClassificationMetrics binaryClassificationMetrics = new BinaryClassificationMetrics(JavaRDD.toRDD(javaRDD2));
        javaRDD2.unpersist();
        ArrayList arrayList2 = new ArrayList();
        for (LabeledPoint labeledPoint : javaRDD.take(sampleSize)) {
            if (labeledPoint != null && labeledPoint.features() != null) {
                arrayList2.add(labeledPoint.features().toArray());
            }
        }
        ArrayList arrayList3 = new ArrayList();
        for (int i = 0; i < arrayList2.size(); i++) {
            TestResultDataPoint testResultDataPoint = new TestResultDataPoint();
            testResultDataPoint.setPredictedVsActual((PredictedVsActual) arrayList.get(i));
            testResultDataPoint.setFeatureValues((double[]) arrayList2.get(i));
            arrayList3.add(testResultDataPoint);
        }
        JavaRDD cache = javaSparkContext.parallelize(arrayList3).cache();
        List<TestResultDataPoint> takeSample = cache.count() > ((long) sampleSize) ? cache.takeSample(true, sampleSize) : cache.collect();
        cache.unpersist();
        probabilisticClassificationModelSummary.setTestResultDataPointsSample(takeSample);
        probabilisticClassificationModelSummary.setAuc(binaryClassificationMetrics.areaUnderROC());
        JavaRDD cache2 = binaryClassificationMetrics.roc().toJavaRDD().cache();
        List collect = cache2.collect();
        JSONArray jSONArray = new JSONArray();
        for (int i2 = 0; i2 < collect.size(); i2++) {
            JSONArray jSONArray2 = new JSONArray();
            jSONArray2.put(decimalFormat.format(((Tuple2) collect.get(i2))._1()));
            jSONArray2.put(decimalFormat.format(((Tuple2) collect.get(i2))._2()));
            jSONArray.put(jSONArray2);
        }
        cache2.unpersist();
        probabilisticClassificationModelSummary.setRoc(jSONArray.toString());
        return probabilisticClassificationModelSummary;
    }

    public static ClassClassificationAndRegressionModelSummary generateRegressionModelSummary(JavaSparkContext javaSparkContext, JavaRDD<LabeledPoint> javaRDD, JavaRDD<Tuple2<Double, Double>> javaRDD2) {
        int sampleSize = MLCoreServiceValueHolder.getInstance().getSummaryStatSettings().getSampleSize();
        ClassClassificationAndRegressionModelSummary classClassificationAndRegressionModelSummary = new ClassClassificationAndRegressionModelSummary();
        ArrayList arrayList = new ArrayList();
        DecimalFormat decimalFormat = new DecimalFormat("#.00");
        for (Tuple2 tuple2 : javaRDD2.take(sampleSize)) {
            PredictedVsActual predictedVsActual = new PredictedVsActual();
            predictedVsActual.setPredicted(Double.parseDouble(decimalFormat.format(tuple2._1())));
            predictedVsActual.setActual(Double.parseDouble(decimalFormat.format(tuple2._2())));
            arrayList.add(predictedVsActual);
        }
        ArrayList arrayList2 = new ArrayList();
        for (LabeledPoint labeledPoint : javaRDD.take(sampleSize)) {
            if (labeledPoint != null && labeledPoint.features() != null) {
                arrayList2.add(labeledPoint.features().toArray());
            }
        }
        ArrayList arrayList3 = new ArrayList();
        for (int i = 0; i < arrayList2.size(); i++) {
            TestResultDataPoint testResultDataPoint = new TestResultDataPoint();
            testResultDataPoint.setPredictedVsActual((PredictedVsActual) arrayList.get(i));
            testResultDataPoint.setFeatureValues((double[]) arrayList2.get(i));
            arrayList3.add(testResultDataPoint);
        }
        JavaRDD parallelize = javaSparkContext.parallelize(arrayList3);
        List<TestResultDataPoint> takeSample = parallelize.count() > ((long) sampleSize) ? parallelize.takeSample(true, sampleSize) : parallelize.collect();
        parallelize.unpersist();
        classClassificationAndRegressionModelSummary.setTestResultDataPointsSample(takeSample);
        classClassificationAndRegressionModelSummary.setError(new JavaDoubleRDD(javaRDD2.map(new Function<Tuple2<Double, Double>, Object>() { // from class: org.wso2.carbon.ml.core.spark.algorithms.SparkModelUtils.1
            private static final long serialVersionUID = -162193633199074816L;

            public Object call(Tuple2<Double, Double> tuple22) {
                return Double.valueOf(Math.pow(((Double) tuple22._1()).doubleValue() - ((Double) tuple22._2()).doubleValue(), 2.0d));
            }
        }).rdd()).mean().doubleValue());
        return classClassificationAndRegressionModelSummary;
    }

    public static ClassClassificationAndRegressionModelSummary getClassClassificationModelSummary(JavaSparkContext javaSparkContext, JavaRDD<LabeledPoint> javaRDD, JavaPairRDD<Double, Double> javaPairRDD) {
        int sampleSize = MLCoreServiceValueHolder.getInstance().getSummaryStatSettings().getSampleSize();
        ClassClassificationAndRegressionModelSummary classClassificationAndRegressionModelSummary = new ClassClassificationAndRegressionModelSummary();
        ArrayList arrayList = new ArrayList();
        for (Tuple2 tuple2 : javaPairRDD.take(sampleSize)) {
            PredictedVsActual predictedVsActual = new PredictedVsActual();
            predictedVsActual.setPredicted(((Double) tuple2._1()).doubleValue());
            predictedVsActual.setActual(((Double) tuple2._2()).doubleValue());
            arrayList.add(predictedVsActual);
        }
        ArrayList arrayList2 = new ArrayList();
        for (LabeledPoint labeledPoint : javaRDD.take(sampleSize)) {
            if (labeledPoint != null && labeledPoint.features() != null) {
                arrayList2.add(labeledPoint.features().toArray());
            }
        }
        ArrayList arrayList3 = new ArrayList();
        for (int i = 0; i < arrayList2.size(); i++) {
            TestResultDataPoint testResultDataPoint = new TestResultDataPoint();
            testResultDataPoint.setPredictedVsActual((PredictedVsActual) arrayList.get(i));
            testResultDataPoint.setFeatureValues((double[]) arrayList2.get(i));
            arrayList3.add(testResultDataPoint);
        }
        JavaRDD cache = javaSparkContext.parallelize(arrayList3).cache();
        List<TestResultDataPoint> takeSample = cache.count() > ((long) sampleSize) ? cache.takeSample(true, sampleSize) : cache.collect();
        cache.unpersist();
        classClassificationAndRegressionModelSummary.setTestResultDataPointsSample(takeSample);
        classClassificationAndRegressionModelSummary.setError((1.0d * javaPairRDD.filter(new Function<Tuple2<Double, Double>, Boolean>() { // from class: org.wso2.carbon.ml.core.spark.algorithms.SparkModelUtils.2
            private static final long serialVersionUID = -3063364114286182333L;

            public Boolean call(Tuple2<Double, Double> tuple22) {
                return Boolean.valueOf(!((Double) tuple22._1()).equals(tuple22._2()));
            }
        }).count()) / javaPairRDD.count());
        return classClassificationAndRegressionModelSummary;
    }

    public static List<Map<String, Integer>> buildEncodings(MLModelConfigurationContext mLModelConfigurationContext) {
        List<Feature> features = mLModelConfigurationContext.getFacts().getFeatures();
        Map<String, String> summaryStatsOfFeatures = mLModelConfigurationContext.getSummaryStatsOfFeatures();
        List<Integer> newToOldIndicesList = mLModelConfigurationContext.getNewToOldIndicesList();
        int responseIndex = mLModelConfigurationContext.getResponseIndex();
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < newToOldIndicesList.size() + 1; i++) {
            arrayList.add(new HashMap());
        }
        for (Feature feature : features) {
            HashMap hashMap = new HashMap();
            if (feature.getType().equals("CATEGORICAL")) {
                List<String> uniqueValues = getUniqueValues(summaryStatsOfFeatures.get(feature.getName()));
                Collections.sort(uniqueValues);
                for (int i2 = 0; i2 < uniqueValues.size(); i2++) {
                    hashMap.put(uniqueValues.get(i2), Integer.valueOf(i2));
                }
                int indexOf = newToOldIndicesList.indexOf(Integer.valueOf(feature.getIndex()));
                if (indexOf != -1) {
                    arrayList.set(indexOf, hashMap);
                } else if (feature.getIndex() == responseIndex) {
                    arrayList.set(arrayList.size() - 1, hashMap);
                }
            }
        }
        return arrayList;
    }

    private static List<String> getUniqueValues(String str) {
        ArrayList arrayList = new ArrayList();
        if (str == null) {
            return arrayList;
        }
        try {
            JSONArray jSONArray = new JSONArray(str).getJSONObject(0).getJSONArray(MLConstants.BAM_DATA_VALUES);
            if (jSONArray == null) {
                return arrayList;
            }
            for (int i = 0; i < jSONArray.length(); i++) {
                JSONArray jSONArray2 = jSONArray.getJSONArray(i);
                if (jSONArray2 != null) {
                    arrayList.add(jSONArray2.getString(0));
                }
            }
            return arrayList;
        } catch (JSONException e) {
            log.warn("Failed to extract unique values from summary stats: " + str, e);
            return arrayList;
        }
    }

    public static double getMean(String str) {
        if (str == null) {
            return 0.0d;
        }
        try {
            String string = new JSONArray(str).getJSONObject(0).getString("mean");
            if (string == null) {
                return 0.0d;
            }
            try {
                return Double.parseDouble(string);
            } catch (NumberFormatException e) {
                return 0.0d;
            }
        } catch (JSONException e2) {
            log.warn("Failed to extract mean values from summary stats: " + str, e2);
            return 0.0d;
        }
    }

    public static double getMin(String str) {
        if (str == null) {
            return 0.0d;
        }
        try {
            String string = new JSONArray(str).getJSONObject(0).getString("min");
            if (string == null) {
                return 0.0d;
            }
            try {
                return Double.parseDouble(string);
            } catch (NumberFormatException e) {
                return 0.0d;
            }
        } catch (JSONException e2) {
            log.warn("Failed to extract min values from summary stats: " + str, e2);
            return 0.0d;
        }
    }

    public static double getMax(String str) {
        if (str == null) {
            return 0.0d;
        }
        try {
            String string = new JSONArray(str).getJSONObject(0).getString("max");
            if (string == null) {
                return 0.0d;
            }
            try {
                return Double.parseDouble(string);
            } catch (NumberFormatException e) {
                return 0.0d;
            }
        } catch (JSONException e2) {
            log.warn("Failed to extract max values from summary stats: " + str, e2);
            return 0.0d;
        }
    }
}
