package org.wso2.extension.siddhi.gpl.execution.pmml;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.apache.log4j.Logger;
import org.dmg.pmml.FieldName;
import org.jpmml.evaluator.Evaluator;
import org.jpmml.evaluator.EvaluatorUtil;
import org.jpmml.evaluator.InputField;
import org.jpmml.evaluator.ModelEvaluatorFactory;
import org.jpmml.evaluator.OutputField;
import org.jpmml.evaluator.TargetField;
import org.wso2.extension.siddhi.gpl.execution.pmml.util.PMMLUtil;
import org.wso2.siddhi.annotation.Example;
import org.wso2.siddhi.annotation.Extension;
import org.wso2.siddhi.annotation.Parameter;
import org.wso2.siddhi.annotation.ReturnAttribute;
import org.wso2.siddhi.annotation.util.DataType;
import org.wso2.siddhi.core.config.SiddhiAppContext;
import org.wso2.siddhi.core.event.ComplexEventChunk;
import org.wso2.siddhi.core.event.stream.StreamEvent;
import org.wso2.siddhi.core.event.stream.StreamEventCloner;
import org.wso2.siddhi.core.event.stream.populater.ComplexEventPopulater;
import org.wso2.siddhi.core.exception.SiddhiAppCreationException;
import org.wso2.siddhi.core.exception.SiddhiAppRuntimeException;
import org.wso2.siddhi.core.executor.ConstantExpressionExecutor;
import org.wso2.siddhi.core.executor.ExpressionExecutor;
import org.wso2.siddhi.core.executor.VariableExpressionExecutor;
import org.wso2.siddhi.core.query.processor.Processor;
import org.wso2.siddhi.core.query.processor.stream.StreamProcessor;
import org.wso2.siddhi.core.util.config.ConfigReader;
import org.wso2.siddhi.query.api.definition.AbstractDefinition;
import org.wso2.siddhi.query.api.definition.Attribute;
import org.wso2.siddhi.query.api.exception.SiddhiAppValidationException;

@Extension(name = "predict", namespace = "pmml", description = "This extension processes the input stream attributes according to the defined PMML standard model and outputs the processed results together with the input stream attributes.", parameters = {@Parameter(name = "path.to.pmml.file", description = "The path to the PMML model file.\n", type = {DataType.STRING}), @Parameter(name = "input", description = "An attribute of the input stream that is sent to the PMML standard model as a value to based on which the prediction is made. The predict function does not accept any constant values as input parameters. You can have multiple input parameters according to the input stream definition.", type = {DataType.STRING}, optional = true, defaultValue = "Empty Array")}, returnAttributes = {@ReturnAttribute(name = "output", description = "All the processed outputs defined in the query. The number of outputs can vary depending on the query definition.", type = {DataType.STRING, DataType.INT, DataType.DOUBLE, DataType.FLOAT, DataType.BOOL})}, examples = {@Example(syntax = "predict('<SP HOME>/samples/artifacts/0301/decision-tree.pmml', root_shell, su_attempted, num_root, num_file_creations, num_shells, num_access_files, num_outbound_cmds, is_host_login, is_guest_login , count, srv_count, serror_rate, srv_serror_rate)", description = "This model is implemented to detect network intruders. The input event stream is processed by the execution plan that uses the pmml predictive model to detect whether a particular user is an intruder to the network or not. The output stream contains the processed query results that include the predicted responses.")})
/* loaded from: input_file:org/wso2/extension/siddhi/gpl/execution/pmml/PmmlModelProcessor.class */
public class PmmlModelProcessor extends StreamProcessor {
    private static final Logger logger = Logger.getLogger(PmmlModelProcessor.class);
    private String pmmlDefinition;
    private boolean attributeSelectionAvailable;
    private Map<InputField, int[]> attributeIndexMap;
    private List<InputField> inputFields;
    private Map<FieldName, org.dmg.pmml.DataType> outputFields = new LinkedHashMap();
    private Evaluator evaluator;

    protected List<Attribute> init(AbstractDefinition abstractDefinition, ExpressionExecutor[] expressionExecutorArr, ConfigReader configReader, SiddhiAppContext siddhiAppContext) {
        if (this.attributeExpressionExecutors.length == 0) {
            throw new SiddhiAppValidationException("PMML model definition not available.");
        }
        this.attributeSelectionAvailable = this.attributeExpressionExecutors.length != 1;
        if (!(this.attributeExpressionExecutors[0] instanceof ConstantExpressionExecutor)) {
            throw new SiddhiAppValidationException("PMML model definition has not been set as the first parameter");
        }
        this.pmmlDefinition = (String) this.attributeExpressionExecutors[0].getValue();
        this.evaluator = ModelEvaluatorFactory.newInstance().newModelEvaluator(PMMLUtil.unmarshal(this.pmmlDefinition));
        this.inputFields = this.evaluator.getActiveFields();
        if (this.evaluator.getOutputFields().size() == 0) {
            for (TargetField targetField : this.evaluator.getTargetFields()) {
                this.outputFields.put(targetField.getName(), targetField.getDataType());
            }
        } else {
            for (OutputField outputField : this.evaluator.getOutputFields()) {
                this.outputFields.put(outputField.getName(), outputField.getDataType());
            }
        }
        return generateOutputAttributes();
    }

    protected void process(ComplexEventChunk<StreamEvent> complexEventChunk, Processor processor, StreamEventCloner streamEventCloner, ComplexEventPopulater complexEventPopulater) {
        StreamEvent first = complexEventChunk.getFirst();
        HashMap hashMap = new HashMap();
        for (Map.Entry<InputField, int[]> entry : this.attributeIndexMap.entrySet()) {
            InputField key = entry.getKey();
            int[] value = entry.getValue();
            Object obj = null;
            switch (value[2]) {
                case 0:
                    obj = first.getBeforeWindowData()[value[3]];
                    break;
                case 2:
                    obj = first.getOutputData()[value[3]];
                    break;
            }
            hashMap.put(key.getName(), key.prepare(String.valueOf(obj)));
        }
        if (hashMap.isEmpty()) {
            return;
        }
        try {
            Map<FieldName, ?> evaluate = this.evaluator.evaluate(hashMap);
            Object[] objArr = new Object[this.outputFields.size()];
            int i = 0;
            for (FieldName fieldName : this.outputFields.keySet()) {
                if (evaluate.containsKey(fieldName)) {
                    objArr[i] = EvaluatorUtil.decode(evaluate.get(fieldName));
                    i++;
                }
            }
            complexEventPopulater.populateComplexEvent(first, objArr);
            processor.process(complexEventChunk);
        } catch (Exception e) {
            logger.error("Error while predicting", e);
            throw new SiddhiAppRuntimeException("Error while predicting", e);
        }
    }

    public void start() {
        try {
            populateFeatureAttributeMapping();
        } catch (Exception e) {
            logger.error("Error while mapping attributes with pmml model features : " + this.pmmlDefinition, e);
            throw new SiddhiAppCreationException("Error while mapping attributes with pmml model features : " + this.pmmlDefinition + "\n" + e.getMessage());
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void populateFeatureAttributeMapping() throws Exception {
        this.attributeIndexMap = new HashMap();
        HashMap hashMap = new HashMap();
        for (InputField inputField : this.inputFields) {
            hashMap.put(inputField.getName().getValue(), inputField);
        }
        if (!this.attributeSelectionAvailable) {
            for (String str : this.inputDefinition.getAttributeNameArray()) {
                if (hashMap.get(str) == null) {
                    throw new SiddhiAppCreationException("No matching feature name found in the model for the attribute : " + str);
                }
                this.attributeIndexMap.put(hashMap.get(str), new int[]{0, 0, 2, this.inputDefinition.getAttributePosition(str)});
            }
            return;
        }
        for (VariableExpressionExecutor variableExpressionExecutor : this.attributeExpressionExecutors) {
            if (variableExpressionExecutor instanceof VariableExpressionExecutor) {
                VariableExpressionExecutor variableExpressionExecutor2 = variableExpressionExecutor;
                String name = variableExpressionExecutor2.getAttribute().getName();
                if (hashMap.get(name) == null) {
                    throw new SiddhiAppCreationException("No matching feature name found in the model for the attribute : " + name);
                }
                this.attributeIndexMap.put(hashMap.get(name), variableExpressionExecutor2.getPosition());
            }
        }
    }

    private List<Attribute> generateOutputAttributes() {
        ArrayList arrayList = new ArrayList();
        for (Map.Entry<FieldName, org.dmg.pmml.DataType> entry : this.outputFields.entrySet()) {
            FieldName key = entry.getKey();
            org.dmg.pmml.DataType value = entry.getValue();
            if (value == null) {
                value = org.dmg.pmml.DataType.STRING;
            }
            arrayList.add(new Attribute(key.getValue(), mapOutputAttributes(value)));
        }
        return arrayList;
    }

    private Attribute.Type mapOutputAttributes(org.dmg.pmml.DataType dataType) {
        Attribute.Type type = null;
        if (dataType.equals(org.dmg.pmml.DataType.DOUBLE)) {
            type = Attribute.Type.DOUBLE;
        } else if (dataType.equals(org.dmg.pmml.DataType.FLOAT)) {
            type = Attribute.Type.FLOAT;
        } else if (dataType.equals(org.dmg.pmml.DataType.INTEGER)) {
            type = Attribute.Type.INT;
        } else if (dataType.equals(org.dmg.pmml.DataType.BOOLEAN)) {
            type = Attribute.Type.BOOL;
        } else if (dataType.equals(org.dmg.pmml.DataType.STRING)) {
            type = Attribute.Type.STRING;
        }
        return type;
    }

    public void stop() {
    }

    public Map<String, Object> currentState() {
        return new HashMap();
    }

    public void restoreState(Map<String, Object> map) {
    }
}
