package org.wso2.carbon.ml.core.impl;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.mllib.classification.ClassificationModel;
import org.apache.spark.mllib.clustering.KMeansModel;
import org.apache.spark.mllib.linalg.DenseVector;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.GeneralizedLinearModel;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.tree.model.RandomForestModel;
import org.wso2.carbon.metrics.manager.Level;
import org.wso2.carbon.metrics.manager.MetricManager;
import org.wso2.carbon.metrics.manager.Timer;
import org.wso2.carbon.ml.commons.constants.MLConstants;
import org.wso2.carbon.ml.commons.domain.MLModel;
import org.wso2.carbon.ml.core.exceptions.AlgorithmNameException;
import org.wso2.carbon.ml.core.exceptions.MLModelHandlerException;
import org.wso2.carbon.ml.core.factories.AlgorithmType;
import org.wso2.carbon.ml.core.spark.models.MLAnomalyDetectionModel;
import org.wso2.carbon.ml.core.spark.models.MLClassificationModel;
import org.wso2.carbon.ml.core.spark.models.MLDecisionTreeModel;
import org.wso2.carbon.ml.core.spark.models.MLDeeplearningModel;
import org.wso2.carbon.ml.core.spark.models.MLGeneralizedLinearModel;
import org.wso2.carbon.ml.core.spark.models.MLRandomForestModel;
import org.wso2.carbon.ml.core.spark.transformations.BasicEncoder;
import org.wso2.carbon.ml.core.spark.transformations.Normalization;
import org.wso2.carbon.ml.core.utils.DeeplearningModelUtils;
import org.wso2.carbon.ml.core.utils.MLUtils;

/* loaded from: input_file:org/wso2/carbon/ml/core/impl/Predictor.class */
public class Predictor {
    private static final Log log = LogFactory.getLog(Predictor.class);
    private long id;
    private MLModel model;
    private List<Vector> dataToBePredicted;
    private double percentileValue;

    /* renamed from: org.wso2.carbon.ml.core.impl.Predictor$1, reason: invalid class name */
    /* loaded from: input_file:org/wso2/carbon/ml/core/impl/Predictor$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$SUPERVISED_ALGORITHM;
        static final /* synthetic */ int[] $SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$UNSUPERVISED_ALGORITHM;
        static final /* synthetic */ int[] $SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$ANOMALY_DETECTION_ALGORITHM;
        static final /* synthetic */ int[] $SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$DEEPLEARNING_ALGORITHM = new int[MLConstants.DEEPLEARNING_ALGORITHM.values().length];

        static {
            try {
                $SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$DEEPLEARNING_ALGORITHM[MLConstants.DEEPLEARNING_ALGORITHM.STACKED_AUTOENCODERS.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            $SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$ANOMALY_DETECTION_ALGORITHM = new int[MLConstants.ANOMALY_DETECTION_ALGORITHM.values().length];
            try {
                $SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$ANOMALY_DETECTION_ALGORITHM[MLConstants.ANOMALY_DETECTION_ALGORITHM.K_MEANS_ANOMALY_DETECTION_WITH_UNLABELED_DATA.ordinal()] = 1;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$ANOMALY_DETECTION_ALGORITHM[MLConstants.ANOMALY_DETECTION_ALGORITHM.K_MEANS_ANOMALY_DETECTION_WITH_LABELED_DATA.ordinal()] = 2;
            } catch (NoSuchFieldError e3) {
            }
            $SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$UNSUPERVISED_ALGORITHM = new int[MLConstants.UNSUPERVISED_ALGORITHM.values().length];
            try {
                $SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$UNSUPERVISED_ALGORITHM[MLConstants.UNSUPERVISED_ALGORITHM.K_MEANS.ordinal()] = 1;
            } catch (NoSuchFieldError e4) {
            }
            $SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$SUPERVISED_ALGORITHM = new int[MLConstants.SUPERVISED_ALGORITHM.values().length];
            try {
                $SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$SUPERVISED_ALGORITHM[MLConstants.SUPERVISED_ALGORITHM.DECISION_TREE.ordinal()] = 1;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$SUPERVISED_ALGORITHM[MLConstants.SUPERVISED_ALGORITHM.RANDOM_FOREST.ordinal()] = 2;
            } catch (NoSuchFieldError e6) {
            }
        }
    }

    public Predictor(long j, MLModel mLModel, List<String[]> list) {
        this.id = j;
        this.model = mLModel;
        this.dataToBePredicted = getVectors(list);
    }

    public Predictor(long j, MLModel mLModel, List<String[]> list, double d) {
        this.id = j;
        this.model = mLModel;
        this.dataToBePredicted = getVectors(list);
        this.percentileValue = d;
    }

    public List<?> predict() throws MLModelHandlerException {
        double[] dArr;
        String algorithmClass = this.model.getAlgorithmClass();
        AlgorithmType algorithmType = AlgorithmType.getAlgorithmType(algorithmClass);
        Timer timer = getTimer(this.model.getAlgorithmName());
        if (AlgorithmType.CLASSIFICATION == algorithmType) {
            MLConstants.SUPERVISED_ALGORITHM valueOf = MLConstants.SUPERVISED_ALGORITHM.valueOf(this.model.getAlgorithmName());
            ArrayList arrayList = new ArrayList();
            switch (AnonymousClass1.$SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$SUPERVISED_ALGORITHM[valueOf.ordinal()]) {
                case 1:
                    DecisionTreeModel model = ((MLDecisionTreeModel) this.model.getModel()).getModel();
                    for (Vector vector : this.dataToBePredicted) {
                        Timer.Context startTimer = startTimer(timer);
                        double predict = model.predict(vector);
                        arrayList.add(Double.valueOf(predict));
                        stopTimer(startTimer);
                        if (log.isDebugEnabled()) {
                            log.debug("Predicted value before decoding: " + predict);
                        }
                    }
                    return decodePredictedValues(arrayList);
                case 2:
                    RandomForestModel model2 = ((MLRandomForestModel) this.model.getModel()).getModel();
                    for (Vector vector2 : this.dataToBePredicted) {
                        Timer.Context startTimer2 = startTimer(timer);
                        double predict2 = model2.predict(vector2);
                        arrayList.add(Double.valueOf(predict2));
                        stopTimer(startTimer2);
                        if (log.isDebugEnabled()) {
                            log.debug("Predicted value before decoding: " + predict2);
                        }
                    }
                    return decodePredictedValues(arrayList);
                default:
                    ClassificationModel model3 = ((MLClassificationModel) this.model.getModel()).getModel();
                    for (Vector vector3 : this.dataToBePredicted) {
                        Timer.Context startTimer3 = startTimer(timer);
                        double predict3 = model3.predict(vector3);
                        arrayList.add(Double.valueOf(predict3));
                        stopTimer(startTimer3);
                        if (log.isDebugEnabled()) {
                            log.debug("Predicted value before decoding: " + predict3);
                        }
                    }
                    return decodePredictedValues(arrayList);
            }
        }
        if (AlgorithmType.NUMERICAL_PREDICTION == algorithmType) {
            GeneralizedLinearModel model4 = ((MLGeneralizedLinearModel) this.model.getModel()).getModel();
            ArrayList arrayList2 = new ArrayList();
            for (Vector vector4 : this.dataToBePredicted) {
                Timer.Context startTimer4 = startTimer(timer);
                double predict4 = model4.predict(vector4);
                arrayList2.add(Double.valueOf(predict4));
                stopTimer(startTimer4);
                if (log.isDebugEnabled()) {
                    log.debug("Predicted value before decoding: " + predict4);
                }
            }
            return decodePredictedValues(arrayList2);
        }
        if (AlgorithmType.CLUSTERING == algorithmType) {
            switch (AnonymousClass1.$SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$UNSUPERVISED_ALGORITHM[MLConstants.UNSUPERVISED_ALGORITHM.valueOf(this.model.getAlgorithmName()).ordinal()]) {
                case 1:
                    ArrayList arrayList3 = new ArrayList();
                    KMeansModel model5 = this.model.getModel();
                    for (Vector vector5 : this.dataToBePredicted) {
                        Timer.Context startTimer5 = startTimer(timer);
                        int predict5 = model5.predict(vector5);
                        arrayList3.add(Integer.valueOf(predict5));
                        stopTimer(startTimer5);
                        if (log.isDebugEnabled()) {
                            log.debug("Predicted value before decoding: " + predict5);
                        }
                    }
                    return decodePredictedValues(arrayList3);
                default:
                    throw new AlgorithmNameException("Incorrect algorithm name: " + this.model.getAlgorithmName() + " for model id: " + this.id);
            }
        }
        if (AlgorithmType.ANOMALY_DETECTION != algorithmType) {
            if (AlgorithmType.DEEPLEARNING != algorithmType) {
                throw new MLModelHandlerException(String.format("Failed to build the model [id] %s . Invalid algorithm type: %s", Long.valueOf(this.id), algorithmClass));
            }
            switch (AnonymousClass1.$SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$DEEPLEARNING_ALGORITHM[MLConstants.DEEPLEARNING_ALGORITHM.valueOf(this.model.getAlgorithmName()).ordinal()]) {
                case 1:
                    ArrayList arrayList4 = new ArrayList();
                    MLDeeplearningModel mLDeeplearningModel = (MLDeeplearningModel) this.model.getModel();
                    ArrayList arrayList5 = new ArrayList();
                    Iterator<Vector> it = this.dataToBePredicted.iterator();
                    while (it.hasNext()) {
                        arrayList5.add(it.next().toArray());
                    }
                    for (double d : mLDeeplearningModel.predict(DeeplearningModelUtils.doubleArrayListToFrame(arrayList5))) {
                        arrayList4.add(Double.valueOf(d));
                    }
                    return arrayList4;
                default:
                    throw new AlgorithmNameException("Incorrect algorithm name: " + this.model.getAlgorithmName() + " for model id: " + this.id);
            }
        }
        switch (AnonymousClass1.$SwitchMap$org$wso2$carbon$ml$commons$constants$MLConstants$ANOMALY_DETECTION_ALGORITHM[MLConstants.ANOMALY_DETECTION_ALGORITHM.valueOf(this.model.getAlgorithmName()).ordinal()]) {
            case 1:
            case 2:
                ArrayList arrayList6 = new ArrayList();
                MLAnomalyDetectionModel mLAnomalyDetectionModel = (MLAnomalyDetectionModel) this.model.getModel();
                Normalization build = this.model.getNormalization() ? new Normalization.Builder().minMax(this.model.getFeatures(), this.model.getSummaryStatsOfFeatures()).build() : null;
                for (Vector vector6 : this.dataToBePredicted) {
                    Timer.Context startTimer6 = startTimer(timer);
                    if (this.model.getNormalization()) {
                        double[] array = vector6.toArray();
                        try {
                            dArr = build.call(array);
                        } catch (Exception e) {
                            log.warn("Data normalization failed for data: " + array + " Cause: " + e.getMessage());
                            dArr = array;
                        }
                        vector6 = new DenseVector(dArr);
                    }
                    String predict6 = mLAnomalyDetectionModel.getModel().predict(vector6, this.percentileValue);
                    arrayList6.add(predict6);
                    stopTimer(startTimer6);
                    if (log.isDebugEnabled()) {
                        log.debug("Predicted value before decoding: " + predict6);
                    }
                }
                return arrayList6;
            default:
                throw new AlgorithmNameException("Incorrect algorithm name: " + this.model.getAlgorithmName() + " for model id: " + this.id);
        }
    }

    private void stopTimer(Timer.Context context) {
        if (context != null) {
            context.stop();
        }
    }

    private Timer.Context startTimer(Timer timer) {
        if (timer != null) {
            return timer.start();
        }
        return null;
    }

    private Timer getTimer(String str) {
        try {
            return MetricManager.timer(Level.INFO, "org.wso2.carbon.ml.prediction-time." + this.model.getAlgorithmName());
        } catch (IllegalStateException e) {
            log.warn(e.getMessage());
            return null;
        }
    }

    private List<Vector> getVectors(List<String[]> list) {
        String[] strArr;
        ArrayList arrayList = new ArrayList();
        BasicEncoder build = new BasicEncoder.Builder().encodings(this.model.getEncodings()).build();
        for (String[] strArr2 : list) {
            try {
                strArr = build.call(strArr2);
            } catch (Exception e) {
                log.warn("Data encoding failed. Cause: " + e.getMessage());
                strArr = strArr2;
            }
            arrayList.add(new DenseVector(MLUtils.toDoubleArray(strArr)));
        }
        return arrayList;
    }

    private List<?> decodePredictedValues(List<?> list) {
        int intValue;
        if (this.model.getResponseIndex() == -1) {
            return list;
        }
        List encodings = this.model.getEncodings();
        Map<String, Integer> map = (Map) encodings.get(encodings.size() - 1);
        if (map == null || map.isEmpty()) {
            return list;
        }
        ArrayList arrayList = new ArrayList();
        for (Object obj : list) {
            if (obj instanceof Double) {
                intValue = (int) Math.round(((Double) obj).doubleValue());
            } else {
                if (!(obj instanceof Integer)) {
                    return list;
                }
                intValue = ((Integer) obj).intValue();
            }
            String decode = decode(map, intValue);
            if (log.isDebugEnabled()) {
                log.debug("Predicted value after decoding: " + decode);
            }
            arrayList.add(decode);
        }
        return arrayList;
    }

    private String decode(Map<String, Integer> map, int i) {
        String findClass = findClass(map, i);
        if (findClass != null) {
            return findClass;
        }
        int closest = closest(i, map.values());
        findClass(map, closest);
        return String.valueOf(closest);
    }

    private String findClass(Map<String, Integer> map, int i) {
        for (Map.Entry<String, Integer> entry : map.entrySet()) {
            if (i == entry.getValue().intValue()) {
                return entry.getKey();
            }
        }
        return null;
    }

    public int closest(int i, Collection<Integer> collection) {
        int i2 = Integer.MAX_VALUE;
        int i3 = i;
        Iterator<Integer> it = collection.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            int abs = Math.abs(intValue - i);
            if (abs < i2) {
                i2 = abs;
                i3 = intValue;
            }
        }
        return i3;
    }
}
