package org.wso2.carbon.ml.core.spark.algorithms;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.SortedMap;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.classification.LogisticRegressionModel;
import org.apache.spark.mllib.classification.NaiveBayesModel;
import org.apache.spark.mllib.classification.SVMModel;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.regression.LassoModel;
import org.apache.spark.mllib.regression.LinearRegressionModel;
import org.apache.spark.mllib.regression.RidgeRegressionModel;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.wso2.carbon.ml.commons.constants.MLConstants;
import org.wso2.carbon.ml.commons.domain.MLModel;
import org.wso2.carbon.ml.commons.domain.ModelSummary;
import org.wso2.carbon.ml.commons.domain.Workflow;
import org.wso2.carbon.ml.core.exceptions.AlgorithmNameException;
import org.wso2.carbon.ml.core.exceptions.MLModelBuilderException;
import org.wso2.carbon.ml.core.internal.MLModelConfigurationContext;
import org.wso2.carbon.ml.core.spark.summary.ClassClassificationAndRegressionModelSummary;
import org.wso2.carbon.ml.core.spark.summary.FeatureImportance;
import org.wso2.carbon.ml.core.spark.summary.ProbabilisticClassificationModelSummary;
import org.wso2.carbon.ml.core.spark.transformations.DoubleArrayToLabeledPoint;
import org.wso2.carbon.ml.core.utils.MLCoreServiceValueHolder;
import org.wso2.carbon.ml.core.utils.MLUtils;
import org.wso2.carbon.ml.database.DatabaseService;

/* loaded from: input_file:org/wso2/carbon/ml/core/spark/algorithms/SupervisedModel.class */
public class SupervisedModel {
    private static final Log logger = LogFactory.getLog(SupervisedModel.class);

    /* renamed from: org.wso2.carbon.ml.core.spark.algorithms.SupervisedModel$1, reason: invalid class name */
    /* loaded from: input_file:org/wso2/carbon/ml/core/spark/algorithms/SupervisedModel$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$SUPERVISED_ALGORITHM = new int[MLConstants.SUPERVISED_ALGORITHM.values().length];

        static {
            try {
                $SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$SUPERVISED_ALGORITHM[MLConstants.SUPERVISED_ALGORITHM.LOGISTIC_REGRESSION.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$SUPERVISED_ALGORITHM[MLConstants.SUPERVISED_ALGORITHM.DECISION_TREE.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$SUPERVISED_ALGORITHM[MLConstants.SUPERVISED_ALGORITHM.SVM.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$SUPERVISED_ALGORITHM[MLConstants.SUPERVISED_ALGORITHM.NAIVE_BAYES.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$SUPERVISED_ALGORITHM[MLConstants.SUPERVISED_ALGORITHM.LINEAR_REGRESSION.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$SUPERVISED_ALGORITHM[MLConstants.SUPERVISED_ALGORITHM.RIDGE_REGRESSION.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$SUPERVISED_ALGORITHM[MLConstants.SUPERVISED_ALGORITHM.LASSO_REGRESSION.ordinal()] = 7;
            } catch (NoSuchFieldError e7) {
            }
        }
    }

    public MLModel buildModel(MLModelConfigurationContext mLModelConfigurationContext) throws MLModelBuilderException {
        ModelSummary buildLassoRegressionModel;
        JavaSparkContext javaSparkContext = null;
        DatabaseService databaseService = MLCoreServiceValueHolder.getInstance().getDatabaseService();
        MLModel mLModel = new MLModel();
        try {
            try {
                JavaSparkContext sparkContext = mLModelConfigurationContext.getSparkContext();
                Workflow facts = mLModelConfigurationContext.getFacts();
                String headerRow = mLModelConfigurationContext.getHeaderRow();
                String columnSeparator = mLModelConfigurationContext.getColumnSeparator();
                long modelId = mLModelConfigurationContext.getModelId();
                JavaRDD<double[]> preProcess = SparkModelUtils.preProcess(sparkContext, facts, mLModelConfigurationContext.getLines(), headerRow, columnSeparator);
                int featureIndex = MLUtils.getFeatureIndex(facts.getResponseVariable(), headerRow, columnSeparator);
                SortedMap<Integer, String> includedFeatures = MLUtils.getIncludedFeatures(facts, featureIndex);
                JavaRDD map = preProcess.map(new DoubleArrayToLabeledPoint(includedFeatures, featureIndex));
                JavaRDD<LabeledPoint> sample = map.sample(false, facts.getTrainDataFraction(), MLConstants.RANDOM_SEED.longValue());
                JavaRDD<LabeledPoint> subtract = map.subtract(sample);
                mLModel.setAlgorithmName(facts.getAlgorithmName());
                mLModel.setAlgorithmClass(facts.getAlgorithmClass());
                mLModel.setFeatures(facts.getIncludedFeatures());
                mLModel.setResponseVariable(facts.getResponseVariable());
                switch (AnonymousClass1.$SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$SUPERVISED_ALGORITHM[MLConstants.SUPERVISED_ALGORITHM.valueOf(facts.getAlgorithmName()).ordinal()]) {
                    case 1:
                        buildLassoRegressionModel = buildLogisticRegressionModel(modelId, sample, subtract, facts, mLModel, includedFeatures);
                        break;
                    case 2:
                        buildLassoRegressionModel = buildDecisionTreeModel(modelId, sample, subtract, facts, mLModel, includedFeatures);
                        break;
                    case 3:
                        buildLassoRegressionModel = buildSVMModel(modelId, sample, subtract, facts, mLModel, includedFeatures);
                        break;
                    case 4:
                        buildLassoRegressionModel = buildNaiveBayesModel(modelId, sample, subtract, facts, mLModel, includedFeatures);
                        break;
                    case 5:
                        buildLassoRegressionModel = buildLinearRegressionModel(modelId, sample, subtract, facts, mLModel, includedFeatures);
                        break;
                    case 6:
                        buildLassoRegressionModel = buildRidgeRegressionModel(modelId, sample, subtract, facts, mLModel, includedFeatures);
                        break;
                    case 7:
                        buildLassoRegressionModel = buildLassoRegressionModel(modelId, sample, subtract, facts, mLModel, includedFeatures);
                        break;
                    default:
                        throw new AlgorithmNameException("Incorrect algorithm name");
                }
                databaseService.updateModelSummary(modelId, buildLassoRegressionModel);
                if (sparkContext != null) {
                    sparkContext.stop();
                }
                return mLModel;
            } catch (Exception e) {
                throw new MLModelBuilderException("An error occurred while building supervised machine learning model: " + e.getMessage(), e);
            }
        } catch (Throwable th) {
            if (0 != 0) {
                javaSparkContext.stop();
            }
            throw th;
        }
    }

    private ModelSummary buildLogisticRegressionModel(long j, JavaRDD<LabeledPoint> javaRDD, JavaRDD<LabeledPoint> javaRDD2, Workflow workflow, MLModel mLModel, SortedMap<Integer, String> sortedMap) throws MLModelBuilderException {
        try {
            LogisticRegression logisticRegression = new LogisticRegression();
            Map hyperParameters = workflow.getHyperParameters();
            LogisticRegressionModel trainWithSGD = logisticRegression.trainWithSGD(javaRDD, Double.parseDouble((String) hyperParameters.get("Learning_Rate")), Integer.parseInt((String) hyperParameters.get("Iterations")), (String) hyperParameters.get("Reg_Type"), Double.parseDouble((String) hyperParameters.get("Reg_Parameter")), Double.parseDouble((String) hyperParameters.get("SGD_Data_Fraction")));
            trainWithSGD.clearThreshold();
            ProbabilisticClassificationModelSummary generateProbabilisticClassificationModelSummary = SparkModelUtils.generateProbabilisticClassificationModelSummary(logisticRegression.test(trainWithSGD, javaRDD2));
            mLModel.setModel(trainWithSGD);
            List<FeatureImportance> featureWeights = getFeatureWeights(sortedMap, trainWithSGD.weights().toArray());
            generateProbabilisticClassificationModelSummary.setFeatures((String[]) sortedMap.values().toArray(new String[0]));
            generateProbabilisticClassificationModelSummary.setFeatureImportance(featureWeights);
            generateProbabilisticClassificationModelSummary.setAlgorithm(MLConstants.SUPERVISED_ALGORITHM.LOGISTIC_REGRESSION.toString());
            return generateProbabilisticClassificationModelSummary;
        } catch (Exception e) {
            throw new MLModelBuilderException("An error occurred while building logistic regression model: " + e.getMessage(), e);
        }
    }

    private ModelSummary buildDecisionTreeModel(long j, JavaRDD<LabeledPoint> javaRDD, JavaRDD<LabeledPoint> javaRDD2, Workflow workflow, MLModel mLModel, SortedMap<Integer, String> sortedMap) throws MLModelBuilderException {
        try {
            Map hyperParameters = workflow.getHyperParameters();
            DecisionTree decisionTree = new DecisionTree();
            DecisionTreeModel train = decisionTree.train(javaRDD, Integer.parseInt((String) hyperParameters.get("Num_Classes")), new HashMap(), (String) hyperParameters.get("Impurity"), Integer.parseInt((String) hyperParameters.get("Max_Depth")), Integer.parseInt((String) hyperParameters.get("Max_Bins")));
            ClassClassificationAndRegressionModelSummary classClassificationModelSummary = SparkModelUtils.getClassClassificationModelSummary(decisionTree.test(train, javaRDD2));
            mLModel.setModel(train);
            classClassificationModelSummary.setFeatures((String[]) sortedMap.values().toArray(new String[0]));
            classClassificationModelSummary.setAlgorithm(MLConstants.SUPERVISED_ALGORITHM.DECISION_TREE.toString());
            return classClassificationModelSummary;
        } catch (Exception e) {
            throw new MLModelBuilderException("An error occurred while building decision tree model: " + e.getMessage(), e);
        }
    }

    private ModelSummary buildSVMModel(long j, JavaRDD<LabeledPoint> javaRDD, JavaRDD<LabeledPoint> javaRDD2, Workflow workflow, MLModel mLModel, SortedMap<Integer, String> sortedMap) throws MLModelBuilderException {
        try {
            SVM svm = new SVM();
            Map hyperParameters = workflow.getHyperParameters();
            SVMModel train = svm.train(javaRDD, Integer.parseInt((String) hyperParameters.get("Iterations")), (String) hyperParameters.get("Reg_Type"), Double.parseDouble((String) hyperParameters.get("Reg_Parameter")), Double.parseDouble((String) hyperParameters.get("Learning_Rate")), Double.parseDouble((String) hyperParameters.get("SGD_Data_Fraction")));
            train.clearThreshold();
            ProbabilisticClassificationModelSummary generateProbabilisticClassificationModelSummary = SparkModelUtils.generateProbabilisticClassificationModelSummary(svm.test(train, javaRDD2));
            mLModel.setModel(train);
            List<FeatureImportance> featureWeights = getFeatureWeights(sortedMap, train.weights().toArray());
            generateProbabilisticClassificationModelSummary.setFeatures((String[]) sortedMap.values().toArray(new String[0]));
            generateProbabilisticClassificationModelSummary.setFeatureImportance(featureWeights);
            generateProbabilisticClassificationModelSummary.setAlgorithm(MLConstants.SUPERVISED_ALGORITHM.SVM.toString());
            return generateProbabilisticClassificationModelSummary;
        } catch (Exception e) {
            throw new MLModelBuilderException("An error occurred while building SVM model: " + e.getMessage(), e);
        }
    }

    private ModelSummary buildLinearRegressionModel(long j, JavaRDD<LabeledPoint> javaRDD, JavaRDD<LabeledPoint> javaRDD2, Workflow workflow, MLModel mLModel, SortedMap<Integer, String> sortedMap) throws MLModelBuilderException {
        try {
            LinearRegression linearRegression = new LinearRegression();
            Map hyperParameters = workflow.getHyperParameters();
            LinearRegressionModel train = linearRegression.train(javaRDD, Integer.parseInt((String) hyperParameters.get("Iterations")), Double.parseDouble((String) hyperParameters.get("Learning_Rate")), Double.parseDouble((String) hyperParameters.get("SGD_Data_Fraction")));
            ClassClassificationAndRegressionModelSummary generateRegressionModelSummary = SparkModelUtils.generateRegressionModelSummary(linearRegression.test(train, javaRDD2));
            mLModel.setModel(train);
            List<FeatureImportance> featureWeights = getFeatureWeights(sortedMap, train.weights().toArray());
            generateRegressionModelSummary.setFeatures((String[]) sortedMap.values().toArray(new String[0]));
            generateRegressionModelSummary.setFeatureImportance(featureWeights);
            generateRegressionModelSummary.setAlgorithm(MLConstants.SUPERVISED_ALGORITHM.LINEAR_REGRESSION.toString());
            return generateRegressionModelSummary;
        } catch (Exception e) {
            throw new MLModelBuilderException("An error occurred while building linear regression model: " + e.getMessage(), e);
        }
    }

    private ModelSummary buildRidgeRegressionModel(long j, JavaRDD<LabeledPoint> javaRDD, JavaRDD<LabeledPoint> javaRDD2, Workflow workflow, MLModel mLModel, SortedMap<Integer, String> sortedMap) throws MLModelBuilderException {
        try {
            RidgeRegression ridgeRegression = new RidgeRegression();
            Map hyperParameters = workflow.getHyperParameters();
            RidgeRegressionModel train = ridgeRegression.train(javaRDD, Integer.parseInt((String) hyperParameters.get("Iterations")), Double.parseDouble((String) hyperParameters.get("Learning_Rate")), Double.parseDouble((String) hyperParameters.get("Reg_Parameter")), Double.parseDouble((String) hyperParameters.get("SGD_Data_Fraction")));
            ClassClassificationAndRegressionModelSummary generateRegressionModelSummary = SparkModelUtils.generateRegressionModelSummary(ridgeRegression.test(train, javaRDD2));
            mLModel.setModel(train);
            List<FeatureImportance> featureWeights = getFeatureWeights(sortedMap, train.weights().toArray());
            generateRegressionModelSummary.setFeatures((String[]) sortedMap.values().toArray(new String[0]));
            generateRegressionModelSummary.setAlgorithm(MLConstants.SUPERVISED_ALGORITHM.RIDGE_REGRESSION.toString());
            generateRegressionModelSummary.setFeatureImportance(featureWeights);
            return generateRegressionModelSummary;
        } catch (Exception e) {
            throw new MLModelBuilderException("An error occurred while building ridge regression model: " + e.getMessage(), e);
        }
    }

    private ModelSummary buildLassoRegressionModel(long j, JavaRDD<LabeledPoint> javaRDD, JavaRDD<LabeledPoint> javaRDD2, Workflow workflow, MLModel mLModel, SortedMap<Integer, String> sortedMap) throws MLModelBuilderException {
        try {
            LassoRegression lassoRegression = new LassoRegression();
            Map hyperParameters = workflow.getHyperParameters();
            LassoModel train = lassoRegression.train(javaRDD, Integer.parseInt((String) hyperParameters.get("Iterations")), Double.parseDouble((String) hyperParameters.get("Learning_Rate")), Double.parseDouble((String) hyperParameters.get("Reg_Parameter")), Double.parseDouble((String) hyperParameters.get("SGD_Data_Fraction")));
            ClassClassificationAndRegressionModelSummary generateRegressionModelSummary = SparkModelUtils.generateRegressionModelSummary(lassoRegression.test(train, javaRDD2));
            mLModel.setModel(train);
            List<FeatureImportance> featureWeights = getFeatureWeights(sortedMap, train.weights().toArray());
            generateRegressionModelSummary.setFeatures((String[]) sortedMap.values().toArray(new String[0]));
            generateRegressionModelSummary.setAlgorithm(MLConstants.SUPERVISED_ALGORITHM.LASSO_REGRESSION.toString());
            generateRegressionModelSummary.setFeatureImportance(featureWeights);
            return generateRegressionModelSummary;
        } catch (Exception e) {
            throw new MLModelBuilderException("An error occurred while building lasso regression model: " + e.getMessage(), e);
        }
    }

    private ModelSummary buildNaiveBayesModel(long j, JavaRDD<LabeledPoint> javaRDD, JavaRDD<LabeledPoint> javaRDD2, Workflow workflow, MLModel mLModel, SortedMap<Integer, String> sortedMap) throws MLModelBuilderException {
        try {
            Map hyperParameters = workflow.getHyperParameters();
            NaiveBayesClassifier naiveBayesClassifier = new NaiveBayesClassifier();
            NaiveBayesModel train = naiveBayesClassifier.train(javaRDD, Double.parseDouble((String) hyperParameters.get("Lambda")));
            ClassClassificationAndRegressionModelSummary classClassificationModelSummary = SparkModelUtils.getClassClassificationModelSummary(naiveBayesClassifier.test(train, javaRDD));
            mLModel.setModel(train);
            classClassificationModelSummary.setFeatures((String[]) sortedMap.values().toArray(new String[0]));
            classClassificationModelSummary.setAlgorithm(MLConstants.SUPERVISED_ALGORITHM.NAIVE_BAYES.toString());
            return classClassificationModelSummary;
        } catch (Exception e) {
            throw new MLModelBuilderException("An error occurred while building naive bayes model: " + e.getMessage(), e);
        }
    }

    private List<FeatureImportance> getFeatureWeights(SortedMap<Integer, String> sortedMap, double[] dArr) {
        ArrayList arrayList = new ArrayList();
        int i = 0;
        for (String str : sortedMap.values()) {
            FeatureImportance featureImportance = new FeatureImportance();
            featureImportance.setLabel(str);
            featureImportance.setValue(dArr[i]);
            arrayList.add(featureImportance);
            i++;
        }
        return arrayList;
    }
}
