CoreUtils.java
package org.wso2.extension.siddhi.execution.streamingml.util;
import org.wso2.siddhi.core.executor.ExpressionExecutor;
import org.wso2.siddhi.core.executor.VariableExpressionExecutor;
import org.wso2.siddhi.query.api.definition.AbstractDefinition;
import org.wso2.siddhi.query.api.definition.Attribute;
import org.wso2.siddhi.query.api.exception.SiddhiAppValidationException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
* Common utils for Streaming Machine Learning tasks.
*/
public class CoreUtils {
private static final List<Attribute.Type> numericTypes = Arrays.asList(Attribute.Type.INT,
Attribute.Type.DOUBLE, Attribute.Type.LONG, Attribute.Type.FLOAT);
private static final List<Attribute.Type> labelTypes = Arrays.asList(Attribute.Type.STRING, Attribute.Type.BOOL);
/**
* @param inputDefinition
* @param attributeExpressionExecutors
* @param startIndex starting index
* @param noOfFeatures
* @return
*/
public static List<VariableExpressionExecutor> extractAndValidateFeatures(
AbstractDefinition inputDefinition, ExpressionExecutor[]
attributeExpressionExecutors,
int startIndex, int noOfFeatures) {
List<VariableExpressionExecutor> featureVariableExpressionExecutors = new ArrayList<>();
// feature values start
for (int i = startIndex; i < (startIndex + noOfFeatures); i++) {
if (attributeExpressionExecutors[i] instanceof VariableExpressionExecutor) {
featureVariableExpressionExecutors.add((VariableExpressionExecutor)
attributeExpressionExecutors[i]);
// other attributes should be numeric type.
String attributeName = ((VariableExpressionExecutor)
attributeExpressionExecutors[i]).getAttribute().getName();
Attribute.Type featureAttributeType = inputDefinition.
getAttributeType(attributeName);
//feature attributes not numerical type
if (!isNumeric(featureAttributeType)) {
throw new SiddhiAppValidationException("model.features in " + (i + 1) + "th parameter is not " +
"a numerical type attribute. Found " +
attributeExpressionExecutors[i].getReturnType()
+ ". Check the input stream definition.");
}
} else {
throw new SiddhiAppValidationException((i + 1) + "th parameter is not " +
"an attribute (VariableExpressionExecutor) present in the stream definition. Found a "
+ attributeExpressionExecutors[i].getClass().getCanonicalName());
}
}
return featureVariableExpressionExecutors;
}
private static boolean isNumeric(Attribute.Type attributeType) {
return numericTypes.contains(attributeType);
}
public static boolean isLabelType(Attribute.Type attributeType) {
return labelTypes.contains(attributeType);
}
/**
* @param inputDefinition
* @param attributeExpressionExecutors
* @param classIndex
* @return
*/
public static VariableExpressionExecutor extractAndValidateClassLabel
(AbstractDefinition inputDefinition, ExpressionExecutor[] attributeExpressionExecutors, int classIndex) {
VariableExpressionExecutor classLabelVariableExecutor;
if (attributeExpressionExecutors[classIndex] instanceof VariableExpressionExecutor) {
// other attributes should be numeric type.
String attributeName = ((VariableExpressionExecutor)
attributeExpressionExecutors[classIndex]).getAttribute().getName();
Attribute.Type classLabelAttributeType = inputDefinition.
getAttributeType(attributeName);
//class label should be String or Bool
if (isLabelType(classLabelAttributeType)) {
classLabelVariableExecutor = (VariableExpressionExecutor) attributeExpressionExecutors[classIndex];
} else {
throw new SiddhiAppValidationException(String.format("[label attribute] in %s th index of " +
"classifierUpdate should be either a %s or a %s but found %s",
classIndex, Attribute.Type.BOOL, Attribute
.Type.STRING, classLabelAttributeType));
}
} else {
throw new SiddhiAppValidationException((classIndex) + "th parameter is not " +
"an attribute (VariableExpressionExecutor) present in the stream definition. Found a "
+ attributeExpressionExecutors[classIndex].getClass().getCanonicalName());
}
return classLabelVariableExecutor;
}
}