package io.siddhi.extension.execution.streamingml.bayesian.classification;

import io.siddhi.core.config.SiddhiQueryContext;
import io.siddhi.core.event.ComplexEventChunk;
import io.siddhi.core.event.stream.MetaStreamEvent;
import io.siddhi.core.event.stream.StreamEvent;
import io.siddhi.core.event.stream.StreamEventCloner;
import io.siddhi.core.event.stream.holder.StreamEventClonerHolder;
import io.siddhi.core.event.stream.populater.ComplexEventPopulater;
import io.siddhi.core.exception.SiddhiAppCreationException;
import io.siddhi.core.executor.ConstantExpressionExecutor;
import io.siddhi.core.executor.ExpressionExecutor;
import io.siddhi.core.executor.VariableExpressionExecutor;
import io.siddhi.core.query.processor.ProcessingMode;
import io.siddhi.core.query.processor.Processor;
import io.siddhi.core.query.processor.stream.StreamProcessor;
import io.siddhi.core.util.config.ConfigReader;
import io.siddhi.core.util.snapshot.state.State;
import io.siddhi.core.util.snapshot.state.StateFactory;
import io.siddhi.extension.execution.streamingml.bayesian.classification.util.SoftmaxRegressionModelHolder;
import io.siddhi.extension.execution.streamingml.bayesian.util.BayesianModel;
import io.siddhi.extension.execution.streamingml.bayesian.util.SoftmaxRegression;
import io.siddhi.extension.execution.streamingml.util.CoreUtils;
import io.siddhi.query.api.definition.AbstractDefinition;
import io.siddhi.query.api.definition.Attribute;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

/* loaded from: input_file:io/siddhi/extension/execution/streamingml/bayesian/classification/BayesianClassificationUpdaterStreamProcessorExtension.class */
public class BayesianClassificationUpdaterStreamProcessorExtension extends StreamProcessor<ExtensionState> {
    private static Logger logger = LogManager.getLogger(BayesianClassificationUpdaterStreamProcessorExtension.class);
    private String modelName;
    private int numberOfClasses;
    private String modelPrefix;
    private int numberOfFeatures;
    private SoftmaxRegression model;
    double learningRate;
    int nSamples;
    private BayesianModel.OptimizerType opimizerName;
    private VariableExpressionExecutor targetVariableExpressionExecutor;
    private List<VariableExpressionExecutor> featureVariableExpressionExecutors = new ArrayList();
    private ArrayList<Attribute> attributes;

    /* loaded from: input_file:io/siddhi/extension/execution/streamingml/bayesian/classification/BayesianClassificationUpdaterStreamProcessorExtension$ExtensionState.class */
    static class ExtensionState extends State {
        private static final String KEY_SOFTMAX_REGRESSION_MODEL = "SoftmaxRegressionModel";
        private final Map<String, Object> state;
        private final String modelName;

        private ExtensionState(String str) {
            this.state = new HashMap();
            this.modelName = str;
        }

        public boolean canDestroy() {
            return false;
        }

        public Map<String, Object> snapshot() {
            this.state.put(KEY_SOFTMAX_REGRESSION_MODEL, SoftmaxRegressionModelHolder.getInstance().getClonedSoftmaxRegressionModel(this.modelName));
            return this.state;
        }

        public void restore(Map<String, Object> map) {
            SoftmaxRegression softmaxRegression = (SoftmaxRegression) this.state.get(KEY_SOFTMAX_REGRESSION_MODEL);
            softmaxRegression.initiateModel();
            SoftmaxRegressionModelHolder.getInstance().addSoftmaxRegressionModel(this.modelName, softmaxRegression);
        }
    }

    protected StateFactory<ExtensionState> init(MetaStreamEvent metaStreamEvent, AbstractDefinition abstractDefinition, ExpressionExecutor[] expressionExecutorArr, ConfigReader configReader, StreamEventClonerHolder streamEventClonerHolder, boolean z, boolean z2, SiddhiQueryContext siddhiQueryContext) {
        String name = siddhiQueryContext.getSiddhiAppContext().getName();
        this.model = null;
        this.learningRate = -1.0d;
        this.nSamples = -1;
        this.numberOfClasses = -1;
        this.opimizerName = null;
        int size = this.inputDefinition.getAttributeList().size() - 1;
        if (this.attributeExpressionLength < 4) {
            throw new SiddhiAppCreationException(String.format("Invalid number of parameters [%s] for streamingml:updateBayesianClassification. Expect at least %s parameters", Integer.valueOf(this.attributeExpressionLength), 4));
        }
        if (this.attributeExpressionLength > 6 + size) {
            throw new SiddhiAppCreationException(String.format("Invalid number of parameters for streamingml:updateBayesianClassification. This Stream Processor requires at most %s parameters, namely, model.name, no.of.classes, model.target, model.samples[optional], model.optimizer[optional], learning.rate[optional], model.features. but found %s parameters", Integer.valueOf(6 + size), Integer.valueOf(this.attributeExpressionLength)));
        }
        if (!(expressionExecutorArr[0] instanceof ConstantExpressionExecutor)) {
            throw new SiddhiAppCreationException("Parameter model.name must be a constant. But found " + expressionExecutorArr[0].getClass().getCanonicalName());
        }
        if (expressionExecutorArr[0].getReturnType() != Attribute.Type.STRING) {
            throw new SiddhiAppCreationException("Invalid parameter type found for the model.name argument, required " + Attribute.Type.STRING + " But found " + expressionExecutorArr[0].getReturnType().toString());
        }
        this.modelPrefix = (String) ((ConstantExpressionExecutor) expressionExecutorArr[0]).getValue();
        this.modelName = this.modelPrefix + "." + name;
        if (!(expressionExecutorArr[1] instanceof ConstantExpressionExecutor)) {
            throw new SiddhiAppCreationException("Parameter no.of.classes must be a constant. But found " + expressionExecutorArr[1].getClass().getCanonicalName());
        }
        if (expressionExecutorArr[1].getReturnType() != Attribute.Type.INT) {
            throw new SiddhiAppCreationException("Invalid parameter type found for the no.of.classes argument, required " + Attribute.Type.INT + " But found " + expressionExecutorArr[1].getReturnType().toString());
        }
        int intValue = ((Integer) ((ConstantExpressionExecutor) expressionExecutorArr[1]).getValue()).intValue();
        if (intValue <= 1) {
            throw new SiddhiAppCreationException(String.format("no.of.classes should be greater than 1. But found %d", Integer.valueOf(intValue)));
        }
        this.numberOfClasses = intValue;
        if (!(expressionExecutorArr[2] instanceof VariableExpressionExecutor)) {
            throw new SiddhiAppCreationException("model.target attribute in updateBayesianClassification should be a variable, but found a " + expressionExecutorArr[2].getClass().getCanonicalName());
        }
        this.targetVariableExpressionExecutor = (VariableExpressionExecutor) expressionExecutorArr[2];
        Attribute.Type attributeType = this.inputDefinition.getAttributeType(this.targetVariableExpressionExecutor.getAttribute().getName());
        if (!CoreUtils.isLabelType(attributeType)) {
            throw new SiddhiAppCreationException(String.format("[model.target] %s in updateBayesianClassification should be a STRING or BOOLEAN. But found %s", this.targetVariableExpressionExecutor.getAttribute().getName(), attributeType.name()));
        }
        if (attributeType == Attribute.Type.BOOL && this.numberOfClasses != 2) {
            throw new SiddhiAppCreationException("no.of.classes should be 2, if the type of the attribute model.target is BOOLEAN. But found " + this.numberOfClasses);
        }
        int i = 3;
        while (true) {
            if (!(expressionExecutorArr[i] instanceof ConstantExpressionExecutor)) {
                break;
            }
            if (expressionExecutorArr[i].getReturnType() == Attribute.Type.INT) {
                if (i != 3) {
                    throw new SiddhiAppCreationException(String.format("%dth parameter cannot be type of %s. Only model.sample can be %s, which can be set as the %dth parameter.", Integer.valueOf(i), Attribute.Type.INT, Attribute.Type.INT, 3));
                }
                int intValue2 = ((Integer) ((ConstantExpressionExecutor) expressionExecutorArr[i]).getValue()).intValue();
                if (intValue2 <= 0) {
                    throw new SiddhiAppCreationException(String.format("model.sample should be greater than zero.But found %d", Integer.valueOf(intValue2)));
                }
                this.nSamples = intValue2;
                i++;
            } else if (expressionExecutorArr[i].getReturnType() == Attribute.Type.STRING) {
                if (i > 3 + 1) {
                    throw new SiddhiAppCreationException(String.format("%dth parameter cannot be type of %s. Only model.optimizer can be %s.", Integer.valueOf(i), Attribute.Type.STRING, Attribute.Type.STRING));
                }
                if (this.opimizerName != null) {
                    throw new SiddhiAppCreationException(String.format("%dth parameter cannot be type of %s. Only model.optimizer can be %s, which is already set to %s.", Integer.valueOf(i), Attribute.Type.STRING, Attribute.Type.STRING, this.opimizerName));
                }
                String str = (String) ((ConstantExpressionExecutor) expressionExecutorArr[i]).getValue();
                try {
                    this.opimizerName = BayesianModel.OptimizerType.valueOf(str.toUpperCase(Locale.ENGLISH));
                    i++;
                } catch (Exception e) {
                    throw new SiddhiAppCreationException(String.format("model.optimizer should be one of %s. But found %s", Arrays.toString(BayesianModel.OptimizerType.values()), str));
                }
            } else {
                if (expressionExecutorArr[i].getReturnType() != Attribute.Type.DOUBLE) {
                    throw new SiddhiAppCreationException(String.format("Invalid parameter type found. Expected: %s or %s or %s. But found %s", Attribute.Type.INT, Attribute.Type.STRING, Attribute.Type.DOUBLE, expressionExecutorArr[2].getReturnType().toString()));
                }
                double doubleValue = ((Double) ((ConstantExpressionExecutor) expressionExecutorArr[i]).getValue()).doubleValue();
                if (doubleValue <= CMAESOptimizer.DEFAULT_STOPFITNESS) {
                    throw new SiddhiAppCreationException(String.format("learning.rate should be greater than zero. But found %f", Double.valueOf(doubleValue)));
                }
                this.learningRate = doubleValue;
                i++;
            }
        }
        if (!(expressionExecutorArr[i] instanceof VariableExpressionExecutor)) {
            throw new SiddhiAppCreationException("Parameter " + i + " must either be a constant (hyperparameter) or an attribute of the stream (model.features), but found a " + expressionExecutorArr[2].getClass().getCanonicalName());
        }
        this.numberOfFeatures = this.attributeExpressionLength - i;
        this.featureVariableExpressionExecutors = CoreUtils.extractAndValidateFeatures(this.inputDefinition, expressionExecutorArr, i, this.numberOfFeatures);
        this.attributes = new ArrayList<>();
        this.attributes.add(new Attribute("loss", Attribute.Type.DOUBLE));
        return () -> {
            return new ExtensionState(this.modelName);
        };
    }

    protected void process(ComplexEventChunk<StreamEvent> complexEventChunk, Processor processor, StreamEventCloner streamEventCloner, ComplexEventPopulater complexEventPopulater, ExtensionState extensionState) {
        synchronized (this) {
            while (complexEventChunk.hasNext()) {
                StreamEvent next = complexEventChunk.next();
                if (logger.isDebugEnabled()) {
                    logger.debug(String.format("Event received; Model name: %s Event:%s", this.modelName, next));
                }
                String obj = this.targetVariableExpressionExecutor.execute(next).toString();
                double[] dArr = new double[this.numberOfFeatures];
                for (int i = 0; i < this.numberOfFeatures; i++) {
                    dArr[i] = ((Number) this.featureVariableExpressionExecutors.get(i).execute(next)).doubleValue();
                }
                logger.debug(Double.valueOf(SoftmaxRegressionModelHolder.getInstance().getSoftmaxRegressionModel(this.modelName).evaluate(dArr, obj)));
                complexEventPopulater.populateComplexEvent(next, new Object[]{Double.valueOf(SoftmaxRegressionModelHolder.getInstance().getSoftmaxRegressionModel(this.modelName).update(dArr, obj)[0])});
            }
        }
        this.nextProcessor.process(complexEventChunk);
    }

    public void start() {
        if (SoftmaxRegressionModelHolder.getInstance().getSoftmaxRegressionMap().containsKey(this.modelName)) {
            throw new SiddhiAppCreationException("A model already exists with name the " + this.modelPrefix + ". Use a different value for model.name argument.");
        }
        this.model = new SoftmaxRegression(this.numberOfClasses);
        SoftmaxRegressionModelHolder.getInstance().addSoftmaxRegressionModel(this.modelName, this.model);
        if (this.learningRate != -1.0d) {
            logger.debug("set learning rate to : " + this.learningRate);
            this.model.setLearningRate(this.learningRate);
        }
        if (this.nSamples != -1) {
            logger.debug("set number of samples to : " + this.nSamples);
            this.model.setNumSamples(this.nSamples);
        }
        if (this.opimizerName != null) {
            logger.debug("set optimizer to : " + this.opimizerName);
            this.model.setOptimizerType(this.opimizerName);
        }
        if (this.model.getNumFeatures() != -1) {
            if (this.numberOfFeatures != this.model.getNumFeatures()) {
                throw new SiddhiAppCreationException(String.format("Model [%s] expects %s features, but the streamingml:updateBayesianClassification specifies %s features", this.modelPrefix, Integer.valueOf(this.model.getNumFeatures()), Integer.valueOf(this.numberOfFeatures)));
            }
        } else {
            this.model.setNumFeatures(this.numberOfFeatures);
            this.model.initiateModel();
        }
    }

    public void stop() {
        SoftmaxRegressionModelHolder.getInstance().deleteSoftmaxRegressionModel(this.modelName);
    }

    public List<Attribute> getReturnAttributes() {
        return this.attributes;
    }

    public ProcessingMode getProcessingMode() {
        return ProcessingMode.BATCH;
    }

    protected /* bridge */ /* synthetic */ void process(ComplexEventChunk complexEventChunk, Processor processor, StreamEventCloner streamEventCloner, ComplexEventPopulater complexEventPopulater, State state) {
        process((ComplexEventChunk<StreamEvent>) complexEventChunk, processor, streamEventCloner, complexEventPopulater, (ExtensionState) state);
    }
}
