package org.wso2.carbon.ml.core.spark.algorithms;

import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.SortedMap;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.classification.LogisticRegressionModel;
import org.apache.spark.mllib.classification.NaiveBayesModel;
import org.apache.spark.mllib.classification.SVMModel;
import org.apache.spark.mllib.evaluation.MulticlassMetrics;
import org.apache.spark.mllib.evaluation.RegressionMetrics;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.regression.LassoModel;
import org.apache.spark.mllib.regression.LinearRegressionModel;
import org.apache.spark.mllib.regression.RidgeRegressionModel;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.tree.model.RandomForestModel;
import org.apache.spark.rdd.RDD;
import org.wso2.carbon.ml.commons.constants.MLConstants;
import org.wso2.carbon.ml.commons.domain.Feature;
import org.wso2.carbon.ml.commons.domain.MLModel;
import org.wso2.carbon.ml.commons.domain.ModelSummary;
import org.wso2.carbon.ml.commons.domain.Workflow;
import org.wso2.carbon.ml.core.exceptions.AlgorithmNameException;
import org.wso2.carbon.ml.core.exceptions.MLModelBuilderException;
import org.wso2.carbon.ml.core.factories.AlgorithmType;
import org.wso2.carbon.ml.core.interfaces.MLModelBuilder;
import org.wso2.carbon.ml.core.internal.MLModelConfigurationContext;
import org.wso2.carbon.ml.core.spark.MulticlassConfusionMatrix;
import org.wso2.carbon.ml.core.spark.models.MLClassificationModel;
import org.wso2.carbon.ml.core.spark.models.MLDecisionTreeModel;
import org.wso2.carbon.ml.core.spark.models.MLGeneralizedLinearModel;
import org.wso2.carbon.ml.core.spark.models.MLRandomForestModel;
import org.wso2.carbon.ml.core.spark.summary.ClassClassificationAndRegressionModelSummary;
import org.wso2.carbon.ml.core.spark.summary.FeatureImportance;
import org.wso2.carbon.ml.core.spark.summary.ProbabilisticClassificationModelSummary;
import org.wso2.carbon.ml.core.spark.transformations.BasicEncoder;
import org.wso2.carbon.ml.core.spark.transformations.DiscardedRowsFilter;
import org.wso2.carbon.ml.core.spark.transformations.DoubleArrayToLabeledPoint;
import org.wso2.carbon.ml.core.spark.transformations.HeaderFilter;
import org.wso2.carbon.ml.core.spark.transformations.LineToTokens;
import org.wso2.carbon.ml.core.spark.transformations.MeanImputation;
import org.wso2.carbon.ml.core.spark.transformations.RemoveDiscardedFeatures;
import org.wso2.carbon.ml.core.spark.transformations.StringArrayToDoubleArray;
import org.wso2.carbon.ml.core.utils.MLCoreServiceValueHolder;
import org.wso2.carbon.ml.core.utils.MLUtils;
import org.wso2.carbon.ml.database.DatabaseService;
import scala.Tuple2;

/* loaded from: input_file:org/wso2/carbon/ml/core/spark/algorithms/SupervisedSparkModelBuilder.class */
public class SupervisedSparkModelBuilder extends MLModelBuilder {

    /* renamed from: org.wso2.carbon.ml.core.spark.algorithms.SupervisedSparkModelBuilder$1, reason: invalid class name */
    /* loaded from: input_file:org/wso2/carbon/ml/core/spark/algorithms/SupervisedSparkModelBuilder$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$SUPERVISED_ALGORITHM = new int[MLConstants.SUPERVISED_ALGORITHM.values().length];

        static {
            try {
                $SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$SUPERVISED_ALGORITHM[MLConstants.SUPERVISED_ALGORITHM.LOGISTIC_REGRESSION.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$SUPERVISED_ALGORITHM[MLConstants.SUPERVISED_ALGORITHM.LOGISTIC_REGRESSION_LBFGS.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$SUPERVISED_ALGORITHM[MLConstants.SUPERVISED_ALGORITHM.DECISION_TREE.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$SUPERVISED_ALGORITHM[MLConstants.SUPERVISED_ALGORITHM.RANDOM_FOREST.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$SUPERVISED_ALGORITHM[MLConstants.SUPERVISED_ALGORITHM.SVM.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$SUPERVISED_ALGORITHM[MLConstants.SUPERVISED_ALGORITHM.NAIVE_BAYES.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$SUPERVISED_ALGORITHM[MLConstants.SUPERVISED_ALGORITHM.LINEAR_REGRESSION.ordinal()] = 7;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$SUPERVISED_ALGORITHM[MLConstants.SUPERVISED_ALGORITHM.RIDGE_REGRESSION.ordinal()] = 8;
            } catch (NoSuchFieldError e8) {
            }
            try {
                $SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$SUPERVISED_ALGORITHM[MLConstants.SUPERVISED_ALGORITHM.LASSO_REGRESSION.ordinal()] = 9;
            } catch (NoSuchFieldError e9) {
            }
        }
    }

    public SupervisedSparkModelBuilder(MLModelConfigurationContext mLModelConfigurationContext) {
        super(mLModelConfigurationContext);
    }

    public JavaRDD<LabeledPoint> preProcess() throws MLModelBuilderException {
        JavaRDD javaRDD = null;
        try {
            MLModelConfigurationContext context = getContext();
            HeaderFilter build = new HeaderFilter.Builder().init(context).build();
            LineToTokens build2 = new LineToTokens.Builder().init(context).build();
            DiscardedRowsFilter build3 = new DiscardedRowsFilter.Builder().init(context).build();
            RemoveDiscardedFeatures build4 = new RemoveDiscardedFeatures.Builder().init(context).build();
            BasicEncoder build5 = new BasicEncoder.Builder().init(context).build();
            MeanImputation build6 = new MeanImputation.Builder().init(context).build();
            StringArrayToDoubleArray build7 = new StringArrayToDoubleArray.Builder().build();
            DoubleArrayToLabeledPoint build8 = new DoubleArrayToLabeledPoint.Builder().build();
            javaRDD = context.getLines().cache();
            JavaRDD<LabeledPoint> map = javaRDD.filter(build).map(build2).filter(build3).map(build4).map(build5).map(build6).map(build7).map(build8);
            if (javaRDD != null) {
                javaRDD.unpersist();
            }
            return map;
        } catch (Throwable th) {
            if (javaRDD != null) {
                javaRDD.unpersist();
            }
            throw th;
        }
    }

    @Override // org.wso2.carbon.ml.core.interfaces.MLModelBuilder
    public MLModel build() throws MLModelBuilderException {
        ModelSummary buildLassoRegressionModel;
        MLModelConfigurationContext context = getContext();
        DatabaseService databaseService = MLCoreServiceValueHolder.getInstance().getDatabaseService();
        MLModel mLModel = new MLModel();
        try {
            JavaSparkContext sparkContext = context.getSparkContext();
            Workflow facts = context.getFacts();
            long modelId = context.getModelId();
            String typeOfResponseVariable = getTypeOfResponseVariable(facts.getResponseVariable(), facts.getFeatures());
            if (typeOfResponseVariable == null) {
                throw new MLModelBuilderException("Type of response variable cannot be null for supervised learning algorithms.");
            }
            if (facts.getAlgorithmClass().equals(AlgorithmType.NUMERICAL_PREDICTION.getValue()) && typeOfResponseVariable.equals("CATEGORICAL")) {
                throw new MLModelBuilderException("Categorical attribute " + facts.getResponseVariable() + " cannot be used as the response variable of the Numerical Prediction algorithm: " + facts.getAlgorithmName());
            }
            int responseIndex = context.getResponseIndex();
            SortedMap<Integer, String> includedFeaturesAfterReordering = MLUtils.getIncludedFeaturesAfterReordering(facts, context.getNewToOldIndicesList(), responseIndex);
            JavaRDD cache = preProcess().cache();
            JavaRDD<LabeledPoint>[] randomSplit = cache.randomSplit(new double[]{facts.getTrainDataFraction(), 1.0d - facts.getTrainDataFraction()}, MLConstants.RANDOM_SEED.longValue());
            cache.unpersist();
            JavaRDD<LabeledPoint> cache2 = randomSplit[0].cache();
            JavaRDD<LabeledPoint> javaRDD = randomSplit[1];
            mLModel.setAlgorithmName(facts.getAlgorithmName());
            mLModel.setAlgorithmClass(facts.getAlgorithmClass());
            mLModel.setFeatures(facts.getIncludedFeatures());
            mLModel.setResponseVariable(facts.getResponseVariable());
            mLModel.setEncodings(context.getEncodings());
            mLModel.setNewToOldIndicesList(context.getNewToOldIndicesList());
            mLModel.setResponseIndex(responseIndex);
            switch (AnonymousClass1.$SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$SUPERVISED_ALGORITHM[MLConstants.SUPERVISED_ALGORITHM.valueOf(facts.getAlgorithmName()).ordinal()]) {
                case 1:
                    buildLassoRegressionModel = buildLogisticRegressionModel(sparkContext, modelId, cache2, javaRDD, facts, mLModel, includedFeaturesAfterReordering, true);
                    break;
                case 2:
                    buildLassoRegressionModel = buildLogisticRegressionModel(sparkContext, modelId, cache2, javaRDD, facts, mLModel, includedFeaturesAfterReordering, false);
                    break;
                case 3:
                    buildLassoRegressionModel = buildDecisionTreeModel(sparkContext, modelId, cache2, javaRDD, facts, mLModel, includedFeaturesAfterReordering, getCategoricalFeatureInfo(context.getEncodings()));
                    break;
                case 4:
                    buildLassoRegressionModel = buildRandomForestTreeModel(sparkContext, modelId, cache2, javaRDD, facts, mLModel, includedFeaturesAfterReordering, getCategoricalFeatureInfo(context.getEncodings()));
                    break;
                case 5:
                    buildLassoRegressionModel = buildSVMModel(sparkContext, modelId, cache2, javaRDD, facts, mLModel, includedFeaturesAfterReordering);
                    break;
                case 6:
                    buildLassoRegressionModel = buildNaiveBayesModel(sparkContext, modelId, cache2, javaRDD, facts, mLModel, includedFeaturesAfterReordering);
                    break;
                case 7:
                    buildLassoRegressionModel = buildLinearRegressionModel(sparkContext, modelId, cache2, javaRDD, facts, mLModel, includedFeaturesAfterReordering);
                    break;
                case 8:
                    buildLassoRegressionModel = buildRidgeRegressionModel(sparkContext, modelId, cache2, javaRDD, facts, mLModel, includedFeaturesAfterReordering);
                    break;
                case 9:
                    buildLassoRegressionModel = buildLassoRegressionModel(sparkContext, modelId, cache2, javaRDD, facts, mLModel, includedFeaturesAfterReordering);
                    break;
                default:
                    throw new AlgorithmNameException("Incorrect algorithm name");
            }
            databaseService.updateModelSummary(modelId, buildLassoRegressionModel);
            return mLModel;
        } catch (Exception e) {
            throw new MLModelBuilderException("An error occurred while building supervised machine learning model: " + e.getMessage(), e);
        }
    }

    private String getTypeOfResponseVariable(String str, List<Feature> list) {
        String str2 = null;
        for (Feature feature : list) {
            if (feature.getName().equals(str)) {
                str2 = feature.getType();
            }
        }
        return str2;
    }

    private Map<Integer, Integer> getCategoricalFeatureInfo(List<Map<String, Integer>> list) {
        HashMap hashMap = new HashMap();
        for (int i = 0; i < list.size() - 1; i++) {
            if (list.get(i).size() > 0) {
                hashMap.put(Integer.valueOf(i), Integer.valueOf(list.get(i).size()));
            }
        }
        return hashMap;
    }

    private ModelSummary buildLogisticRegressionModel(JavaSparkContext javaSparkContext, long j, JavaRDD<LabeledPoint> javaRDD, JavaRDD<LabeledPoint> javaRDD2, Workflow workflow, MLModel mLModel, SortedMap<Integer, String> sortedMap, boolean z) throws MLModelBuilderException {
        String supervised_algorithm;
        LogisticRegressionModel trainWithLBFGS;
        try {
            LogisticRegression logisticRegression = new LogisticRegression();
            Map hyperParameters = workflow.getHyperParameters();
            int noOfClasses = getNoOfClasses(mLModel);
            if (z) {
                supervised_algorithm = MLConstants.SUPERVISED_ALGORITHM.LOGISTIC_REGRESSION.toString();
                if (noOfClasses > 2) {
                    throw new MLModelBuilderException("A binary classification algorithm cannot have more than two distinct values in response variable.");
                }
                trainWithLBFGS = logisticRegression.trainWithSGD(javaRDD, Double.parseDouble((String) hyperParameters.get("Learning_Rate")), Integer.parseInt((String) hyperParameters.get("Iterations")), (String) hyperParameters.get("Reg_Type"), Double.parseDouble((String) hyperParameters.get("Reg_Parameter")), Double.parseDouble((String) hyperParameters.get("SGD_Data_Fraction")));
            } else {
                supervised_algorithm = MLConstants.SUPERVISED_ALGORITHM.LOGISTIC_REGRESSION_LBFGS.toString();
                trainWithLBFGS = logisticRegression.trainWithLBFGS(javaRDD, (String) hyperParameters.get("Reg_Type"), noOfClasses);
            }
            javaRDD.unpersist();
            javaRDD2.cache();
            Vector weights = trainWithLBFGS.weights();
            if (!isValidWeights(weights)) {
                throw new MLModelBuilderException("Weights of the model generated are null or infinity. [Weights] " + vectorToString(weights));
            }
            MulticlassMetrics multiclassMetrics = new MulticlassMetrics(JavaRDD.toRDD(logisticRegression.test(trainWithLBFGS, javaRDD2)));
            MulticlassConfusionMatrix multiclassConfusionMatrix = getMulticlassConfusionMatrix(multiclassMetrics, mLModel);
            trainWithLBFGS.clearThreshold();
            ProbabilisticClassificationModelSummary generateProbabilisticClassificationModelSummary = SparkModelUtils.generateProbabilisticClassificationModelSummary(javaSparkContext, javaRDD2, logisticRegression.test(trainWithLBFGS, javaRDD2));
            mLModel.setModel(new MLClassificationModel(trainWithLBFGS));
            javaRDD2.unpersist();
            List<FeatureImportance> featureWeights = getFeatureWeights(sortedMap, trainWithLBFGS.weights().toArray());
            generateProbabilisticClassificationModelSummary.setFeatures((String[]) sortedMap.values().toArray(new String[0]));
            generateProbabilisticClassificationModelSummary.setFeatureImportance(featureWeights);
            generateProbabilisticClassificationModelSummary.setAlgorithm(supervised_algorithm);
            generateProbabilisticClassificationModelSummary.setMulticlassConfusionMatrix(multiclassConfusionMatrix);
            generateProbabilisticClassificationModelSummary.setModelAccuracy(getModelAccuracy(multiclassMetrics).doubleValue());
            generateProbabilisticClassificationModelSummary.setDatasetVersion(workflow.getDatasetVersion());
            return generateProbabilisticClassificationModelSummary;
        } catch (Exception e) {
            throw new MLModelBuilderException("An error occurred while building logistic regression model: " + e.getMessage(), e);
        }
    }

    private int getNoOfClasses(MLModel mLModel) {
        if (mLModel.getEncodings() == null) {
            return -1;
        }
        int size = mLModel.getEncodings().size() - 1;
        if (mLModel.getEncodings().get(size) != null) {
            return ((Map) mLModel.getEncodings().get(size)).size();
        }
        return -1;
    }

    private ModelSummary buildDecisionTreeModel(JavaSparkContext javaSparkContext, long j, JavaRDD<LabeledPoint> javaRDD, JavaRDD<LabeledPoint> javaRDD2, Workflow workflow, MLModel mLModel, SortedMap<Integer, String> sortedMap, Map<Integer, Integer> map) throws MLModelBuilderException {
        try {
            Map hyperParameters = workflow.getHyperParameters();
            DecisionTree decisionTree = new DecisionTree();
            DecisionTreeModel train = decisionTree.train(javaRDD, getNoOfClasses(mLModel), map, (String) hyperParameters.get("Impurity"), Integer.parseInt((String) hyperParameters.get("Max_Depth")), Integer.parseInt((String) hyperParameters.get("Max_Bins")));
            javaRDD.unpersist();
            javaRDD2.cache();
            JavaPairRDD<Double, Double> cache = decisionTree.test(train, javaRDD2).cache();
            ClassClassificationAndRegressionModelSummary classClassificationModelSummary = SparkModelUtils.getClassClassificationModelSummary(javaSparkContext, javaRDD2, cache);
            javaRDD2.unpersist();
            mLModel.setModel(new MLDecisionTreeModel(train));
            classClassificationModelSummary.setFeatures((String[]) sortedMap.values().toArray(new String[0]));
            classClassificationModelSummary.setAlgorithm(MLConstants.SUPERVISED_ALGORITHM.DECISION_TREE.toString());
            MulticlassMetrics multiclassMetrics = getMulticlassMetrics(javaSparkContext, cache);
            cache.unpersist();
            classClassificationModelSummary.setMulticlassConfusionMatrix(getMulticlassConfusionMatrix(multiclassMetrics, mLModel));
            classClassificationModelSummary.setModelAccuracy(getModelAccuracy(multiclassMetrics).doubleValue());
            classClassificationModelSummary.setDatasetVersion(workflow.getDatasetVersion());
            return classClassificationModelSummary;
        } catch (Exception e) {
            throw new MLModelBuilderException("An error occurred while building decision tree model: " + e.getMessage(), e);
        }
    }

    private ModelSummary buildRandomForestTreeModel(JavaSparkContext javaSparkContext, long j, JavaRDD<LabeledPoint> javaRDD, JavaRDD<LabeledPoint> javaRDD2, Workflow workflow, MLModel mLModel, SortedMap<Integer, String> sortedMap, Map<Integer, Integer> map) throws MLModelBuilderException {
        try {
            Map hyperParameters = workflow.getHyperParameters();
            RandomForest randomForest = new RandomForest();
            RandomForestModel train = randomForest.train(javaRDD, getNoOfClasses(mLModel), map, Integer.parseInt((String) hyperParameters.get("Num_Trees")), (String) hyperParameters.get("Feature_Subset_Strategy"), (String) hyperParameters.get("Impurity"), Integer.parseInt((String) hyperParameters.get("Max_Depth")), Integer.parseInt((String) hyperParameters.get("Max_Bins")), Integer.parseInt((String) hyperParameters.get("Seed")));
            javaRDD.unpersist();
            javaRDD2.cache();
            JavaPairRDD<Double, Double> cache = randomForest.test(train, javaRDD2).cache();
            ClassClassificationAndRegressionModelSummary classClassificationModelSummary = SparkModelUtils.getClassClassificationModelSummary(javaSparkContext, javaRDD2, cache);
            javaRDD2.unpersist();
            mLModel.setModel(new MLRandomForestModel(train));
            classClassificationModelSummary.setFeatures((String[]) sortedMap.values().toArray(new String[0]));
            classClassificationModelSummary.setAlgorithm(MLConstants.SUPERVISED_ALGORITHM.RANDOM_FOREST.toString());
            MulticlassMetrics multiclassMetrics = getMulticlassMetrics(javaSparkContext, cache);
            cache.unpersist();
            classClassificationModelSummary.setMulticlassConfusionMatrix(getMulticlassConfusionMatrix(multiclassMetrics, mLModel));
            classClassificationModelSummary.setModelAccuracy(getModelAccuracy(multiclassMetrics).doubleValue());
            classClassificationModelSummary.setDatasetVersion(workflow.getDatasetVersion());
            return classClassificationModelSummary;
        } catch (Exception e) {
            throw new MLModelBuilderException("An error occurred while building random forest classification model: " + e.getMessage(), e);
        }
    }

    private ModelSummary buildSVMModel(JavaSparkContext javaSparkContext, long j, JavaRDD<LabeledPoint> javaRDD, JavaRDD<LabeledPoint> javaRDD2, Workflow workflow, MLModel mLModel, SortedMap<Integer, String> sortedMap) throws MLModelBuilderException {
        if (getNoOfClasses(mLModel) > 2) {
            throw new MLModelBuilderException("A binary classification algorithm cannot have more than two distinct values in response variable.");
        }
        try {
            SVM svm = new SVM();
            Map hyperParameters = workflow.getHyperParameters();
            SVMModel train = svm.train(javaRDD, Integer.parseInt((String) hyperParameters.get("Iterations")), (String) hyperParameters.get("Reg_Type"), Double.parseDouble((String) hyperParameters.get("Reg_Parameter")), Double.parseDouble((String) hyperParameters.get("Learning_Rate")), Double.parseDouble((String) hyperParameters.get("SGD_Data_Fraction")));
            javaRDD.unpersist();
            javaRDD2.cache();
            Vector weights = train.weights();
            if (!isValidWeights(weights)) {
                throw new MLModelBuilderException("Weights of the model generated are null or infinity. [Weights] " + vectorToString(weights));
            }
            MulticlassMetrics multiclassMetrics = new MulticlassMetrics(JavaRDD.toRDD(svm.test(train, javaRDD2)));
            MulticlassConfusionMatrix multiclassConfusionMatrix = getMulticlassConfusionMatrix(multiclassMetrics, mLModel);
            train.clearThreshold();
            ProbabilisticClassificationModelSummary generateProbabilisticClassificationModelSummary = SparkModelUtils.generateProbabilisticClassificationModelSummary(javaSparkContext, javaRDD2, svm.test(train, javaRDD2));
            javaRDD2.unpersist();
            mLModel.setModel(new MLClassificationModel(train));
            List<FeatureImportance> featureWeights = getFeatureWeights(sortedMap, train.weights().toArray());
            generateProbabilisticClassificationModelSummary.setFeatures((String[]) sortedMap.values().toArray(new String[0]));
            generateProbabilisticClassificationModelSummary.setFeatureImportance(featureWeights);
            generateProbabilisticClassificationModelSummary.setAlgorithm(MLConstants.SUPERVISED_ALGORITHM.SVM.toString());
            generateProbabilisticClassificationModelSummary.setMulticlassConfusionMatrix(multiclassConfusionMatrix);
            generateProbabilisticClassificationModelSummary.setModelAccuracy(getModelAccuracy(multiclassMetrics).doubleValue());
            generateProbabilisticClassificationModelSummary.setDatasetVersion(workflow.getDatasetVersion());
            return generateProbabilisticClassificationModelSummary;
        } catch (Exception e) {
            throw new MLModelBuilderException("An error occurred while building SVM model: " + e.getMessage(), e);
        }
    }

    private ModelSummary buildLinearRegressionModel(JavaSparkContext javaSparkContext, long j, JavaRDD<LabeledPoint> javaRDD, JavaRDD<LabeledPoint> javaRDD2, Workflow workflow, MLModel mLModel, SortedMap<Integer, String> sortedMap) throws MLModelBuilderException {
        try {
            LinearRegression linearRegression = new LinearRegression();
            Map hyperParameters = workflow.getHyperParameters();
            LinearRegressionModel train = linearRegression.train(javaRDD, Integer.parseInt((String) hyperParameters.get("Iterations")), Double.parseDouble((String) hyperParameters.get("Learning_Rate")), Double.parseDouble((String) hyperParameters.get("SGD_Data_Fraction")));
            javaRDD.unpersist();
            javaRDD2.cache();
            Vector weights = train.weights();
            if (!isValidWeights(weights)) {
                throw new MLModelBuilderException("Weights of the model generated are null or infinity. [Weights] " + vectorToString(weights));
            }
            JavaRDD<Tuple2<Double, Double>> cache = linearRegression.test(train, javaRDD2).cache();
            ClassClassificationAndRegressionModelSummary generateRegressionModelSummary = SparkModelUtils.generateRegressionModelSummary(javaSparkContext, javaRDD2, cache);
            javaRDD2.unpersist();
            mLModel.setModel(new MLGeneralizedLinearModel(train));
            List<FeatureImportance> featureWeights = getFeatureWeights(sortedMap, train.weights().toArray());
            generateRegressionModelSummary.setFeatures((String[]) sortedMap.values().toArray(new String[0]));
            generateRegressionModelSummary.setFeatureImportance(featureWeights);
            generateRegressionModelSummary.setAlgorithm(MLConstants.SUPERVISED_ALGORITHM.LINEAR_REGRESSION.toString());
            RegressionMetrics regressionMetrics = getRegressionMetrics(javaSparkContext, cache);
            cache.unpersist();
            generateRegressionModelSummary.setMeanSquaredError(Double.valueOf(regressionMetrics.meanSquaredError()).doubleValue());
            generateRegressionModelSummary.setDatasetVersion(workflow.getDatasetVersion());
            return generateRegressionModelSummary;
        } catch (Exception e) {
            throw new MLModelBuilderException("An error occurred while building linear regression model: " + e.getMessage(), e);
        }
    }

    private ModelSummary buildRidgeRegressionModel(JavaSparkContext javaSparkContext, long j, JavaRDD<LabeledPoint> javaRDD, JavaRDD<LabeledPoint> javaRDD2, Workflow workflow, MLModel mLModel, SortedMap<Integer, String> sortedMap) throws MLModelBuilderException {
        try {
            RidgeRegression ridgeRegression = new RidgeRegression();
            Map hyperParameters = workflow.getHyperParameters();
            RidgeRegressionModel train = ridgeRegression.train(javaRDD, Integer.parseInt((String) hyperParameters.get("Iterations")), Double.parseDouble((String) hyperParameters.get("Learning_Rate")), Double.parseDouble((String) hyperParameters.get("Reg_Parameter")), Double.parseDouble((String) hyperParameters.get("SGD_Data_Fraction")));
            javaRDD.unpersist();
            javaRDD2.cache();
            Vector weights = train.weights();
            if (!isValidWeights(weights)) {
                throw new MLModelBuilderException("Weights of the model generated are null or infinity. [Weights] " + vectorToString(weights));
            }
            JavaRDD<Tuple2<Double, Double>> cache = ridgeRegression.test(train, javaRDD2).cache();
            ClassClassificationAndRegressionModelSummary generateRegressionModelSummary = SparkModelUtils.generateRegressionModelSummary(javaSparkContext, javaRDD2, cache);
            javaRDD2.unpersist();
            mLModel.setModel(new MLGeneralizedLinearModel(train));
            List<FeatureImportance> featureWeights = getFeatureWeights(sortedMap, train.weights().toArray());
            generateRegressionModelSummary.setFeatures((String[]) sortedMap.values().toArray(new String[0]));
            generateRegressionModelSummary.setAlgorithm(MLConstants.SUPERVISED_ALGORITHM.RIDGE_REGRESSION.toString());
            generateRegressionModelSummary.setFeatureImportance(featureWeights);
            RegressionMetrics regressionMetrics = getRegressionMetrics(javaSparkContext, cache);
            cache.unpersist();
            generateRegressionModelSummary.setMeanSquaredError(Double.valueOf(regressionMetrics.meanSquaredError()).doubleValue());
            generateRegressionModelSummary.setDatasetVersion(workflow.getDatasetVersion());
            return generateRegressionModelSummary;
        } catch (Exception e) {
            throw new MLModelBuilderException("An error occurred while building ridge regression model: " + e.getMessage(), e);
        }
    }

    private ModelSummary buildLassoRegressionModel(JavaSparkContext javaSparkContext, long j, JavaRDD<LabeledPoint> javaRDD, JavaRDD<LabeledPoint> javaRDD2, Workflow workflow, MLModel mLModel, SortedMap<Integer, String> sortedMap) throws MLModelBuilderException {
        try {
            LassoRegression lassoRegression = new LassoRegression();
            Map hyperParameters = workflow.getHyperParameters();
            LassoModel train = lassoRegression.train(javaRDD, Integer.parseInt((String) hyperParameters.get("Iterations")), Double.parseDouble((String) hyperParameters.get("Learning_Rate")), Double.parseDouble((String) hyperParameters.get("Reg_Parameter")), Double.parseDouble((String) hyperParameters.get("SGD_Data_Fraction")));
            javaRDD.unpersist();
            javaRDD2.cache();
            Vector weights = train.weights();
            if (!isValidWeights(weights)) {
                throw new MLModelBuilderException("Weights of the model generated are null or infinity. [Weights] " + vectorToString(weights));
            }
            JavaRDD<Tuple2<Double, Double>> cache = lassoRegression.test(train, javaRDD2).cache();
            ClassClassificationAndRegressionModelSummary generateRegressionModelSummary = SparkModelUtils.generateRegressionModelSummary(javaSparkContext, javaRDD2, cache);
            javaRDD2.unpersist();
            mLModel.setModel(new MLGeneralizedLinearModel(train));
            List<FeatureImportance> featureWeights = getFeatureWeights(sortedMap, train.weights().toArray());
            generateRegressionModelSummary.setFeatures((String[]) sortedMap.values().toArray(new String[0]));
            generateRegressionModelSummary.setAlgorithm(MLConstants.SUPERVISED_ALGORITHM.LASSO_REGRESSION.toString());
            generateRegressionModelSummary.setFeatureImportance(featureWeights);
            RegressionMetrics regressionMetrics = getRegressionMetrics(javaSparkContext, cache);
            cache.unpersist();
            generateRegressionModelSummary.setMeanSquaredError(Double.valueOf(regressionMetrics.meanSquaredError()).doubleValue());
            generateRegressionModelSummary.setDatasetVersion(workflow.getDatasetVersion());
            return generateRegressionModelSummary;
        } catch (Exception e) {
            throw new MLModelBuilderException("An error occurred while building lasso regression model: " + e.getMessage(), e);
        }
    }

    private ModelSummary buildNaiveBayesModel(JavaSparkContext javaSparkContext, long j, JavaRDD<LabeledPoint> javaRDD, JavaRDD<LabeledPoint> javaRDD2, Workflow workflow, MLModel mLModel, SortedMap<Integer, String> sortedMap) throws MLModelBuilderException {
        try {
            Map hyperParameters = workflow.getHyperParameters();
            NaiveBayesClassifier naiveBayesClassifier = new NaiveBayesClassifier();
            NaiveBayesModel train = naiveBayesClassifier.train(javaRDD, Double.parseDouble((String) hyperParameters.get("Lambda")));
            javaRDD.unpersist();
            javaRDD2.cache();
            JavaPairRDD<Double, Double> cache = naiveBayesClassifier.test(train, javaRDD2).cache();
            ClassClassificationAndRegressionModelSummary classClassificationModelSummary = SparkModelUtils.getClassClassificationModelSummary(javaSparkContext, javaRDD2, cache);
            javaRDD2.unpersist();
            mLModel.setModel(new MLClassificationModel(train));
            classClassificationModelSummary.setFeatures((String[]) sortedMap.values().toArray(new String[0]));
            classClassificationModelSummary.setAlgorithm(MLConstants.SUPERVISED_ALGORITHM.NAIVE_BAYES.toString());
            MulticlassMetrics multiclassMetrics = getMulticlassMetrics(javaSparkContext, cache);
            cache.unpersist();
            classClassificationModelSummary.setMulticlassConfusionMatrix(getMulticlassConfusionMatrix(multiclassMetrics, mLModel));
            classClassificationModelSummary.setModelAccuracy(getModelAccuracy(multiclassMetrics).doubleValue());
            classClassificationModelSummary.setDatasetVersion(workflow.getDatasetVersion());
            return classClassificationModelSummary;
        } catch (Exception e) {
            throw new MLModelBuilderException("An error occurred while building naive bayes model: " + e.getMessage(), e);
        }
    }

    private List<FeatureImportance> getFeatureWeights(SortedMap<Integer, String> sortedMap, double[] dArr) {
        ArrayList arrayList = new ArrayList();
        int i = 0;
        for (String str : sortedMap.values()) {
            FeatureImportance featureImportance = new FeatureImportance();
            featureImportance.setLabel(str);
            featureImportance.setValue(dArr[i]);
            arrayList.add(featureImportance);
            i++;
        }
        return arrayList;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public MulticlassMetrics getMulticlassMetrics(JavaSparkContext javaSparkContext, JavaPairRDD<Double, Double> javaPairRDD) {
        List<Tuple2> collect = javaPairRDD.collect();
        ArrayList arrayList = new ArrayList();
        for (Tuple2 tuple2 : collect) {
            arrayList.add(new Tuple2(tuple2._1, tuple2._2));
        }
        JavaRDD cache = javaSparkContext.parallelize(arrayList).cache();
        RDD rdd = JavaRDD.toRDD(cache);
        cache.unpersist();
        return new MulticlassMetrics(rdd);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public MulticlassConfusionMatrix getMulticlassConfusionMatrix(MulticlassMetrics multiclassMetrics, MLModel mLModel) {
        MulticlassConfusionMatrix multiclassConfusionMatrix = new MulticlassConfusionMatrix();
        if (multiclassMetrics != null) {
            int numCols = multiclassMetrics.confusionMatrix().numCols();
            double[] array = multiclassMetrics.confusionMatrix().toArray();
            double[][] dArr = new double[numCols][numCols];
            for (int i = 0; i < numCols; i++) {
                for (int i2 = 0; i2 < numCols; i2++) {
                    dArr[i][i2] = array[(i2 * numCols) + i];
                }
            }
            multiclassConfusionMatrix.setMatrix(dArr);
            List encodings = mLModel.getEncodings();
            if (encodings != null) {
                Map map = (Map) encodings.get(encodings.size() - 1);
                ArrayList arrayList = new ArrayList();
                for (double d : multiclassMetrics.labels()) {
                    String str = (String) MLUtils.getKeyByValue(map, Integer.valueOf((int) d));
                    if (str != null) {
                        arrayList.add(str);
                    }
                }
                multiclassConfusionMatrix.setLabels(arrayList);
            } else {
                multiclassConfusionMatrix.setLabels(toStringList(multiclassMetrics.labels()));
            }
            multiclassConfusionMatrix.setSize(numCols);
        }
        return multiclassConfusionMatrix;
    }

    private RegressionMetrics getRegressionMetrics(JavaSparkContext javaSparkContext, JavaRDD<Tuple2<Double, Double>> javaRDD) {
        List<Tuple2> collect = javaRDD.collect();
        ArrayList arrayList = new ArrayList();
        for (Tuple2 tuple2 : collect) {
            arrayList.add(new Tuple2(tuple2._1, tuple2._2));
        }
        JavaRDD cache = javaSparkContext.parallelize(arrayList).cache();
        RDD rdd = JavaRDD.toRDD(cache);
        cache.unpersist();
        return new RegressionMetrics(rdd);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Double getModelAccuracy(MulticlassMetrics multiclassMetrics) {
        DecimalFormat decimalFormat = new DecimalFormat("#.00");
        Double valueOf = Double.valueOf(0.0d);
        int numCols = multiclassMetrics.confusionMatrix().numCols();
        int i = 0;
        long arraySum = arraySum(multiclassMetrics.confusionMatrix().toArray());
        for (int i2 = 0; i2 < numCols; i2++) {
            i = (int) (i + multiclassMetrics.confusionMatrix().toArray()[multiclassMetrics.confusionMatrix().index(i2, i2)]);
        }
        if (arraySum > 0) {
            valueOf = Double.valueOf(i / arraySum);
        }
        return Double.valueOf(Double.parseDouble(decimalFormat.format(valueOf.doubleValue() * 100.0d)));
    }

    protected long arraySum(double[] dArr) {
        long j = 0;
        for (double d : dArr) {
            j = (long) (j + d);
        }
        return j;
    }

    private boolean isValidWeights(Vector vector) {
        for (int i = 0; i < vector.size(); i++) {
            double apply = vector.apply(i);
            if (Double.isNaN(apply) || Double.isInfinite(apply)) {
                return false;
            }
        }
        return true;
    }

    private String vectorToString(Vector vector) {
        StringBuilder sb = new StringBuilder();
        for (int i = 1; i <= vector.size(); i++) {
            sb.append(vector.apply(i - 1));
            if (i != vector.size()) {
                sb.append(",");
            }
        }
        return sb.toString();
    }

    private List<String> toStringList(double[] dArr) {
        ArrayList arrayList = new ArrayList(dArr.length);
        for (double d : dArr) {
            arrayList.add(String.valueOf(d));
        }
        return arrayList;
    }
}
