PerceptronClassifierStreamProcessorExtension.java
/*
* Copyright (c) 2017, WSO2 Inc. (http://www.wso2.org) All Rights Reserved.
*
* WSO2 Inc. licenses this file to you under the Apache License,
* Version 2.0 (the "License"); you may not use this file except
* in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.wso2.extension.siddhi.execution.streamingml.classification.perceptron;
import org.apache.log4j.Logger;
import org.wso2.extension.siddhi.execution.streamingml.classification.perceptron.util.PerceptronModel;
import org.wso2.extension.siddhi.execution.streamingml.classification.perceptron.util.PerceptronModelsHolder;
import org.wso2.extension.siddhi.execution.streamingml.util.CoreUtils;
import org.wso2.siddhi.annotation.Example;
import org.wso2.siddhi.annotation.Extension;
import org.wso2.siddhi.annotation.Parameter;
import org.wso2.siddhi.annotation.ReturnAttribute;
import org.wso2.siddhi.annotation.util.DataType;
import org.wso2.siddhi.core.config.SiddhiAppContext;
import org.wso2.siddhi.core.event.ComplexEventChunk;
import org.wso2.siddhi.core.event.stream.StreamEvent;
import org.wso2.siddhi.core.event.stream.StreamEventCloner;
import org.wso2.siddhi.core.event.stream.populater.ComplexEventPopulater;
import org.wso2.siddhi.core.exception.SiddhiAppCreationException;
import org.wso2.siddhi.core.executor.ConstantExpressionExecutor;
import org.wso2.siddhi.core.executor.ExpressionExecutor;
import org.wso2.siddhi.core.executor.VariableExpressionExecutor;
import org.wso2.siddhi.core.query.processor.Processor;
import org.wso2.siddhi.core.query.processor.stream.StreamProcessor;
import org.wso2.siddhi.core.util.config.ConfigReader;
import org.wso2.siddhi.query.api.definition.AbstractDefinition;
import org.wso2.siddhi.query.api.definition.Attribute;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* Predict using a linear binary classification Perceptron model built via
* {@link PerceptronClassifierStreamProcessorExtension}
*/
@Extension(
name = "perceptronClassifier",
namespace = "streamingml",
description = "This extension predicts using a linear binary classification Perceptron model.",
parameters = {
@Parameter(name = "model.name",
description = "The name of the model to be used.",
type = {DataType.STRING}),
@Parameter(name = "model.bias",
description = "The bias of the Perceptron algorithm.",
type = {DataType.DOUBLE}, defaultValue = "0.0"),
@Parameter(name = "model.threshold",
description = "The threshold that separates the two classes. The value specified must be " +
"between zero and one.",
type = {DataType.DOUBLE}, defaultValue = " The output is a probability."),
@Parameter(name = "model.features",
description = "The features of the model that need to be attributes of the stream.",
type = {DataType.DOUBLE})
},
returnAttributes = {
@ReturnAttribute(name = "prediction",
description = "The predicted value (`true/false`)",
type = {DataType.BOOL}),
@ReturnAttribute(name = "confidenceLevel",
description = "The probability of the prediction",
type = {DataType.DOUBLE})
},
examples = {
@Example(
syntax = "define stream StreamA (attribute_0 double, attribute_1 double, attribute_2 double, " +
"attribute_3 double);\n" +
"\n" +
"from StreamA#streamingml:perceptronClassifier('model1',0.0,0.5, attribute_0, " +
"attribute_1, attribute_2, attribute_3) \n" +
"insert all events into OutputStream;",
description = "This query uses a Perceptron model named `model1` with a `0.0` bias and a " +
"`0.5` threshold learning rate to predict the label of the feature vector " +
"represented by `attribute_0`, `attribute_1`, `attribute_2`, and `attribute_3`. " +
"The predicted label (`true/false`) is emitted to the `OutputStream` stream" +
"along with the prediction confidence level(probability) and the feature vector. " +
"As a result, the OutputStream stream is defined as follows: " +
"(attribute_0 double, attribute_1 double, attribute_2" +
" double, attribute_3 double, prediction bool, confidenceLevel double)."
),
@Example(
syntax = "define stream StreamA (attribute_0 double, attribute_1 double, attribute_2 double, " +
"attribute_3 double);\n" +
"\n" +
"from StreamA#streamingml:perceptronClassifier('model1',0.0, attribute_0, " +
"attribute_1, attribute_2, attribute_3) \n" +
"insert all events into OutputStream;",
description = "This query uses a Perceptron model named `model1` with a `0.0` bias to predict" +
" the label of the feature vector represented by `attribute_0`, `attribute_1`, " +
"`attribute_2`, and `attribute_3`. The prediction(`true/false`) is emitted to the " +
"`OutputStream`stream along with the prediction confidence level(probability) and " +
"the feature. " +
"As a result, the OutputStream stream is defined as follows: " +
"(attribute_0 double, attribute_1 double, attribute_2 double, attribute_3 double, " +
"prediction bool, confidenceLevel double)."
),
@Example(
syntax = "define stream StreamA (attribute_0 double, attribute_1 double, attribute_2 double, " +
"attribute_3 double);\n" +
"\n" +
"from StreamA#streamingml:perceptronClassifier(`model1`, attribute_0, attribute_1, " +
"attribute_2) \n" +
"insert all events into OutputStream;",
description = "This query uses a Perceptron model named `model1` with a default 0.0 bias" +
" to predict the label of the feature vector represented by `attribute_0`, " +
"`attribute_1`, and `attribute_2`. The predicted probability is emitted to the " +
"OutputStream stream along with the feature vector. As a result, the OutputStream is " +
"defined as follows: " +
"(attribute_0 double, attribute_1 double, attribute_2 double, attribute_3 double, " +
"prediction bool, confidenceLevel double)."
)
}
)
public class PerceptronClassifierStreamProcessorExtension extends StreamProcessor {
private static Logger logger = Logger.getLogger(PerceptronClassifierStreamProcessorExtension.class);
private String modelName;
private int numberOfFeatures;
private List<VariableExpressionExecutor> featureVariableExpressionExecutors = new ArrayList<>();
@Override
protected List<Attribute> init(AbstractDefinition inputDefinition, ExpressionExecutor[]
attributeExpressionExecutors, ConfigReader configReader, SiddhiAppContext siddhiAppContext) {
String siddhiAppName = siddhiAppContext.getName();
PerceptronModel model;
String modelPrefix;
double bias = -1, threshold = -1;
// maxNumberOfFeatures = number of attributes - label attribute
int maxNumberOfFeatures = inputDefinition.getAttributeList().size();
if (attributeExpressionLength >= 2) {
if (attributeExpressionLength > 3 + maxNumberOfFeatures) {
throw new SiddhiAppCreationException(String.format("Invalid number of parameters for " +
"streamingml:perceptronClassifier. This Stream Processor requires at most %s " + "parameters," +
" namely, model.name, model.bias, model.threshold, model.features but found %s " +
"parameters", 3 + maxNumberOfFeatures, attributeExpressionLength));
}
if (attributeExpressionExecutors[0] instanceof ConstantExpressionExecutor) {
if (attributeExpressionExecutors[0].getReturnType() == Attribute.Type.STRING) {
modelPrefix = (String) ((ConstantExpressionExecutor) attributeExpressionExecutors[0]).getValue();
// model name = user given name + siddhi app name
modelName = modelPrefix + "." + siddhiAppName;
} else {
throw new SiddhiAppCreationException("Invalid parameter type found for the model.name argument," +
"" + " required " + Attribute.Type.STRING + " but found " +
attributeExpressionExecutors[0].getReturnType().toString());
}
} else {
throw new SiddhiAppCreationException("Parameter model.name must be a constant but found " +
attributeExpressionExecutors[0].getClass().getCanonicalName());
}
// 2nd param
if (attributeExpressionExecutors[1] instanceof ConstantExpressionExecutor) {
// bias
if (attributeExpressionExecutors[1].getReturnType() == Attribute.Type.DOUBLE) {
bias = (double) ((ConstantExpressionExecutor) attributeExpressionExecutors[1]).getValue();
} else {
throw new SiddhiAppCreationException("Invalid parameter type found for the model.bias " +
"argument. Expected: " + Attribute.Type.DOUBLE + " but found: " +
attributeExpressionExecutors[1].getReturnType().toString());
}
// 3rd param
if (attributeExpressionExecutors[2] instanceof ConstantExpressionExecutor) {
// threshold
if (attributeExpressionExecutors[2].getReturnType() == Attribute.Type.DOUBLE) {
threshold = (double) ((ConstantExpressionExecutor) attributeExpressionExecutors[2])
.getValue();
if (threshold <= 0 || threshold >= 1) {
throw new SiddhiAppCreationException("Invalid parameter value found for the model" + "" +
".threshold argument. Expected a value between 0 & 1, but found: " + threshold);
}
} else {
throw new SiddhiAppCreationException("Invalid parameter type found for the model.threshold " +
"" + "argument. Expected: " + Attribute.Type.DOUBLE + " but found: " +
attributeExpressionExecutors[2].getReturnType().toString());
}
// set number of features
numberOfFeatures = attributeExpressionLength - 3;
// feature variables
featureVariableExpressionExecutors = CoreUtils.extractAndValidateFeatures(inputDefinition,
attributeExpressionExecutors, 3, numberOfFeatures);
} else if (attributeExpressionExecutors[2] instanceof VariableExpressionExecutor) {
// set number of features
numberOfFeatures = attributeExpressionLength - 2;
// feature variables
featureVariableExpressionExecutors = CoreUtils.extractAndValidateFeatures(inputDefinition,
attributeExpressionExecutors, 2, numberOfFeatures);
} else {
throw new SiddhiAppCreationException("3rd Parameter must either be a constant (model.threshold)" +
"" + " or an attribute of the stream (model.features), but found a " +
attributeExpressionExecutors[2].getClass().getCanonicalName());
}
} else if (attributeExpressionExecutors[1] instanceof VariableExpressionExecutor) {
// set number of features
numberOfFeatures = attributeExpressionLength - 1;
// feature values
/* extractAndValidateFeatures(inputDefinition, attributeExpressionExecutors, 1);*/
featureVariableExpressionExecutors = CoreUtils.extractAndValidateFeatures(inputDefinition,
attributeExpressionExecutors, 1, numberOfFeatures);
} else {
throw new SiddhiAppCreationException("2nd Parameter must either be a constant (model.bias) or " +
"an attribute of the stream (model.features), but found a " + attributeExpressionExecutors[1]
.getClass().getCanonicalName());
}
} else {
throw new SiddhiAppCreationException(String.format("Invalid number of parameters [%s] for " +
"streamingml:perceptronClassifier", attributeExpressionLength));
}
model = PerceptronModelsHolder.getInstance().getPerceptronModel(modelName);
if (model != null) {
if (bias != -1) {
model.setBias(bias);
}
if (threshold != -1) {
model.setThreshold(threshold);
}
if (model.getFeatureSize() != -1) {
// validate the model
if (numberOfFeatures != model.getFeatureSize()) {
throw new SiddhiAppCreationException(String.format("Model [%s] expects %s features, but the " +
"streamingml:perceptronClassifier specifies %s features", modelPrefix, model.getFeatureSize()
, numberOfFeatures));
}
}
} else {
throw new SiddhiAppCreationException(String.format("Model [%s] needs to initialized "
+ "prior to be used with streamingml:perceptronClassifier. "
+ "Perform streamingml:updatePerceptronClassifier process first.", modelName));
}
List<Attribute> attributes = new ArrayList<>();
attributes.add(new Attribute("prediction", Attribute.Type.BOOL));
attributes.add(new Attribute("confidenceLevel", Attribute.Type.DOUBLE));
return attributes;
}
/**
* Process events received by PerceptronClassifierUpdaterStreamProcessorExtension
*
* @param streamEventChunk the event chunk that need to be processed
* @param nextProcessor the next processor to which the success events need to be passed
* @param streamEventCloner helps to clone the incoming event for local storage or modification
* @param complexEventPopulater helps to populate the events with the resultant attributes
*/
@Override
protected void process(ComplexEventChunk<StreamEvent> streamEventChunk, Processor nextProcessor,
StreamEventCloner streamEventCloner, ComplexEventPopulater complexEventPopulater) {
synchronized (this) {
while (streamEventChunk.hasNext()) {
StreamEvent event = streamEventChunk.next();
if (logger.isDebugEnabled()) {
logger.debug(String.format("Event received; Model name: %s Event:%s", modelName, event));
}
double[] features = new double[numberOfFeatures];
for (int i = 0; i < numberOfFeatures; i++) {
// attributes cannot ever be any other type than double as we've validated the query at init
features[i] = (double) featureVariableExpressionExecutors.get(i).execute(event);
}
Object[] data = PerceptronModelsHolder.getInstance().getPerceptronModel(modelName).classify(features);
// If output has values, then add those values to output stream
complexEventPopulater.populateComplexEvent(event, data);
}
}
nextProcessor.process(streamEventChunk);
}
@Override
public void start() {
}
@Override
public void stop() {
PerceptronModelsHolder.getInstance().deletePerceptronModel(modelName);
}
@Override
public Map<String, Object> currentState() {
return new HashMap<>();
}
@Override
public void restoreState(Map<String, Object> state) {
}
}