package org.wso2.carbon.ml.core.impl;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Properties;
import java.util.regex.Pattern;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
import org.wso2.carbon.context.PrivilegedCarbonContext;
import org.wso2.carbon.ml.commons.domain.ClusterPoint;
import org.wso2.carbon.ml.commons.domain.MLModel;
import org.wso2.carbon.ml.commons.domain.MLModelNew;
import org.wso2.carbon.ml.commons.domain.MLStorage;
import org.wso2.carbon.ml.commons.domain.ModelSummary;
import org.wso2.carbon.ml.commons.domain.Workflow;
import org.wso2.carbon.ml.commons.domain.config.ModelStorage;
import org.wso2.carbon.ml.core.exceptions.MLModelBuilderException;
import org.wso2.carbon.ml.core.exceptions.MLModelHandlerException;
import org.wso2.carbon.ml.core.interfaces.MLOutputAdapter;
import org.wso2.carbon.ml.core.internal.MLModelConfigurationContext;
import org.wso2.carbon.ml.core.spark.algorithms.KMeans;
import org.wso2.carbon.ml.core.spark.algorithms.SupervisedModel;
import org.wso2.carbon.ml.core.spark.algorithms.UnsupervisedModel;
import org.wso2.carbon.ml.core.spark.transformations.HeaderFilter;
import org.wso2.carbon.ml.core.spark.transformations.LineToTokens;
import org.wso2.carbon.ml.core.spark.transformations.MissingValuesFilter;
import org.wso2.carbon.ml.core.spark.transformations.TokensToVectors;
import org.wso2.carbon.ml.core.utils.MLCoreServiceValueHolder;
import org.wso2.carbon.ml.core.utils.MLUtils;
import org.wso2.carbon.ml.core.utils.ThreadExecutor;
import org.wso2.carbon.ml.database.DatabaseService;
import org.wso2.carbon.ml.database.exceptions.DatabaseHandlerException;
import scala.Tuple2;

/* loaded from: input_file:org/wso2/carbon/ml/core/impl/MLModelHandler.class */
public class MLModelHandler {
    private static final Log log = LogFactory.getLog(MLModelHandler.class);
    private DatabaseService databaseService;
    private Properties mlProperties;
    private ThreadExecutor threadExecutor;

    /* loaded from: input_file:org/wso2/carbon/ml/core/impl/MLModelHandler$ModelBuilder.class */
    class ModelBuilder implements Runnable {
        private long id;
        private MLModelConfigurationContext ctxt;
        private int tenantId;
        private String tenantDomain;
        private String username;
        private String emailNotificationEndpoint = MLCoreServiceValueHolder.getInstance().getEmailNotificationEndpoint();

        public ModelBuilder(long j, MLModelConfigurationContext mLModelConfigurationContext) {
            this.id = j;
            this.ctxt = mLModelConfigurationContext;
            PrivilegedCarbonContext threadLocalCarbonContext = PrivilegedCarbonContext.getThreadLocalCarbonContext();
            this.tenantId = threadLocalCarbonContext.getTenantId();
            this.tenantDomain = threadLocalCarbonContext.getTenantDomain();
            this.username = threadLocalCarbonContext.getUsername();
        }

        @Override // java.lang.Runnable
        public void run() {
            MLModel buildModel;
            String[] strArr = {this.username};
            try {
                try {
                    PrivilegedCarbonContext.startTenantFlow();
                    PrivilegedCarbonContext.getThreadLocalCarbonContext().setTenantId(this.tenantId);
                    PrivilegedCarbonContext.getThreadLocalCarbonContext().setTenantDomain(this.tenantDomain);
                    Thread.currentThread().setContextClassLoader(JavaSparkContext.class.getClassLoader());
                    String algorithmClass = this.ctxt.getFacts().getAlgorithmClass();
                    if ("Classification".equals(algorithmClass) || "Numerical_Prediction".equals(algorithmClass)) {
                        buildModel = new SupervisedModel().buildModel(this.ctxt);
                    } else {
                        if (!"Clustering".equals(algorithmClass)) {
                            throw new MLModelBuilderException(String.format("Failed to build the model [id] %s . Invalid algorithm type: %s", Long.valueOf(this.id), algorithmClass));
                        }
                        buildModel = new UnsupervisedModel().buildModel(this.ctxt);
                    }
                    MLModelHandler.this.persistModel(this.id, this.ctxt.getModel().getName(), buildModel);
                    EmailNotificationSender.sendModelBuildingCompleteNotification(this.emailNotificationEndpoint, strArr);
                    PrivilegedCarbonContext.endTenantFlow();
                } catch (Exception e) {
                    MLModelHandler.log.error(String.format("Failed to build the model [id] %s ", Long.valueOf(this.id)), e);
                    try {
                        MLModelHandler.this.databaseService.updateModelStatus(this.id, "Failed");
                        MLModelHandler.this.databaseService.updateModelError(this.id, e.getMessage());
                    } catch (DatabaseHandlerException e2) {
                        MLModelHandler.log.error(String.format("Failed to update the status of model [id] %s ", Long.valueOf(this.id)), e);
                    }
                    EmailNotificationSender.sendModelBuildingFailedNotification(this.emailNotificationEndpoint, strArr);
                    PrivilegedCarbonContext.endTenantFlow();
                }
            } catch (Throwable th) {
                PrivilegedCarbonContext.endTenantFlow();
                throw th;
            }
        }
    }

    public MLModelHandler() {
        MLCoreServiceValueHolder mLCoreServiceValueHolder = MLCoreServiceValueHolder.getInstance();
        this.databaseService = mLCoreServiceValueHolder.getDatabaseService();
        this.mlProperties = mLCoreServiceValueHolder.getMlProperties();
        this.threadExecutor = new ThreadExecutor(this.mlProperties);
    }

    public MLModelNew createModel(MLModelNew mLModelNew) throws MLModelHandlerException {
        try {
            ModelStorage modelStorage = MLCoreServiceValueHolder.getInstance().getModelStorage();
            mLModelNew.setStorageType(modelStorage.getStorageType());
            mLModelNew.setStorageDirectory(modelStorage.getStorageDirectory());
            mLModelNew.setName(this.databaseService.getAnalysis(mLModelNew.getTenantId(), mLModelNew.getUserName(), mLModelNew.getAnalysisId()).getName() + ".Model." + MLUtils.getDate());
            mLModelNew.setStatus("Not Started");
            this.databaseService.insertModel(mLModelNew);
            log.info(String.format("[Created] %s", mLModelNew));
            return mLModelNew;
        } catch (DatabaseHandlerException e) {
            throw new MLModelHandlerException(e.getMessage(), e);
        }
    }

    public void deleteModel(int i, String str, long j) throws MLModelHandlerException {
        try {
            this.databaseService.deleteModel(i, str, j);
            log.info(String.format("[Deleted] Model [id] %s", Long.valueOf(j)));
        } catch (DatabaseHandlerException e) {
            throw new MLModelHandlerException(e.getMessage(), e);
        }
    }

    public long getModelId(int i, String str, String str2) throws MLModelHandlerException {
        try {
            return this.databaseService.getModelId(i, str, str2);
        } catch (DatabaseHandlerException e) {
            throw new MLModelHandlerException(e.getMessage(), e);
        }
    }

    public MLModelNew getModel(int i, String str, String str2) throws MLModelHandlerException {
        try {
            return this.databaseService.getModel(i, str, str2);
        } catch (DatabaseHandlerException e) {
            throw new MLModelHandlerException(e.getMessage(), e);
        }
    }

    public List<MLModelNew> getAllModels(int i, String str) throws MLModelHandlerException {
        try {
            return this.databaseService.getAllModels(i, str);
        } catch (DatabaseHandlerException e) {
            throw new MLModelHandlerException(e.getMessage(), e);
        }
    }

    public boolean isValidModelId(int i, String str, long j) throws MLModelHandlerException {
        try {
            return this.databaseService.isValidModelId(i, str, j);
        } catch (DatabaseHandlerException e) {
            throw new MLModelHandlerException(e.getMessage(), e);
        }
    }

    public void addStorage(long j, MLStorage mLStorage) throws MLModelHandlerException {
        try {
            this.databaseService.updateModelStorage(j, mLStorage.getType(), mLStorage.getLocation());
        } catch (DatabaseHandlerException e) {
            throw new MLModelHandlerException(e.getMessage(), e);
        }
    }

    public ModelSummary getModelSummary(long j) throws MLModelHandlerException {
        try {
            return this.databaseService.getModelSummary(j);
        } catch (DatabaseHandlerException e) {
            throw new MLModelHandlerException(e.getMessage(), e);
        }
    }

    public void buildModel(int i, String str, long j) throws MLModelHandlerException, MLModelBuilderException {
        if (!isValidModelId(i, str, j)) {
            throw new MLModelHandlerException(String.format("Failed to build the model. Invalid model id: %s for tenant: %s and user: %s", Long.valueOf(j), Integer.valueOf(i), str));
        }
        ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader();
        try {
            try {
                MLModelNew model = this.databaseService.getModel(i, str, j);
                Thread.currentThread().setContextClassLoader(JavaSparkContext.class.getClassLoader());
                long datasetVersionIdOfModel = this.databaseService.getDatasetVersionIdOfModel(j);
                String columnSeparator = MLUtils.ColumnSeparatorFactory.getColumnSeparator(this.databaseService.getDataTypeOfModel(j));
                String datasetVersionUri = this.databaseService.getDatasetVersionUri(datasetVersionIdOfModel);
                SparkConf sparkConf = MLCoreServiceValueHolder.getInstance().getSparkConf();
                Workflow workflow = this.databaseService.getWorkflow(model.getAnalysisId());
                MLModelConfigurationContext mLModelConfigurationContext = new MLModelConfigurationContext();
                mLModelConfigurationContext.setModelId(j);
                mLModelConfigurationContext.setColumnSeparator(columnSeparator);
                mLModelConfigurationContext.setFacts(workflow);
                mLModelConfigurationContext.setModel(model);
                sparkConf.setAppName(String.valueOf(j));
                JavaSparkContext javaSparkContext = new JavaSparkContext(sparkConf);
                JavaRDD<String> textFile = javaSparkContext.textFile(datasetVersionUri);
                String featureNamesInOrderUsingDatasetVersion = this.databaseService.getFeatureNamesInOrderUsingDatasetVersion(datasetVersionIdOfModel, columnSeparator);
                mLModelConfigurationContext.setSparkContext(javaSparkContext);
                mLModelConfigurationContext.setLines(textFile);
                mLModelConfigurationContext.setHeaderRow(featureNamesInOrderUsingDatasetVersion);
                this.threadExecutor.execute(new ModelBuilder(j, mLModelConfigurationContext));
                this.databaseService.updateModelStatus(j, "In Progress");
                log.info(String.format("Build model [id] %s job is successfully submitted to Spark.", Long.valueOf(j)));
                Thread.currentThread().setContextClassLoader(contextClassLoader);
            } catch (DatabaseHandlerException e) {
                throw new MLModelBuilderException("An error occurred while saving model to database: " + e.getMessage(), e);
            }
        } catch (Throwable th) {
            Thread.currentThread().setContextClassLoader(contextClassLoader);
            throw th;
        }
    }

    public List<?> predict(int i, String str, long j, List<double[]> list) throws MLModelHandlerException, MLModelBuilderException {
        if (!isValidModelId(i, str, j)) {
            throw new MLModelHandlerException(String.format("Failed to build the model. Invalid model id: %s for tenant: %s and user: %s", Long.valueOf(j), Integer.valueOf(i), str));
        }
        List<?> predict = new Predictor(j, retrieveModel(j), list).predict();
        log.info(String.format("Prediction from model [id] %s was successful.", Long.valueOf(j)));
        return predict;
    }

    public List<?> predict(int i, String str, long j, String[] strArr) throws MLModelHandlerException, MLModelBuilderException {
        ArrayList arrayList = new ArrayList();
        arrayList.add(MLUtils.toDoubleArray(strArr));
        return predict(i, str, j, arrayList);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void persistModel(long j, String str, MLModel mLModel) throws MLModelBuilderException {
        try {
            MLStorage modelStorage = this.databaseService.getModelStorage(j);
            String type = modelStorage.getType();
            String location = modelStorage.getLocation();
            MLOutputAdapter outputAdapter = new MLIOFactory(this.mlProperties).getOutputAdapter(type + ".out");
            ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
            ObjectOutputStream objectOutputStream = new ObjectOutputStream(byteArrayOutputStream);
            objectOutputStream.writeObject(mLModel);
            objectOutputStream.flush();
            objectOutputStream.close();
            ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(byteArrayOutputStream.toByteArray());
            String str2 = location + File.separator + str;
            outputAdapter.write(str2, byteArrayInputStream);
            this.databaseService.updateModelStorage(j, type, str2);
        } catch (Exception e) {
            throw new MLModelBuilderException("Failed to persist the model [id] " + j, e);
        }
    }

    public MLModel retrieveModel(long j) throws MLModelBuilderException {
        InputStream inputStream = null;
        ObjectInputStream objectInputStream = null;
        try {
            try {
                MLStorage modelStorage = this.databaseService.getModelStorage(j);
                inputStream = new MLIOFactory(this.mlProperties).getInputAdapter(modelStorage.getType() + ".in").readDataset(new URI(modelStorage.getLocation()));
                objectInputStream = new ObjectInputStream(inputStream);
                MLModel mLModel = (MLModel) objectInputStream.readObject();
                if (inputStream != null) {
                    try {
                        inputStream.close();
                    } catch (IOException e) {
                    }
                }
                if (objectInputStream != null) {
                    try {
                        objectInputStream.close();
                    } catch (IOException e2) {
                    }
                }
                return mLModel;
            } catch (Exception e3) {
                throw new MLModelBuilderException("Failed to retrieve the model [id] " + j, e3);
            }
        } catch (Throwable th) {
            if (inputStream != null) {
                try {
                    inputStream.close();
                } catch (IOException e4) {
                }
            }
            if (objectInputStream != null) {
                try {
                    objectInputStream.close();
                } catch (IOException e5) {
                }
            }
            throw th;
        }
    }

    public List<ClusterPoint> getClusterPoints(int i, String str, long j, String str2, int i2) throws DatabaseHandlerException {
        ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader();
        List asList = Arrays.asList(str2.split("\\s*,\\s*"));
        try {
            try {
                ArrayList arrayList = new ArrayList();
                String datasetUri = this.databaseService.getDatasetUri(j);
                Thread.currentThread().setContextClassLoader(JavaSparkContext.class.getClassLoader());
                SparkConf sparkConf = MLCoreServiceValueHolder.getInstance().getSparkConf();
                sparkConf.setAppName(String.valueOf(j));
                JavaSparkContext javaSparkContext = new JavaSparkContext(sparkConf);
                JavaRDD textFile = javaSparkContext.textFile(datasetUri);
                String columnSeparator = MLUtils.ColumnSeparatorFactory.getColumnSeparator(datasetUri);
                String featureNamesInOrder = this.databaseService.getFeatureNamesInOrder(j, columnSeparator);
                Pattern compile = Pattern.compile(columnSeparator);
                ArrayList arrayList2 = new ArrayList();
                Iterator it = asList.iterator();
                while (it.hasNext()) {
                    arrayList2.add(Integer.valueOf(MLUtils.getFeatureIndex((String) it.next(), featureNamesInOrder, columnSeparator)));
                }
                double sampleSize = MLCoreServiceValueHolder.getInstance().getSummaryStatSettings().getSampleSize() / (textFile.count() - 1);
                JavaRDD map = sampleSize >= 1.0d ? textFile.filter(new HeaderFilter(featureNamesInOrder)).map(new LineToTokens(compile)).filter(new MissingValuesFilter()).map(new TokensToVectors(arrayList2)) : textFile.filter(new HeaderFilter(featureNamesInOrder)).sample(false, sampleSize).map(new LineToTokens(compile)).filter(new MissingValuesFilter()).map(new TokensToVectors(arrayList2));
                for (Tuple2 tuple2 : new KMeans().train(map, i2, 100).predict(map).zip(map).collect()) {
                    ClusterPoint clusterPoint = new ClusterPoint();
                    clusterPoint.setCluster(((Integer) tuple2._1()).intValue());
                    clusterPoint.setFeatures(((Vector) tuple2._2()).toArray());
                    arrayList.add(clusterPoint);
                }
                javaSparkContext.stop();
                Thread.currentThread().setContextClassLoader(contextClassLoader);
                return arrayList;
            } catch (DatabaseHandlerException e) {
                throw new DatabaseHandlerException("An error occurred while generating cluster points: " + e.getMessage(), e);
            }
        } catch (Throwable th) {
            Thread.currentThread().setContextClassLoader(contextClassLoader);
            throw th;
        }
    }
}
