PerceptronClassifierUpdaterStreamProcessorExtension.java

/*
 * Copyright (c) 2017, WSO2 Inc. (http://www.wso2.org) All Rights Reserved.
 *
 * WSO2 Inc. licenses this file to you under the Apache License,
 * Version 2.0 (the "License"); you may not use this file except
 * in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.wso2.extension.siddhi.execution.streamingml.classification.perceptron;

import org.apache.log4j.Logger;
import org.wso2.extension.siddhi.execution.streamingml.classification.perceptron.util.PerceptronModel;
import org.wso2.extension.siddhi.execution.streamingml.classification.perceptron.util.PerceptronModelsHolder;
import org.wso2.extension.siddhi.execution.streamingml.util.CoreUtils;
import org.wso2.siddhi.annotation.Example;
import org.wso2.siddhi.annotation.Extension;
import org.wso2.siddhi.annotation.Parameter;
import org.wso2.siddhi.annotation.ReturnAttribute;
import org.wso2.siddhi.annotation.util.DataType;
import org.wso2.siddhi.core.config.SiddhiAppContext;
import org.wso2.siddhi.core.event.ComplexEventChunk;
import org.wso2.siddhi.core.event.stream.StreamEvent;
import org.wso2.siddhi.core.event.stream.StreamEventCloner;
import org.wso2.siddhi.core.event.stream.populater.ComplexEventPopulater;
import org.wso2.siddhi.core.exception.SiddhiAppCreationException;
import org.wso2.siddhi.core.exception.SiddhiAppRuntimeException;
import org.wso2.siddhi.core.executor.ConstantExpressionExecutor;
import org.wso2.siddhi.core.executor.ExpressionExecutor;
import org.wso2.siddhi.core.executor.VariableExpressionExecutor;
import org.wso2.siddhi.core.query.processor.Processor;
import org.wso2.siddhi.core.query.processor.stream.StreamProcessor;
import org.wso2.siddhi.core.util.config.ConfigReader;
import org.wso2.siddhi.query.api.definition.AbstractDefinition;
import org.wso2.siddhi.query.api.definition.Attribute;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * Build or update a linear binary classification Perceptron model and emit the weights of the features in the order of
 * the attributes.
 */
@Extension(
        name = "updatePerceptronClassifier",
        namespace = "streamingml",
        description = "This extension builds/updates a linear binary classification Perceptron model.",
        parameters = {
                @Parameter(name = "model.name",
                        description = "The name of the model to be built/updated.",
                        type = {DataType.STRING}),
                @Parameter(name = "model.label",
                        description = "The attribute of the label or the class of the dataset.",
                        type = {DataType.BOOL, DataType.STRING}),
                @Parameter(name = "learning.rate",
                        description = "The learning rate of the Perceptron algorithm.",
                        type = {DataType.DOUBLE}, optional = true, defaultValue = "0.1"),
                @Parameter(name = "model.features",
                        description = "Features of the model that need to be attributes of the stream.",
                        type = {DataType.DOUBLE, DataType.INT})},
        returnAttributes = {
                @ReturnAttribute(name = "featureWeight", description = "Weight of the <feature" +
                        ".name> of the " + "model.", type = {DataType.DOUBLE})},
        examples = {
                @Example(syntax = "define stream StreamA (attribute_0 double, attribute_1 double, attribute_2 double," +
                        " attribute_3 double, attribute_4 " + "string );\n\n" +
                        "from StreamA#streamingml:updatePerceptronClassifier('model1', attribute_4, 0.01, " +
                        "attribute_0, attribute_1, attribute_2, attribute_3) \n" +
                        "insert all events into outputStream;",
                        description = "This query builds/updates a Perceptron model named `model1` with a `0.01` " +
                                "learning rate using `attribute_0`, `attribute_1`, `attribute_2`, and `attribute_3` " +
                                "as features, and `attribute_4` as the label. Updated weights of the model " +
                                "are emitted to the OutputStream stream."),
                @Example(syntax = "define stream StreamA (attribute_0 double, attribute_1 double, attribute_2 double," +
                        "attribute_3 double, attribute_4 string );\n\n " +
                        "from StreamA#streamingml:updatePerceptronClassifier('model1', attribute_4, attribute_0, " +
                        "attribute_1, attribute_2, attribute_3) \n" +
                        "insert all events into outputStream;",
                        description = "This query builds/updates a Perceptron model named `model1` with a default" +
                                " `0.1` learning rate using `attribute_0`, `attribute_1`, `attribute_2`, and " +
                                "`attribute_3` as features, and `attribute_4` as the label. The updated weights of " +
                                "the model are appended to the outputStream.")
        }
)
public class PerceptronClassifierUpdaterStreamProcessorExtension extends StreamProcessor {

    private static Logger logger = Logger.getLogger(PerceptronClassifierUpdaterStreamProcessorExtension.class);
    private String modelName;
    private int numberOfFeatures;
    private VariableExpressionExecutor labelVariableExpressionExecutor;
    private List<VariableExpressionExecutor> featureVariableExpressionExecutors = new ArrayList<>();

    @Override
    protected List<Attribute> init(AbstractDefinition inputDefinition, ExpressionExecutor[]
            attributeExpressionExecutors, ConfigReader configReader, SiddhiAppContext siddhiAppContext) {
        String siddhiAppName = siddhiAppContext.getName();
        PerceptronModel model;
        String modelPrefix;
        double learningRate = -1;
        // maxNumberOfFeatures = number of attributes - label attribute
        int maxNumberOfFeatures = inputDefinition.getAttributeList().size() - 1;

        if (attributeExpressionLength >= 3) {
            if (attributeExpressionLength > 3 + maxNumberOfFeatures) {
                throw new SiddhiAppCreationException(String.format("Invalid number of parameters for " +
                        "streamingml:updatePerceptronClassifier. This Stream Processor requires at most %s " +
                        "parameters, namely, model.name, model.label, learning.rate, model.features but found %s " +
                        "parameters", 3 + maxNumberOfFeatures, attributeExpressionLength));
            }
            if (attributeExpressionExecutors[0] instanceof ConstantExpressionExecutor) {
                if (attributeExpressionExecutors[0].getReturnType() == Attribute.Type.STRING) {
                    modelPrefix = (String) ((ConstantExpressionExecutor) attributeExpressionExecutors[0]).getValue();
                    // model name = user given name + siddhi app name
                    modelName = modelPrefix + "." + siddhiAppName;
                } else {
                    throw new SiddhiAppCreationException("Invalid parameter type found for the model.name argument," +
                            " required " + Attribute.Type.STRING + " but found " + attributeExpressionExecutors[0].
                            getReturnType().toString());
                }
            } else {
                throw new SiddhiAppCreationException("Parameter model.name must be a constant" + " but found " +
                        attributeExpressionExecutors[0].getClass().getCanonicalName());
            }

            if (this.attributeExpressionExecutors[1] instanceof VariableExpressionExecutor) {
                labelVariableExpressionExecutor = (VariableExpressionExecutor) this.attributeExpressionExecutors[1];
                // label attribute should be bool or string types
                Attribute.Type labelAttributeType = inputDefinition.getAttributeType(labelVariableExpressionExecutor
                        .getAttribute().getName());
                if (!CoreUtils.isLabelType(labelAttributeType)) {
                    throw new SiddhiAppCreationException(String.format("[model.label] %s in " +
                                    "updatePerceptronClassifier should be either a %s or a %s (true/false). But found"
                                    + " %s", labelVariableExpressionExecutor.getAttribute().getName(),
                            Attribute.Type.BOOL, Attribute.Type.STRING, labelAttributeType.name()));
                }
            } else {
                throw new SiddhiAppCreationException("model.label attribute in updatePerceptronClassifier should "
                        + "be a variable, but found a " + this.attributeExpressionExecutors[1].getClass()
                        .getCanonicalName());
            }

            if (attributeExpressionExecutors[2] instanceof ConstantExpressionExecutor) {
                // learning rate
                if (attributeExpressionExecutors[2].getReturnType() == Attribute.Type.DOUBLE) {
                    learningRate = (double) ((ConstantExpressionExecutor) attributeExpressionExecutors[2])
                            .getValue();
                } else {
                    throw new SiddhiAppCreationException("Invalid parameter type found for the learning.rate " +
                            "argument. Expected: " + Attribute.Type.DOUBLE + " but found: " +
                            attributeExpressionExecutors[2].getReturnType().toString());
                }

                // set number of features
                numberOfFeatures = attributeExpressionLength - 3;
                // feature values
                featureVariableExpressionExecutors = CoreUtils.extractAndValidateFeatures(inputDefinition,
                        attributeExpressionExecutors, 3, numberOfFeatures);
            } else if (attributeExpressionExecutors[2] instanceof VariableExpressionExecutor) {
                // set number of features
                numberOfFeatures = attributeExpressionLength - 2;
                // feature values
                featureVariableExpressionExecutors = CoreUtils.extractAndValidateFeatures(inputDefinition,
                        attributeExpressionExecutors, 2, numberOfFeatures);
            } else {
                throw new SiddhiAppCreationException("3rd Parameter must either be a constant (learning.rate) or "
                        + "an attribute of the stream (model" + ".features), but found a " +
                        attributeExpressionExecutors[2].getClass().getCanonicalName());
            }
        } else {
            throw new SiddhiAppCreationException(String.format("Invalid number of parameters [%s] for " +
                    "streamingml:updatePerceptronClassifier", attributeExpressionLength));
        }

        model = PerceptronModelsHolder.getInstance().getPerceptronModel(modelName);
        if (model == null) {
            model = new PerceptronModel();
            PerceptronModelsHolder.getInstance().addPerceptronModel(modelName, model);
        }
        if (learningRate != -1) {
            model.setLearningRate(learningRate);
        }
        if (model.getFeatureSize() != -1) {
            // validate the model
            if (numberOfFeatures != model.getFeatureSize()) {
                throw new SiddhiAppCreationException(String.format("Model [%s] expects %s features, but the " +
                        "streamingml:updatePerceptronClassifier specifies %s features", modelPrefix, model
                        .getFeatureSize(), numberOfFeatures));
            }
        } else {
            model.initWeights(numberOfFeatures);
        }

        List<Attribute> attributes = new ArrayList<>();
        for (int i = 0; i < numberOfFeatures; i++) {
            attributes.add(new Attribute(featureVariableExpressionExecutors.get(i).getAttribute().getName() +
                    ".weight", Attribute.Type.DOUBLE));
        }

        return attributes;
    }

    /**
     * Process events received by PerceptronClassifierUpdaterStreamProcessorExtension
     *
     * @param streamEventChunk      the event chunk that need to be processed
     * @param nextProcessor         the next processor to which the success events need to be passed
     * @param streamEventCloner     helps to clone the incoming event for local storage or modification
     * @param complexEventPopulater helps to populate the events with the resultant attributes
     */
    @Override
    protected void process(ComplexEventChunk<StreamEvent> streamEventChunk, Processor nextProcessor,
                           StreamEventCloner streamEventCloner, ComplexEventPopulater complexEventPopulater) {

        synchronized (this) {
            while (streamEventChunk.hasNext()) {
                StreamEvent event = streamEventChunk.next();
                if (logger.isDebugEnabled()) {
                    logger.debug(String.format("Event received; Model name: %s Event:%s", modelName, event));
                }

                Object labelObj = labelVariableExpressionExecutor.execute(event);
                String label;
                if (labelObj instanceof String) {
                    label = (String) labelObj;
                    if (!(label.equalsIgnoreCase("true") || label.equalsIgnoreCase("false"))) {
                        throw new SiddhiAppRuntimeException(String.format("Detected attribute type of the label " +
                                "is String, but the value is not either true or false but %s. Note: Perceptron " +
                                "classifier can be used only for binary classification problems.", label));
                    }
                } else {
                    // should be a boolean as validated at init
                    label = Boolean.toString((boolean) labelObj);
                }

                double[] features = new double[numberOfFeatures];
                for (int i = 0; i < numberOfFeatures; i++) {
                    // attributes cannot ever be any other type than int or double as we've validated the query at init
                    features[i] = (double) featureVariableExpressionExecutors.get(i).execute(event);
                }

                double[] weights = PerceptronModelsHolder.getInstance().getPerceptronModel(modelName).update(Boolean
                        .parseBoolean(label), features);
                // convert weights to object[]
                Object[] data = new Object[weights.length];
                for (int i = 0; i < weights.length; i++) {
                    data[i] = weights[i];
                }

                complexEventPopulater.populateComplexEvent(event, data);
            }
        }
        nextProcessor.process(streamEventChunk);
    }

    @Override
    public void start() {
    }

    @Override
    public void stop() {
        PerceptronModelsHolder.getInstance().deletePerceptronModel(modelName);
    }

    @Override
    public Map<String, Object> currentState() {
        Map<String, Object> currentState = new HashMap<>();
        currentState.put("PerceptronModel", PerceptronModelsHolder.getInstance().getClonedPerceptronModel(modelName));
        return currentState;
    }

    @Override
    public void restoreState(Map<String, Object> state) {
        PerceptronModelsHolder.getInstance().addPerceptronModel(modelName, (PerceptronModel)
                state.get("PerceptronModel"));
    }
}