package org.wso2.carbon.ml.core.spark.recommendation;

import java.util.Map;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
import org.apache.spark.mllib.recommendation.Rating;
import org.wso2.carbon.ml.commons.constants.MLConstants;
import org.wso2.carbon.ml.commons.domain.MLModel;
import org.wso2.carbon.ml.commons.domain.ModelSummary;
import org.wso2.carbon.ml.commons.domain.Workflow;
import org.wso2.carbon.ml.core.exceptions.AlgorithmNameException;
import org.wso2.carbon.ml.core.exceptions.MLModelBuilderException;
import org.wso2.carbon.ml.core.interfaces.MLModelBuilder;
import org.wso2.carbon.ml.core.internal.MLModelConfigurationContext;
import org.wso2.carbon.ml.core.spark.models.MLMatrixFactorizationModel;
import org.wso2.carbon.ml.core.spark.summary.RecommendationModelSummary;
import org.wso2.carbon.ml.core.utils.MLCoreServiceValueHolder;
import org.wso2.carbon.ml.database.DatabaseService;

/* loaded from: input_file:org/wso2/carbon/ml/core/spark/recommendation/RecommendationModelBuilder.class */
public class RecommendationModelBuilder extends MLModelBuilder {
    public RecommendationModelBuilder(MLModelConfigurationContext mLModelConfigurationContext) {
        super(mLModelConfigurationContext);
    }

    @Override // org.wso2.carbon.ml.core.interfaces.MLModelBuilder
    public MLModel build() throws MLModelBuilderException {
        ModelSummary buildCollaborativeFilteringModel;
        MLModelConfigurationContext context = getContext();
        DatabaseService databaseService = MLCoreServiceValueHolder.getInstance().getDatabaseService();
        try {
            Workflow facts = context.getFacts();
            long modelId = context.getModelId();
            MLModel mLModel = new MLModel();
            mLModel.setAlgorithmName(facts.getAlgorithmName());
            mLModel.setAlgorithmClass(facts.getAlgorithmClass());
            mLModel.setFeatures(facts.getFeatures());
            switch (MLConstants.RECOMMENDATION_ALGORITHM.valueOf(facts.getAlgorithmName())) {
                case COLLABORATIVE_FILTERING:
                    buildCollaborativeFilteringModel = buildCollaborativeFilteringModel(RecommendationUtils.preProcess(context, false), facts, mLModel, false);
                    break;
                case COLLABORATIVE_FILTERING_IMPLICIT:
                    buildCollaborativeFilteringModel = buildCollaborativeFilteringModel(RecommendationUtils.preProcess(context, true), facts, mLModel, true);
                    break;
                default:
                    throw new AlgorithmNameException("Incorrect algorithm name: " + facts.getAlgorithmName() + " for model id: " + modelId);
            }
            databaseService.updateModelSummary(modelId, buildCollaborativeFilteringModel);
            return mLModel;
        } catch (Exception e) {
            throw new MLModelBuilderException("An error occurred while building recommendation model: " + e.getMessage(), e);
        }
    }

    private ModelSummary buildCollaborativeFilteringModel(JavaRDD<Rating> javaRDD, Workflow workflow, MLModel mLModel, boolean z) throws MLModelBuilderException {
        MatrixFactorizationModel trainExplicit;
        try {
            SparkContext sparkContext = javaRDD.rdd().sparkContext();
            sparkContext.setCheckpointDir(MLConstants.CHECKPOINTING_DIR);
            Map<String, String> hyperParameters = workflow.getHyperParameters();
            CollaborativeFiltering collaborativeFiltering = new CollaborativeFiltering();
            RecommendationModelSummary recommendationModelSummary = new RecommendationModelSummary();
            if (z) {
                trainExplicit = collaborativeFiltering.trainImplicit(javaRDD, Integer.parseInt(hyperParameters.get(MLConstants.RANK)), Integer.parseInt(hyperParameters.get(MLConstants.ITERATIONS)), Double.parseDouble(hyperParameters.get(MLConstants.LAMBDA)), Double.parseDouble(hyperParameters.get(MLConstants.ALPHA)), Integer.parseInt(hyperParameters.get(MLConstants.BLOCKS)));
                recommendationModelSummary.setAlgorithm(MLConstants.RECOMMENDATION_ALGORITHM.COLLABORATIVE_FILTERING_IMPLICIT.toString());
                recommendationModelSummary.setDatasetVersion(workflow.getDatasetVersion());
            } else {
                trainExplicit = collaborativeFiltering.trainExplicit(javaRDD, Integer.parseInt(hyperParameters.get(MLConstants.RANK)), Integer.parseInt(hyperParameters.get(MLConstants.ITERATIONS)), Double.parseDouble(hyperParameters.get(MLConstants.LAMBDA)), Integer.parseInt(hyperParameters.get(MLConstants.BLOCKS)));
                recommendationModelSummary.setAlgorithm(MLConstants.RECOMMENDATION_ALGORITHM.COLLABORATIVE_FILTERING.toString());
                recommendationModelSummary.setDatasetVersion(workflow.getDatasetVersion());
            }
            mLModel.setModel(new MLMatrixFactorizationModel(trainExplicit));
            recommendationModelSummary.setMeanSquaredError(collaborativeFiltering.test(trainExplicit, javaRDD).mean().doubleValue());
            sparkContext.setCheckpointDir(null);
            return recommendationModelSummary;
        } catch (Exception e) {
            throw new MLModelBuilderException("An error occurred while building recommendation model: " + e.getMessage(), e);
        }
    }
}
