package org.wso2.extension.siddhi.execution.tensorflow;

import com.google.protobuf.InvalidProtocolBufferException;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import org.apache.log4j.Logger;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.SignatureDef;
import org.wso2.extension.siddhi.execution.tensorflow.util.CoreUtils;
import org.wso2.siddhi.annotation.Example;
import org.wso2.siddhi.annotation.Extension;
import org.wso2.siddhi.annotation.Parameter;
import org.wso2.siddhi.annotation.ReturnAttribute;
import org.wso2.siddhi.annotation.util.DataType;
import org.wso2.siddhi.core.config.SiddhiAppContext;
import org.wso2.siddhi.core.event.ComplexEventChunk;
import org.wso2.siddhi.core.event.stream.StreamEvent;
import org.wso2.siddhi.core.event.stream.StreamEventCloner;
import org.wso2.siddhi.core.event.stream.populater.ComplexEventPopulater;
import org.wso2.siddhi.core.exception.SiddhiAppCreationException;
import org.wso2.siddhi.core.executor.ConstantExpressionExecutor;
import org.wso2.siddhi.core.executor.ExpressionExecutor;
import org.wso2.siddhi.core.executor.VariableExpressionExecutor;
import org.wso2.siddhi.core.query.processor.Processor;
import org.wso2.siddhi.core.query.processor.stream.StreamProcessor;
import org.wso2.siddhi.core.util.config.ConfigReader;
import org.wso2.siddhi.query.api.definition.AbstractDefinition;
import org.wso2.siddhi.query.api.definition.Attribute;

@Extension(name = "predict", namespace = "tensorFlow", description = "Performs inferences (prediction) from an already built TensorFlow machine learning model. The types of models are unlimited (including image classifiers, deep learning models) as long as they satisfy the following conditions.\n1. They are saved with the tag 'serve' in SavedModel format (See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md)\n2. Model is initially trained and ready for inferences\n3. Inference logic is written and saved in the model\n4. signature_def is properly included in the metaGraphDef (a protocol buffer file which has information about the graph) and the key for prediction signature def is 'serving-default'\n\nAlso the prerequisites for inference are as follows.\n1. User knows the names of the input and output nodes\n2. Has a preprocessed data set of Java primitive types or their multidimensional arrays\n\nSince each input is directly used to create a Tensor they should be of compatible shape and data type with the model.\nThe information related to input and output nodes can be retrieved from saved model signature def.signature_def can be read by using the saved_model_cli commands found at https://www.tensorflow.org/programmers_guide/saved_model\nsignature_def can be read in Python as follows\nwith tf.Session() as sess:\n  md = tf.saved_model.loader.load(sess, ['serve'], export_dir)\n  sig = md.signature_def[tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]\n  print(sig)\n\nOr you can read signature def from Java as follows,\nfinal String DEFAULT_SERVING_SIGNATURE_DEF_KEY = \"serving_default\"; \n\nfinal SignatureDef sig =\n      MetaGraphDef.parseFrom(model.metaGraphDef())\n          .getSignatureDefOrThrow(DEFAULT_SERVING_SIGNATURE_DEF_KEY);\n\nYou will have to import the following in Java.\nimport org.tensorflow.framework.MetaGraphDef;\nimport org.tensorflow.framework.SignatureDef;", parameters = {@Parameter(name = "absolute.path.to.model", description = "This is the absolute path to the model folder in the local machine.", type = {DataType.STRING}), @Parameter(name = "input.node.names", description = "This is a variable length parameter. The names of the input nodes as comma separated strings.", type = {DataType.STRING}), @Parameter(name = "output.node.names", description = "This is a variable length parameter. The names of the output nodes as comma separated strings.", type = {DataType.STRING}), @Parameter(name = "attributes", description = "This is a variable length parameter. These are the attributes coming with events. Note that arrays should be cast to objects and sent.", type = {DataType.INT, DataType.STRING, DataType.DOUBLE, DataType.LONG, DataType.FLOAT, DataType.BOOL, DataType.OBJECT})}, returnAttributes = {@ReturnAttribute(name = "outputs", description = "This is a variable length return attribute. The output tensors from the inference will be flattened out and sent in their primitive values. User is expected to know the shape of the output tensors if he/she wishes to reconstruct it. The shape and data type information can be retrieved from TensorFlow saved model signature_def. See the description of this extension for instructions on how to read signature_def", type = {DataType.INT, DataType.STRING, DataType.DOUBLE, DataType.LONG, DataType.FLOAT, DataType.BOOL})}, examples = {@Example(syntax = "define stream InputStream (x Object, y Object);\n@info(name = 'query1') \nfrom InputStream#tensorFlow:predict('home/MNIST', 'inputPoint', 'dropout', 'outputPoint', x, y) \nselect outputPoint0, outputPoint1, outputPoint2, outputPoint3, outputPoint4, outputPoint5, outputPoint6, outputPoint7, outputPoint8, outputPoint9 \ninsert into OutputStream;\n", description = "This is a query to get inferences from a MNIST model. This model takes in 2 inputs. One being the image as float array and other is keep probability array and sends out a Tensor with 10 elements. Our stream processor flattens the tensor and sends 10 floats each representing the probability of image being 0,1,...,9")})
/* loaded from: input_file:org/wso2/extension/siddhi/execution/tensorflow/TensorFlowExtension.class */
public class TensorFlowExtension extends StreamProcessor {
    private static final Logger logger = Logger.getLogger(TensorFlowExtension.class);
    private String[] inputVariableNamesArray;
    private String[] outputVariableNamesArray;
    private int noOfInputs;
    private int noOfOutputs;
    private VariableExpressionExecutor[] inputVariableExpressionExecutors;
    private Session tensorFlowSession;
    private SignatureDef signatureDef;

    protected List<Attribute> init(AbstractDefinition abstractDefinition, ExpressionExecutor[] expressionExecutorArr, ConfigReader configReader, SiddhiAppContext siddhiAppContext) {
        if (this.attributeExpressionLength < 3) {
            String str = "Insufficient number of parameters. Query should have at least 5 constant parameters and appropriate number of variable parameters but given " + this.attributeExpressionLength;
            logger.error(siddhiAppContext.getName() + str);
            throw new SiddhiAppCreationException(str);
        }
        if (!(this.attributeExpressionExecutors[0] instanceof ConstantExpressionExecutor)) {
            throw new SiddhiAppCreationException("1st query parameter is the absolute path to model which has to be constant but found " + this.attributeExpressionExecutors[0].getClass().getCanonicalName());
        }
        if (this.attributeExpressionExecutors[0].getReturnType() != Attribute.Type.STRING) {
            throw new SiddhiAppCreationException("1st query parameter is the absolute path to model which has to be of type String but found " + this.attributeExpressionExecutors[0].getReturnType());
        }
        SavedModelBundle load = SavedModelBundle.load((String) this.attributeExpressionExecutors[0].getValue(), "serve");
        this.tensorFlowSession = load.session();
        try {
            this.signatureDef = MetaGraphDef.parseFrom(load.metaGraphDef()).getSignatureDefOrThrow("serving_default");
            this.noOfInputs = this.signatureDef.getInputsCount();
            this.noOfOutputs = this.signatureDef.getOutputsCount();
            int i = 1 + (2 * this.noOfInputs) + this.noOfOutputs;
            if (this.attributeExpressionLength != i) {
                String str2 = "Invalid number of query parameters. Number of inputs and number of outputs are specified as " + this.noOfInputs + " and " + this.noOfOutputs + " respectively. So the total number of query parameters should be " + i + " but " + this.attributeExpressionLength + " given.";
                logger.error(siddhiAppContext.getName() + str2);
                throw new SiddhiAppCreationException(str2);
            }
            this.inputVariableNamesArray = new String[this.noOfInputs];
            this.outputVariableNamesArray = new String[this.noOfOutputs];
            for (int i2 = 0; i2 < this.noOfInputs; i2++) {
                int i3 = i2 + 1;
                if (!(this.attributeExpressionExecutors[i3] instanceof ConstantExpressionExecutor)) {
                    throw new SiddhiAppCreationException("The query parameter of index " + (i3 + 1) + " is a input name which has to be a constant but found " + this.attributeExpressionExecutors[i3].getClass().getCanonicalName());
                }
                if (this.attributeExpressionExecutors[i3].getReturnType() != Attribute.Type.STRING) {
                    throw new SiddhiAppCreationException("The query parameter of index " + (i3 + 1) + " is a input name which has to be a String but found " + this.attributeExpressionExecutors[i3].getReturnType());
                }
                this.inputVariableNamesArray[i2] = (String) this.attributeExpressionExecutors[i3].getValue();
            }
            for (int i4 = 0; i4 < this.noOfOutputs; i4++) {
                int i5 = i4 + 1 + this.noOfInputs;
                if (!(this.attributeExpressionExecutors[i5] instanceof ConstantExpressionExecutor)) {
                    throw new SiddhiAppCreationException("The query parameter of index " + (i5 + 1) + " is a output name which has to be a constant but found " + this.attributeExpressionExecutors[i5].getClass().getCanonicalName());
                }
                if (this.attributeExpressionExecutors[i5].getReturnType() != Attribute.Type.STRING) {
                    throw new SiddhiAppCreationException("The query parameter of index " + (i5 + 1) + " is a output name which has to be a String but found " + this.attributeExpressionExecutors[i5].getReturnType());
                }
                this.outputVariableNamesArray[i4] = (String) this.attributeExpressionExecutors[i5].getValue();
            }
            for (String str3 : this.inputVariableNamesArray) {
                if (!this.signatureDef.getInputsMap().containsKey(str3)) {
                    throw new SiddhiAppCreationException(str3 + " not present in the signature def. Please check the input node names");
                }
            }
            for (String str4 : this.outputVariableNamesArray) {
                if (!this.signatureDef.getOutputsMap().containsKey(str4)) {
                    throw new SiddhiAppCreationException(str4 + " not present in the signature def. Please check the output node names.");
                }
            }
            this.inputVariableExpressionExecutors = CoreUtils.extractAndValidateTensorFlowInputs(this.attributeExpressionExecutors, 1 + this.noOfInputs + this.noOfOutputs, this.noOfInputs);
            return CoreUtils.getReturnAttributeList(this.signatureDef, this.noOfOutputs, load, this.outputVariableNamesArray);
        } catch (InvalidProtocolBufferException e) {
            throw new SiddhiAppCreationException("Error while reading signature def." + e.getMessage(), e);
        }
    }

    protected void process(ComplexEventChunk<StreamEvent> complexEventChunk, Processor processor, StreamEventCloner streamEventCloner, ComplexEventPopulater complexEventPopulater) {
        while (complexEventChunk.hasNext()) {
            StreamEvent next = complexEventChunk.next();
            Session.Runner runner = this.tensorFlowSession.runner();
            LinkedList linkedList = new LinkedList();
            for (int i = 0; i < this.noOfInputs; i++) {
                try {
                    Tensor createTensor = CoreUtils.createTensor((String) this.inputVariableExpressionExecutors[i].execute(next));
                    linkedList.add(createTensor);
                    runner = runner.feed(this.signatureDef.getInputsMap().get(this.inputVariableNamesArray[i]).getName(), createTensor);
                } catch (Throwable th) {
                    logger.error("Error while feeding input " + this.inputVariableNamesArray[i] + ". " + th.getMessage());
                }
            }
            for (int i2 = 0; i2 < this.noOfOutputs; i2++) {
                runner = runner.fetch(this.signatureDef.getOutputsMap().get(this.outputVariableNamesArray[i2]).getName());
            }
            List<Tensor> run = runner.run();
            Iterator it = linkedList.iterator();
            while (it.hasNext()) {
                ((Tensor) it.next()).close();
            }
            complexEventPopulater.populateComplexEvent(next, CoreUtils.getOutputObjectArray(run));
        }
        this.nextProcessor.process(complexEventChunk);
    }

    public void start() {
    }

    public void stop() {
    }

    public Map<String, Object> currentState() {
        return null;
    }

    public void restoreState(Map<String, Object> map) {
    }
}
