package org.wso2.carbon.ml.core.spark.algorithms;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.SortedMap;
import org.apache.spark.api.java.AbstractJavaRDDLike;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.mllib.linalg.Vector;
import org.wso2.carbon.ml.commons.constants.MLConstants;
import org.wso2.carbon.ml.commons.domain.ClusterPoint;
import org.wso2.carbon.ml.commons.domain.MLModel;
import org.wso2.carbon.ml.commons.domain.ModelSummary;
import org.wso2.carbon.ml.commons.domain.Workflow;
import org.wso2.carbon.ml.core.exceptions.AlgorithmNameException;
import org.wso2.carbon.ml.core.exceptions.MLModelBuilderException;
import org.wso2.carbon.ml.core.interfaces.MLModelBuilder;
import org.wso2.carbon.ml.core.internal.MLModelConfigurationContext;
import org.wso2.carbon.ml.core.spark.MulticlassConfusionMatrix;
import org.wso2.carbon.ml.core.spark.MulticlassMetrics;
import org.wso2.carbon.ml.core.spark.models.MLAnomalyDetectionModel;
import org.wso2.carbon.ml.core.spark.models.ext.AnomalyDetectionModel;
import org.wso2.carbon.ml.core.spark.summary.AnomalyDetectionModelSummary;
import org.wso2.carbon.ml.core.spark.transformations.AnomalyRowsFilter;
import org.wso2.carbon.ml.core.spark.transformations.DiscardedRowsFilter;
import org.wso2.carbon.ml.core.spark.transformations.DoubleArrayToVector;
import org.wso2.carbon.ml.core.spark.transformations.HeaderFilter;
import org.wso2.carbon.ml.core.spark.transformations.LineToTokens;
import org.wso2.carbon.ml.core.spark.transformations.MeanImputation;
import org.wso2.carbon.ml.core.spark.transformations.NormalRowsFilter;
import org.wso2.carbon.ml.core.spark.transformations.Normalization;
import org.wso2.carbon.ml.core.spark.transformations.RemoveDiscardedFeatures;
import org.wso2.carbon.ml.core.spark.transformations.RemoveResponseColumn;
import org.wso2.carbon.ml.core.spark.transformations.StringArrayToDoubleArray;
import org.wso2.carbon.ml.core.utils.MLCoreServiceValueHolder;
import org.wso2.carbon.ml.core.utils.MLUtils;
import org.wso2.carbon.ml.database.DatabaseService;
import org.wso2.carbon.ml.database.exceptions.DatabaseHandlerException;
import scala.Tuple2;

/* loaded from: input_file:org/wso2/carbon/ml/core/spark/algorithms/AnomalyDetectionModelBuilder.class */
public class AnomalyDetectionModelBuilder extends MLModelBuilder {
    public AnomalyDetectionModelBuilder(MLModelConfigurationContext mLModelConfigurationContext) {
        super(mLModelConfigurationContext);
    }

    private JavaRDD<Vector> preProcess(MLConstants.ANOMALY_DETECTION_ALGORITHM anomaly_detection_algorithm, MLConstants.ANOMALY_DETECTION_DATA_TYPE anomaly_detection_data_type) throws MLModelBuilderException {
        MLModelConfigurationContext context = getContext();
        Workflow facts = context.getFacts();
        HeaderFilter build = new HeaderFilter.Builder().init(context).build();
        LineToTokens build2 = new LineToTokens.Builder().init(context).build();
        DiscardedRowsFilter build3 = new DiscardedRowsFilter.Builder().init(context).build();
        RemoveDiscardedFeatures build4 = new RemoveDiscardedFeatures.Builder().init(context).build();
        MeanImputation build5 = new MeanImputation.Builder().init(context).build();
        StringArrayToDoubleArray build6 = new StringArrayToDoubleArray.Builder().build();
        DoubleArrayToVector build7 = new DoubleArrayToVector.Builder().build();
        RemoveResponseColumn removeResponseColumn = new RemoveResponseColumn();
        JavaRDD filter = context.getLines().cache().filter(build).map(build2).filter(build3);
        if (anomaly_detection_data_type != null) {
            switch (anomaly_detection_data_type) {
                case NORMAL:
                    filter = filter.filter(new AnomalyRowsFilter.Builder().init(context).build());
                    break;
                case ANOMALOUS:
                    filter = filter.filter(new NormalRowsFilter.Builder().init(context).build());
                    break;
                default:
                    throw new AlgorithmNameException("Incorrect data type: " + facts.getAlgorithmName());
            }
        }
        JavaRDD map = filter.map(build4);
        if (anomaly_detection_algorithm == MLConstants.ANOMALY_DETECTION_ALGORITHM.K_MEANS_ANOMALY_DETECTION_WITH_LABELED_DATA) {
            map = map.map(removeResponseColumn);
        }
        AbstractJavaRDDLike map2 = map.map(build5).map(build6);
        if (facts.getNormalization()) {
            map2.map(new Normalization.Builder().init(context).build());
        }
        return map2.map(build7);
    }

    @Override // org.wso2.carbon.ml.core.interfaces.MLModelBuilder
    public MLModel build() throws MLModelBuilderException {
        ModelSummary buildLabeledDataAnomalyDetectionModel;
        MLModelConfigurationContext context = getContext();
        DatabaseService databaseService = MLCoreServiceValueHolder.getInstance().getDatabaseService();
        try {
            Workflow facts = context.getFacts();
            long modelId = context.getModelId();
            MLModel mLModel = new MLModel();
            mLModel.setNormalization(facts.getNormalization());
            mLModel.setNormalLabels(facts.getNormalLabels());
            mLModel.setAlgorithmName(facts.getAlgorithmName());
            mLModel.setAlgorithmClass(facts.getAlgorithmClass());
            mLModel.setFeatures(facts.getIncludedFeatures());
            mLModel.setResponseVariable(facts.getResponseVariable());
            mLModel.setEncodings(context.getEncodings());
            mLModel.setNewToOldIndicesList(context.getNewToOldIndicesList());
            mLModel.setSummaryStatsOfFeatures(context.getSummaryStatsOfFeatures());
            SortedMap<Integer, String> includedFeaturesAfterReordering = MLUtils.getIncludedFeaturesAfterReordering(facts, context.getNewToOldIndicesList(), context.getResponseIndex());
            MLConstants.ANOMALY_DETECTION_ALGORITHM valueOf = MLConstants.ANOMALY_DETECTION_ALGORITHM.valueOf(facts.getAlgorithmName());
            switch (valueOf) {
                case K_MEANS_ANOMALY_DETECTION_WITH_UNLABELED_DATA:
                    mLModel.setResponseIndex(-1);
                    buildLabeledDataAnomalyDetectionModel = buildUnlabeledDataAnomalyDetectionModel(modelId, preProcess(valueOf, null).cache(), facts, mLModel, includedFeaturesAfterReordering);
                    break;
                case K_MEANS_ANOMALY_DETECTION_WITH_LABELED_DATA:
                    mLModel.setResponseIndex(context.getResponseIndex());
                    JavaRDD<Vector> cache = preProcess(valueOf, MLConstants.ANOMALY_DETECTION_DATA_TYPE.NORMAL).cache();
                    JavaRDD<Vector> cache2 = cache.sample(false, facts.getTrainDataFraction(), MLConstants.RANDOM_SEED.longValue()).cache();
                    JavaRDD<Vector> cache3 = cache.subtract(cache2).cache();
                    cache.unpersist();
                    JavaRDD<Vector> cache4 = preProcess(valueOf, MLConstants.ANOMALY_DETECTION_DATA_TYPE.ANOMALOUS).cache();
                    JavaRDD<Vector> cache5 = cache4.sample(false, 1.0d - facts.getTrainDataFraction(), MLConstants.RANDOM_SEED.longValue()).cache();
                    cache4.unpersist();
                    buildLabeledDataAnomalyDetectionModel = buildLabeledDataAnomalyDetectionModel(modelId, cache2, cache3, cache5, facts, mLModel, includedFeaturesAfterReordering);
                    break;
                default:
                    throw new AlgorithmNameException("Incorrect algorithm name: " + facts.getAlgorithmName() + " for model id: " + modelId);
            }
            databaseService.updateModelSummary(modelId, buildLabeledDataAnomalyDetectionModel);
            return mLModel;
        } catch (DatabaseHandlerException e) {
            throw new MLModelBuilderException("An error occurred while building anomaly detection machine learning model: " + e.getMessage(), e);
        }
    }

    private ModelSummary buildUnlabeledDataAnomalyDetectionModel(long j, JavaRDD<Vector> javaRDD, Workflow workflow, MLModel mLModel, SortedMap<Integer, String> sortedMap) throws MLModelBuilderException {
        try {
            Map<String, String> hyperParameters = workflow.getHyperParameters();
            AnomalyDetectionModel train = new AnomalyDetection().train(javaRDD, Integer.parseInt(hyperParameters.get(MLConstants.NUM_OF_NORMAL_CLUSTERS)), Integer.parseInt(hyperParameters.get(MLConstants.MAX_ITERATIONS)), workflow.getNewNormalLabel(), workflow.getNewAnomalyLabel());
            javaRDD.unpersist();
            AnomalyDetectionModelSummary anomalyDetectionModelSummary = new AnomalyDetectionModelSummary();
            mLModel.setModel(new MLAnomalyDetectionModel(train));
            anomalyDetectionModelSummary.setAlgorithm(MLConstants.ANOMALY_DETECTION_ALGORITHM.K_MEANS_ANOMALY_DETECTION_WITH_UNLABELED_DATA.toString());
            anomalyDetectionModelSummary.setDatasetVersion(workflow.getDatasetVersion());
            anomalyDetectionModelSummary.setFeatures((String[]) sortedMap.values().toArray(new String[0]));
            return anomalyDetectionModelSummary;
        } catch (Exception e) {
            throw new MLModelBuilderException("An error occurred while building k-means anomaly detection with unlabeled data model: " + e.getMessage(), e);
        }
    }

    private ModelSummary buildLabeledDataAnomalyDetectionModel(long j, JavaRDD<Vector> javaRDD, JavaRDD<Vector> javaRDD2, JavaRDD<Vector> javaRDD3, Workflow workflow, MLModel mLModel, SortedMap<Integer, String> sortedMap) throws MLModelBuilderException {
        try {
            Map<String, String> hyperParameters = workflow.getHyperParameters();
            String newNormalLabel = workflow.getNewNormalLabel();
            String newAnomalyLabel = workflow.getNewAnomalyLabel();
            AnomalyDetectionModel train = new AnomalyDetection().train(javaRDD, Integer.parseInt(hyperParameters.get(MLConstants.NUM_OF_NORMAL_CLUSTERS)), Integer.parseInt(hyperParameters.get(MLConstants.MAX_ITERATIONS)), newNormalLabel, newAnomalyLabel);
            javaRDD.unpersist();
            AnomalyDetectionModelSummary anomalyDetectionModelSummary = new AnomalyDetectionModelSummary();
            mLModel.setModel(new MLAnomalyDetectionModel(train));
            int parseInt = System.getProperty(MLConstants.MAX_PERCENTILE_CONF) == null ? 100 : Integer.parseInt(System.getProperty(MLConstants.MAX_PERCENTILE_CONF));
            int parseInt2 = System.getProperty(MLConstants.MIN_PERCENTILE_CONF) == null ? 80 : Integer.parseInt(System.getProperty(MLConstants.MIN_PERCENTILE_CONF));
            Map<Integer, MulticlassMetrics> evaluationResults = getEvaluationResults(train, javaRDD2, javaRDD3, parseInt2, parseInt, newNormalLabel, newAnomalyLabel);
            javaRDD2.unpersist();
            javaRDD3.unpersist();
            double d = 0.0d;
            int i = parseInt2;
            for (int i2 = parseInt2; i2 <= parseInt; i2++) {
                MulticlassMetrics multiclassMetrics = evaluationResults.get(Integer.valueOf(i2));
                if (multiclassMetrics.getF1Score() > d) {
                    d = multiclassMetrics.getF1Score();
                    i = i2;
                }
            }
            double sampleSize = MLCoreServiceValueHolder.getInstance().getSummaryStatSettings().getSampleSize() / (javaRDD.count() - 1);
            JavaRDD<Vector> sample = sampleSize >= 1.0d ? javaRDD : javaRDD.sample(false, sampleSize);
            List<Tuple2> collect = train.getkMeansModel().predict(sample).zip(sample).collect();
            ArrayList arrayList = new ArrayList();
            for (Tuple2 tuple2 : collect) {
                ClusterPoint clusterPoint = new ClusterPoint();
                clusterPoint.setCluster(((Integer) tuple2.mo2345_1()).intValue());
                double[] dArr = new double[sortedMap.size()];
                for (int i3 = 0; i3 < sortedMap.size(); i3++) {
                    dArr[i3] = ((Vector) tuple2.mo2344_2()).toArray()[i3];
                }
                clusterPoint.setFeatures(dArr);
                arrayList.add(clusterPoint);
            }
            anomalyDetectionModelSummary.setAlgorithm(MLConstants.ANOMALY_DETECTION_ALGORITHM.K_MEANS_ANOMALY_DETECTION_WITH_LABELED_DATA.toString());
            anomalyDetectionModelSummary.setPercentileToMulticlassMetricsMap(evaluationResults);
            anomalyDetectionModelSummary.setClusterPoints(arrayList);
            anomalyDetectionModelSummary.setMinPercentile(parseInt2);
            anomalyDetectionModelSummary.setMaxPercentile(parseInt);
            anomalyDetectionModelSummary.setBestPercentile(i);
            anomalyDetectionModelSummary.setDatasetVersion(workflow.getDatasetVersion());
            anomalyDetectionModelSummary.setFeatures((String[]) sortedMap.values().toArray(new String[0]));
            return anomalyDetectionModelSummary;
        } catch (Exception e) {
            throw new MLModelBuilderException("An error occurred while building k-means anomaly detection with labeled data model: " + e.getMessage(), e);
        }
    }

    public Map<Integer, MulticlassMetrics> getEvaluationResults(AnomalyDetectionModel anomalyDetectionModel, JavaRDD<Vector> javaRDD, JavaRDD<Vector> javaRDD2, int i, int i2, String str, String str2) {
        HashMap hashMap = new HashMap();
        Map<Integer, List<String>> predict = anomalyDetectionModel.predict(javaRDD, i, i2);
        Map<Integer, List<String>> predict2 = anomalyDetectionModel.predict(javaRDD2, i, i2);
        for (int i3 = i; i3 <= i2; i3++) {
            double d = 0.0d;
            double d2 = 0.0d;
            double d3 = 0.0d;
            double d4 = 0.0d;
            Iterator<String> it = predict.get(Integer.valueOf(i3)).iterator();
            while (it.hasNext()) {
                if (it.next().equals(str)) {
                    d2 += 1.0d;
                } else {
                    d3 += 1.0d;
                }
            }
            Iterator<String> it2 = predict2.get(Integer.valueOf(i3)).iterator();
            while (it2.hasNext()) {
                if (it2.next().equals(str2)) {
                    d += 1.0d;
                } else {
                    d4 += 1.0d;
                }
            }
            double[][] dArr = new double[2][2];
            dArr[0][0] = d;
            dArr[0][1] = d4;
            dArr[1][0] = d3;
            dArr[1][1] = d2;
            MulticlassConfusionMatrix multiclassConfusionMatrix = new MulticlassConfusionMatrix();
            multiclassConfusionMatrix.setMatrix(dArr);
            ArrayList arrayList = new ArrayList();
            arrayList.add(0, str2);
            arrayList.add(1, str);
            multiclassConfusionMatrix.setLabels(arrayList);
            multiclassConfusionMatrix.setSize(2);
            hashMap.put(Integer.valueOf(i3), new MulticlassMetrics(multiclassConfusionMatrix));
        }
        return hashMap;
    }
}
