package org.wso2.carbon.ml.core.impl;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.mllib.classification.ClassificationModel;
import org.apache.spark.mllib.clustering.KMeansModel;
import org.apache.spark.mllib.linalg.DenseVector;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.GeneralizedLinearModel;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.wso2.carbon.ml.commons.constants.MLConstants;
import org.wso2.carbon.ml.commons.domain.MLModel;
import org.wso2.carbon.ml.core.exceptions.AlgorithmNameException;
import org.wso2.carbon.ml.core.exceptions.MLModelBuilderException;

/* loaded from: input_file:org/wso2/carbon/ml/core/impl/Predictor.class */
public class Predictor {
    private static final Log log = LogFactory.getLog(Predictor.class);
    private long id;
    private MLModel model;
    private List<Vector> dataToBePredicted;

    /* renamed from: org.wso2.carbon.ml.core.impl.Predictor$1, reason: invalid class name */
    /* loaded from: input_file:org/wso2/carbon/ml/core/impl/Predictor$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$SUPERVISED_ALGORITHM;
        static final /* synthetic */ int[] $SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$UNSUPERVISED_ALGORITHM = new int[MLConstants.UNSUPERVISED_ALGORITHM.values().length];

        static {
            try {
                $SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$UNSUPERVISED_ALGORITHM[MLConstants.UNSUPERVISED_ALGORITHM.K_MEANS.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            $SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$SUPERVISED_ALGORITHM = new int[MLConstants.SUPERVISED_ALGORITHM.values().length];
            try {
                $SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$SUPERVISED_ALGORITHM[MLConstants.SUPERVISED_ALGORITHM.DECISION_TREE.ordinal()] = 1;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    public Predictor(long j, MLModel mLModel, List<double[]> list) {
        this.id = j;
        this.model = mLModel;
        this.dataToBePredicted = getVectors(list);
    }

    public List<?> predict() throws MLModelBuilderException {
        String algorithmClass = this.model.getAlgorithmClass();
        if ("Classification".equals(algorithmClass)) {
            MLConstants.SUPERVISED_ALGORITHM valueOf = MLConstants.SUPERVISED_ALGORITHM.valueOf(this.model.getAlgorithmName());
            ArrayList arrayList = new ArrayList();
            switch (AnonymousClass1.$SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$SUPERVISED_ALGORITHM[valueOf.ordinal()]) {
                case 1:
                    DecisionTreeModel model = this.model.getModel();
                    Iterator<Vector> it = this.dataToBePredicted.iterator();
                    while (it.hasNext()) {
                        double predict = model.predict(it.next());
                        arrayList.add(Double.valueOf(predict));
                        log.info("Prediction: " + predict);
                    }
                    break;
            }
            ClassificationModel model2 = this.model.getModel();
            Iterator<Vector> it2 = this.dataToBePredicted.iterator();
            while (it2.hasNext()) {
                double predict2 = model2.predict(it2.next());
                arrayList.add(Double.valueOf(predict2));
                log.info("Prediction: " + predict2);
            }
            return arrayList;
        }
        if ("Numerical_Prediction".equals(algorithmClass)) {
            GeneralizedLinearModel model3 = this.model.getModel();
            ArrayList arrayList2 = new ArrayList();
            Iterator<Vector> it3 = this.dataToBePredicted.iterator();
            while (it3.hasNext()) {
                double predict3 = model3.predict(it3.next());
                arrayList2.add(Double.valueOf(predict3));
                log.info("Prediction: " + predict3);
            }
            return arrayList2;
        }
        if (!"Clustering".equals(algorithmClass)) {
            throw new MLModelBuilderException(String.format("Failed to build the model [id] %s . Invalid algorithm type: %s", Long.valueOf(this.id), algorithmClass));
        }
        switch (AnonymousClass1.$SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$UNSUPERVISED_ALGORITHM[MLConstants.UNSUPERVISED_ALGORITHM.valueOf(this.model.getAlgorithmName()).ordinal()]) {
            case 1:
                ArrayList arrayList3 = new ArrayList();
                KMeansModel model4 = this.model.getModel();
                Iterator<Vector> it4 = this.dataToBePredicted.iterator();
                while (it4.hasNext()) {
                    int predict4 = model4.predict(it4.next());
                    arrayList3.add(Integer.valueOf(predict4));
                    log.info("Prediction: " + predict4);
                }
                return arrayList3;
            default:
                throw new AlgorithmNameException("Incorrect algorithm name: " + this.model.getAlgorithmName() + " for model id: " + this.id);
        }
    }

    private List<Vector> getVectors(List<double[]> list) {
        ArrayList arrayList = new ArrayList();
        Iterator<double[]> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(new DenseVector(it.next()));
        }
        return arrayList;
    }
}
