package org.wso2.carbon.ml.core.spark.algorithms;

import hex.deeplearning.DeepLearningModel;
import java.util.Map;
import java.util.SortedMap;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.evaluation.MulticlassMetrics;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.wso2.carbon.ml.commons.constants.MLConstants;
import org.wso2.carbon.ml.commons.domain.MLModel;
import org.wso2.carbon.ml.commons.domain.ModelSummary;
import org.wso2.carbon.ml.commons.domain.Workflow;
import org.wso2.carbon.ml.core.exceptions.AlgorithmNameException;
import org.wso2.carbon.ml.core.exceptions.MLModelBuilderException;
import org.wso2.carbon.ml.core.internal.MLModelConfigurationContext;
import org.wso2.carbon.ml.core.spark.models.MLDeeplearningModel;
import org.wso2.carbon.ml.core.spark.summary.DeeplearningModelSummary;
import org.wso2.carbon.ml.core.utils.DeeplearningModelUtils;
import org.wso2.carbon.ml.core.utils.MLCoreServiceValueHolder;
import org.wso2.carbon.ml.core.utils.MLUtils;
import org.wso2.carbon.ml.database.DatabaseService;

/* loaded from: input_file:org/wso2/carbon/ml/core/spark/algorithms/DeeplearningModelBuilder.class */
public class DeeplearningModelBuilder extends SupervisedSparkModelBuilder {
    private static final Log log = LogFactory.getLog(DeeplearningModelBuilder.class);

    /* renamed from: org.wso2.carbon.ml.core.spark.algorithms.DeeplearningModelBuilder$1, reason: invalid class name */
    /* loaded from: input_file:org/wso2/carbon/ml/core/spark/algorithms/DeeplearningModelBuilder$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$DEEPLEARNING_ALGORITHM = new int[MLConstants.DEEPLEARNING_ALGORITHM.values().length];

        static {
            try {
                $SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$DEEPLEARNING_ALGORITHM[MLConstants.DEEPLEARNING_ALGORITHM.STACKED_AUTOENCODERS.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
        }
    }

    public DeeplearningModelBuilder(MLModelConfigurationContext mLModelConfigurationContext) {
        super(mLModelConfigurationContext);
    }

    @Override // org.wso2.carbon.ml.core.spark.algorithms.SupervisedSparkModelBuilder, org.wso2.carbon.ml.core.interfaces.MLModelBuilder
    public MLModel build() throws MLModelBuilderException {
        if (log.isDebugEnabled()) {
            log.debug("Start building the Stacked Autoencoders...");
        }
        MLModelConfigurationContext context = getContext();
        DatabaseService databaseService = MLCoreServiceValueHolder.getInstance().getDatabaseService();
        MLModel mLModel = new MLModel();
        try {
            JavaSparkContext sparkContext = context.getSparkContext();
            Workflow facts = context.getFacts();
            long modelId = context.getModelId();
            int responseIndex = context.getResponseIndex();
            SortedMap<Integer, String> includedFeaturesAfterReordering = MLUtils.getIncludedFeaturesAfterReordering(facts, context.getNewToOldIndicesList(), responseIndex);
            JavaRDD<LabeledPoint>[] randomSplit = preProcess().cache().randomSplit(new double[]{facts.getTrainDataFraction(), 1.0d - facts.getTrainDataFraction()}, MLConstants.RANDOM_SEED.longValue());
            JavaRDD<LabeledPoint> javaRDD = randomSplit[0];
            JavaRDD<LabeledPoint> javaRDD2 = randomSplit[1];
            javaRDD.collect();
            javaRDD2.collect();
            mLModel.setAlgorithmName(facts.getAlgorithmName());
            mLModel.setAlgorithmClass(facts.getAlgorithmClass());
            mLModel.setFeatures(facts.getIncludedFeatures());
            mLModel.setResponseVariable(facts.getResponseVariable());
            mLModel.setEncodings(context.getEncodings());
            mLModel.setNewToOldIndicesList(context.getNewToOldIndicesList());
            mLModel.setResponseIndex(responseIndex);
            switch (AnonymousClass1.$SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$DEEPLEARNING_ALGORITHM[MLConstants.DEEPLEARNING_ALGORITHM.valueOf(facts.getAlgorithmName()).ordinal()]) {
                case 1:
                    log.info("Building summary model for SAE");
                    ModelSummary buildStackedAutoencodersModel = buildStackedAutoencodersModel(sparkContext, modelId, javaRDD, javaRDD2, facts, mLModel, includedFeaturesAfterReordering);
                    log.info("Successful building summary model for SAE");
                    databaseService.updateModelSummary(modelId, buildStackedAutoencodersModel);
                    return mLModel;
                default:
                    throw new AlgorithmNameException("Incorrect algorithm name");
            }
        } catch (Exception e) {
            throw new MLModelBuilderException("An error occurred while building supervised machine learning model: " + e.getMessage(), e);
        }
    }

    private int[] stringArrToIntArr(String str) {
        String[] split = str.split(",");
        int[] iArr = new int[split.length];
        for (int i = 0; i < split.length; i++) {
            iArr[i] = Integer.parseInt(split[i]);
        }
        return iArr;
    }

    private ModelSummary buildStackedAutoencodersModel(JavaSparkContext javaSparkContext, long j, JavaRDD<LabeledPoint> javaRDD, JavaRDD<LabeledPoint> javaRDD2, Workflow workflow, MLModel mLModel, SortedMap<Integer, String> sortedMap) throws MLModelBuilderException {
        try {
            StackedAutoencodersClassifier stackedAutoencodersClassifier = new StackedAutoencodersClassifier();
            Map hyperParameters = workflow.getHyperParameters();
            DeepLearningModel train = stackedAutoencodersClassifier.train(javaRDD, Integer.parseInt((String) hyperParameters.get("Batch_Size")), stringArrToIntArr((String) hyperParameters.get("Layer_Sizes")), (String) hyperParameters.get("Activation_Type"), Integer.parseInt((String) hyperParameters.get("Epochs")), workflow.getResponseVariable(), j);
            if (train == null) {
                throw new MLModelBuilderException("DeeplearningModel is Null.");
            }
            javaRDD.unpersist();
            javaRDD2.cache();
            JavaPairRDD<Double, Double> cache = stackedAutoencodersClassifier.test(javaSparkContext, train, javaRDD2).cache();
            DeeplearningModelSummary deeplearningModelSummary = DeeplearningModelUtils.getDeeplearningModelSummary(javaSparkContext, javaRDD2, cache);
            javaRDD2.unpersist();
            mLModel.setModel(new MLDeeplearningModel(train));
            deeplearningModelSummary.setFeatures((String[]) sortedMap.values().toArray(new String[0]));
            deeplearningModelSummary.setAlgorithm(MLConstants.DEEPLEARNING_ALGORITHM.STACKED_AUTOENCODERS.toString());
            MulticlassMetrics multiclassMetrics = getMulticlassMetrics(javaSparkContext, cache);
            cache.unpersist();
            deeplearningModelSummary.setMulticlassConfusionMatrix(getMulticlassConfusionMatrix(multiclassMetrics, mLModel));
            deeplearningModelSummary.setModelAccuracy(getModelAccuracy(multiclassMetrics).doubleValue());
            return deeplearningModelSummary;
        } catch (Exception e) {
            throw new MLModelBuilderException("An error occurred while building stacked autoencoders model: " + e.getMessage(), e);
        }
    }
}
