package org.wso2.carbon.ml.siddhi.extension;

import java.io.IOException;
import java.net.URISyntaxException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.ObjectUtils;
import org.wso2.carbon.ml.core.exceptions.MLInputAdapterException;
import org.wso2.carbon.ml.core.exceptions.MLModelHandlerException;
import org.wso2.carbon.ml.core.factories.AlgorithmType;
import org.wso2.carbon.ml.core.h2o.POJOPredictor;
import org.wso2.siddhi.core.config.ExecutionPlanContext;
import org.wso2.siddhi.core.event.ComplexEventChunk;
import org.wso2.siddhi.core.event.stream.StreamEvent;
import org.wso2.siddhi.core.event.stream.StreamEventCloner;
import org.wso2.siddhi.core.event.stream.populater.ComplexEventPopulater;
import org.wso2.siddhi.core.exception.ExecutionPlanCreationException;
import org.wso2.siddhi.core.exception.ExecutionPlanRuntimeException;
import org.wso2.siddhi.core.executor.ConstantExpressionExecutor;
import org.wso2.siddhi.core.executor.ExpressionExecutor;
import org.wso2.siddhi.core.executor.VariableExpressionExecutor;
import org.wso2.siddhi.core.query.processor.Processor;
import org.wso2.siddhi.core.query.processor.stream.StreamProcessor;
import org.wso2.siddhi.query.api.definition.AbstractDefinition;
import org.wso2.siddhi.query.api.definition.Attribute;
import org.wso2.siddhi.query.api.exception.ExecutionPlanValidationException;

/* loaded from: input_file:org/wso2/carbon/ml/siddhi/extension/PredictStreamProcessor.class */
public class PredictStreamProcessor extends StreamProcessor {
    private ModelHandler[] modelHandlers;
    private String[] modelStorageLocations;
    private String responseVariable;
    private static final String anomalyPrediction = "prediction";
    private String algorithmClass;
    private String outputType;
    private double percentileValue;
    private boolean isAnomalyDetection;
    private boolean attributeSelectionAvailable;
    private Map<Integer, int[]> attributeIndexMap;
    private POJOPredictor[] pojoPredictor;
    private boolean deeplearningWithoutH2O;

    protected void process(ComplexEventChunk<StreamEvent> complexEventChunk, Processor processor, StreamEventCloner streamEventCloner, ComplexEventPopulater complexEventPopulater) {
        Object mode;
        while (complexEventChunk.hasNext()) {
            StreamEvent next = complexEventChunk.next();
            String[] strArr = new String[this.attributeIndexMap.size()];
            for (Map.Entry<Integer, int[]> entry : this.attributeIndexMap.entrySet()) {
                int intValue = entry.getKey().intValue();
                int[] value = entry.getValue();
                Object obj = null;
                switch (value[2]) {
                    case 0:
                        obj = next.getBeforeWindowData()[value[3]];
                        break;
                    case 2:
                        obj = next.getOutputData()[value[3]];
                        break;
                }
                strArr[intValue] = String.valueOf(obj);
            }
            if (strArr != null) {
                try {
                    Object[] objArr = new Object[this.modelHandlers.length];
                    if (AlgorithmType.CLASSIFICATION.getValue().equals(this.algorithmClass)) {
                        for (int i = 0; i < this.modelHandlers.length; i++) {
                            objArr[i] = this.modelHandlers[i].predict(strArr, this.outputType);
                        }
                        mode = ObjectUtils.mode(objArr);
                    } else if (AlgorithmType.NUMERICAL_PREDICTION.getValue().equals(this.algorithmClass)) {
                        double d = 0.0d;
                        for (int i2 = 0; i2 < this.modelHandlers.length; i2++) {
                            d += Double.parseDouble(this.modelHandlers[i2].predict(strArr, this.outputType).toString());
                        }
                        mode = Double.valueOf(d / this.modelHandlers.length);
                    } else if (AlgorithmType.ANOMALY_DETECTION.getValue().equals(this.algorithmClass)) {
                        for (int i3 = 0; i3 < this.modelHandlers.length; i3++) {
                            objArr[i3] = this.modelHandlers[i3].predict(strArr, this.outputType, this.percentileValue);
                        }
                        mode = ObjectUtils.mode(objArr);
                    } else {
                        if (!AlgorithmType.DEEPLEARNING.getValue().equals(this.algorithmClass)) {
                            throw new ExecutionPlanRuntimeException(String.format("Error while predicting. Prediction is not supported for the algorithm class %s. ", this.algorithmClass));
                        }
                        if (this.deeplearningWithoutH2O) {
                            for (int i4 = 0; i4 < this.modelHandlers.length; i4++) {
                                objArr[i4] = this.modelHandlers[i4].predict(strArr, this.outputType, this.pojoPredictor[i4]);
                            }
                            mode = ObjectUtils.mode(objArr);
                        } else {
                            for (int i5 = 0; i5 < this.modelHandlers.length; i5++) {
                                objArr[i5] = this.modelHandlers[i5].predict(strArr, this.outputType);
                            }
                            mode = ObjectUtils.mode(objArr);
                        }
                    }
                    complexEventPopulater.populateComplexEvent(next, new Object[]{mode});
                } catch (Exception e) {
                    log.error("Error while predicting", e);
                    throw new ExecutionPlanRuntimeException("Error while predicting", e);
                }
            }
        }
        processor.process(complexEventChunk);
    }

    protected List<Attribute> init(AbstractDefinition abstractDefinition, ExpressionExecutor[] expressionExecutorArr, ExecutionPlanContext executionPlanContext) {
        if (expressionExecutorArr.length < 2) {
            throw new ExecutionPlanValidationException("ML model storage locations and response variable type have not been defined as the first two parameters");
        }
        if (expressionExecutorArr.length == 2) {
            this.attributeSelectionAvailable = false;
        } else {
            this.attributeSelectionAvailable = true;
        }
        if (!(expressionExecutorArr[0] instanceof ConstantExpressionExecutor)) {
            throw new ExecutionPlanValidationException("ML model storage-location has not been defined as the first parameter");
        }
        this.modelStorageLocations = ((String) ((ConstantExpressionExecutor) expressionExecutorArr[0]).getValue()).split(",");
        if (!(expressionExecutorArr[1] instanceof ConstantExpressionExecutor)) {
            throw new ExecutionPlanValidationException("Response variable type has not been defined as the second parameter");
        }
        this.outputType = (String) ((ConstantExpressionExecutor) expressionExecutorArr[1]).getValue();
        Attribute.Type outputAttributeType = getOutputAttributeType(this.outputType);
        this.modelHandlers = new ModelHandler[this.modelStorageLocations.length];
        for (int i = 0; i < this.modelStorageLocations.length; i++) {
            try {
                this.modelHandlers[i] = new ModelHandler(this.modelStorageLocations[i]);
            } catch (MLInputAdapterException e) {
                logError(i, e);
            } catch (IOException e2) {
                logError(i, e2);
            } catch (ClassNotFoundException e3) {
                logError(i, e3);
            } catch (URISyntaxException e4) {
                logError(i, e4);
            }
        }
        HashSet hashSet = new HashSet();
        for (int i2 = 0; i2 < this.modelHandlers.length; i2++) {
            hashSet.add(this.modelHandlers[i2].getAlgorithmClass());
        }
        if (hashSet.size() > 1) {
            throw new ExecutionPlanRuntimeException("Algorithm classes are not equal");
        }
        this.algorithmClass = this.modelHandlers[0].getAlgorithmClass();
        HashSet hashSet2 = new HashSet();
        for (int i3 = 0; i3 < this.modelHandlers.length; i3++) {
            hashSet2.add(this.modelHandlers[i3].getFeatures());
        }
        if (hashSet2.size() > 1) {
            throw new ExecutionPlanRuntimeException("Features in models are not equal");
        }
        if (AlgorithmType.ANOMALY_DETECTION.getValue().equals(this.algorithmClass)) {
            this.isAnomalyDetection = true;
        }
        if (this.isAnomalyDetection) {
            if (expressionExecutorArr.length == 3) {
                this.attributeSelectionAvailable = false;
            } else {
                this.attributeSelectionAvailable = true;
            }
            if (!(expressionExecutorArr[2] instanceof ConstantExpressionExecutor)) {
                throw new ExecutionPlanValidationException("percentile value has not been defined as the third parameter");
            }
            this.percentileValue = ((Double) ((ConstantExpressionExecutor) expressionExecutorArr[2]).getValue()).doubleValue();
            return Arrays.asList(new Attribute(anomalyPrediction, outputAttributeType));
        }
        HashSet hashSet3 = new HashSet();
        for (int i4 = 0; i4 < this.modelStorageLocations.length; i4++) {
            hashSet3.add(this.modelHandlers[i4].getResponseVariable());
        }
        if (hashSet3.size() > 1) {
            throw new ExecutionPlanCreationException("Response variables of models are not equal");
        }
        this.responseVariable = this.modelHandlers[0].getResponseVariable();
        if (AlgorithmType.DEEPLEARNING.getValue().equals(this.algorithmClass)) {
            if (this.modelHandlers[0].getMlModel().getModel() != null) {
                this.deeplearningWithoutH2O = false;
            } else {
                this.deeplearningWithoutH2O = true;
                this.pojoPredictor = new POJOPredictor[this.modelHandlers.length];
                for (int i5 = 0; i5 < this.modelHandlers.length; i5++) {
                    try {
                        this.pojoPredictor[i5] = new POJOPredictor(this.modelHandlers[i5].getMlModel(), this.modelStorageLocations[i5]);
                    } catch (MLModelHandlerException e5) {
                        throw new ExecutionPlanRuntimeException("Failed to initialize the POJO predictor of the model " + this.modelStorageLocations[i5], e5);
                    }
                }
            }
        }
        return Arrays.asList(new Attribute(this.responseVariable, outputAttributeType));
    }

    public void start() {
        try {
            populateFeatureAttributeMapping();
        } catch (ExecutionPlanCreationException e) {
            log.error("Error while retrieving ML-models", e);
            throw new ExecutionPlanCreationException("Error while retrieving ML-models\n" + e.getMessage());
        }
    }

    private void populateFeatureAttributeMapping() {
        this.attributeIndexMap = new HashMap();
        Map<String, Integer> features = this.modelHandlers[0].getFeatures();
        List<Integer> newToOldIndicesList = this.modelHandlers[0].getNewToOldIndicesList();
        if (!this.attributeSelectionAvailable) {
            for (String str : this.inputDefinition.getAttributeNameArray()) {
                if (features.get(str) == null) {
                    throw new ExecutionPlanCreationException("No matching feature name found in the models for the attribute : " + str);
                }
                this.attributeIndexMap.put(Integer.valueOf(newToOldIndicesList.indexOf(Integer.valueOf(features.get(str).intValue()))), new int[]{0, 0, 2, this.inputDefinition.getAttributePosition(str)});
            }
            return;
        }
        for (VariableExpressionExecutor variableExpressionExecutor : this.attributeExpressionExecutors) {
            if (variableExpressionExecutor instanceof VariableExpressionExecutor) {
                VariableExpressionExecutor variableExpressionExecutor2 = variableExpressionExecutor;
                String name = variableExpressionExecutor2.getAttribute().getName();
                if (features.get(name) == null) {
                    throw new ExecutionPlanCreationException("No matching feature name found in the models for the attribute : " + name);
                }
                this.attributeIndexMap.put(Integer.valueOf(newToOldIndicesList.indexOf(Integer.valueOf(features.get(name).intValue()))), variableExpressionExecutor2.getPosition());
            }
        }
    }

    private Attribute.Type getOutputAttributeType(String str) {
        if (str.equalsIgnoreCase("double")) {
            return Attribute.Type.DOUBLE;
        }
        if (str.equalsIgnoreCase("float")) {
            return Attribute.Type.FLOAT;
        }
        if (str.equalsIgnoreCase("integer") || str.equalsIgnoreCase("int")) {
            return Attribute.Type.INT;
        }
        if (str.equalsIgnoreCase("long")) {
            return Attribute.Type.LONG;
        }
        if (str.equalsIgnoreCase("string")) {
            return Attribute.Type.STRING;
        }
        if (str.equalsIgnoreCase("boolean") || str.equalsIgnoreCase("bool")) {
            return Attribute.Type.BOOL;
        }
        throw new ExecutionPlanValidationException("Invalid data-type defined for response variable.");
    }

    private void logError(int i, Exception exc) {
        log.error("Error while retrieving ML-model : " + this.modelStorageLocations[i], exc);
        throw new ExecutionPlanCreationException("Error while retrieving ML-model : " + this.modelStorageLocations[i] + "\n" + exc.getMessage());
    }

    public void stop() {
    }

    public Object[] currentState() {
        return new Object[0];
    }

    public void restoreState(Object[] objArr) {
    }
}
