package org.wso2.extension.siddhi.gpl.execution.streamingml.classification.hoeffdingtree.util;

import com.yahoo.labs.samoa.instances.Attribute;
import com.yahoo.labs.samoa.instances.DenseInstance;
import com.yahoo.labs.samoa.instances.Instance;
import com.yahoo.labs.samoa.instances.Instances;
import com.yahoo.labs.samoa.instances.InstancesHeader;
import io.siddhi.core.exception.SiddhiAppRuntimeException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import moa.classifiers.trees.HoeffdingAdaptiveTree;
import moa.core.ObjectRepository;
import moa.options.AbstractOptionHandler;
import moa.streams.InstanceStream;
import moa.tasks.TaskMonitor;
import org.apache.log4j.Logger;
import org.wso2.extension.siddhi.gpl.execution.streamingml.classification.ClassifierPrequentialModelEvaluation;
import org.wso2.extension.siddhi.gpl.execution.streamingml.util.CoreUtils;
import org.wso2.extension.siddhi.gpl.execution.streamingml.util.MathUtil;

/* loaded from: input_file:org/wso2/extension/siddhi/gpl/execution/streamingml/classification/hoeffdingtree/util/AdaptiveHoeffdingTreeModel.class */
public class AdaptiveHoeffdingTreeModel extends AbstractOptionHandler {
    private static final long serialVersionUID = 1;
    private static final Logger logger = Logger.getLogger(AdaptiveHoeffdingTreeModel.class);
    private String modelName;
    private InstancesHeader streamHeader;
    private int noOfFeatures;
    private int noOfClasses;
    private HoeffdingAdaptiveTree hoeffdingAdaptiveTree;
    private List<String> classes;

    @Override // moa.MOAObject
    public void getDescription(StringBuilder sb, int i) {
        logger.info("Hoeffding Adaptive Tree for evolving data streams that uses ADWIN to replace branches for new ones.");
    }

    public AdaptiveHoeffdingTreeModel(String str) {
        this.classes = new ArrayList();
        this.modelName = str;
    }

    public AdaptiveHoeffdingTreeModel(AdaptiveHoeffdingTreeModel adaptiveHoeffdingTreeModel) {
        this.classes = new ArrayList();
        this.modelName = adaptiveHoeffdingTreeModel.modelName;
        this.streamHeader = adaptiveHoeffdingTreeModel.streamHeader;
        this.noOfFeatures = adaptiveHoeffdingTreeModel.noOfFeatures;
        this.noOfClasses = adaptiveHoeffdingTreeModel.noOfClasses;
        this.hoeffdingAdaptiveTree = adaptiveHoeffdingTreeModel.hoeffdingAdaptiveTree;
        this.classes = adaptiveHoeffdingTreeModel.classes;
    }

    public void init(int i, int i2) {
        if (logger.isDebugEnabled()) {
            logger.debug(String.format("Model [%s] is being initialized.", this.modelName));
        }
        this.noOfFeatures = i;
        this.noOfClasses = i2;
        this.streamHeader = createMOAInstanceHeader(this.noOfFeatures);
        this.hoeffdingAdaptiveTree = new HoeffdingAdaptiveTree();
        this.hoeffdingAdaptiveTree.setModelContext(this.streamHeader);
        this.hoeffdingAdaptiveTree.prepareForUse();
    }

    public void setConfigurations(int i, int i2, double d, double d2, boolean z, boolean z2, int i3) {
        if (logger.isDebugEnabled()) {
            logger.debug(String.format("Model [%s] is being configured with hyper-parameters.", this.modelName));
        }
        this.hoeffdingAdaptiveTree.gracePeriodOption.setValue(i);
        if (i2 == 0) {
            this.hoeffdingAdaptiveTree.splitCriterionOption.setValueViaCLIString("InfoGainSplitCriterion");
        } else {
            this.hoeffdingAdaptiveTree.splitCriterionOption.setValueViaCLIString("GiniSplitCriterion");
        }
        this.hoeffdingAdaptiveTree.splitConfidenceOption.setValue(d);
        this.hoeffdingAdaptiveTree.tieThresholdOption.setValue(d2);
        this.hoeffdingAdaptiveTree.binarySplitsOption.setValue(z);
        this.hoeffdingAdaptiveTree.noPrePruneOption.setValue(z2);
        this.hoeffdingAdaptiveTree.leafpredictionOption.setChosenIndex(i3);
    }

    public void trainOnEvent(double[] dArr, String str) {
        dArr[this.noOfFeatures - 1] = addClass(str);
        Instance createMOAInstance = createMOAInstance(dArr);
        createMOAInstance.setClassValue(dArr[this.noOfFeatures - 1]);
        this.hoeffdingAdaptiveTree.trainOnInstanceImpl(createMOAInstance);
    }

    public double evaluationTrainOnEvent(ClassifierPrequentialModelEvaluation classifierPrequentialModelEvaluation, double[] dArr, String str) {
        int length = dArr.length - 1;
        double[] votesForInstance = this.hoeffdingAdaptiveTree.getVotesForInstance(createMOAInstance(Arrays.copyOfRange(dArr, 0, length)));
        dArr[length] = getClasses().indexOf(str);
        Instance createMOAInstance = createMOAInstance(dArr);
        this.hoeffdingAdaptiveTree.trainOnInstanceImpl(createMOAInstance);
        classifierPrequentialModelEvaluation.addResult(createMOAInstance, votesForInstance);
        return MathUtil.roundOff(classifierPrequentialModelEvaluation.getFractionCorrectlyClassified(), 3);
    }

    public Object[] getPrediction(double[] dArr) {
        double[] votesForInstance = this.hoeffdingAdaptiveTree.getVotesForInstance(createMOAInstance(dArr));
        return new Object[]{Integer.valueOf(CoreUtils.argMaxIndex(votesForInstance)), Double.valueOf(getPredictionConfidence(votesForInstance))};
    }

    private Instance createMOAInstance(double[] dArr) {
        DenseInstance denseInstance = new DenseInstance(1.0d, dArr);
        denseInstance.setDataset(this.streamHeader);
        return denseInstance;
    }

    private InstancesHeader createMOAInstanceHeader(int i) {
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < i - 1; i2++) {
            arrayList.add(new Attribute(Attribute.ARFF_ATTRIBUTE_NUMERIC + (i2 + 1)));
        }
        ArrayList arrayList2 = new ArrayList();
        for (int i3 = 0; i3 < this.noOfClasses; i3++) {
            arrayList2.add("class" + (i3 + 1));
        }
        arrayList.add(new Attribute("class", arrayList2));
        InstancesHeader instancesHeader = new InstancesHeader(new Instances(getCLICreationString(InstanceStream.class), arrayList, 0));
        instancesHeader.setClassIndex(this.noOfFeatures - 1);
        return instancesHeader;
    }

    private int addClass(String str) {
        if (this.classes.contains(str)) {
            return this.classes.indexOf(str);
        }
        if (this.classes.size() >= this.noOfClasses) {
            throw new SiddhiAppRuntimeException(String.format("Number of classes %s is expected from the model %s but found %s", Integer.valueOf(this.noOfClasses), this.modelName, Integer.valueOf(this.classes.size())));
        }
        this.classes.add(str);
        return this.classes.indexOf(str);
    }

    private double getPredictionConfidence(double[] dArr) {
        return MathUtil.roundOff(CoreUtils.argMax(dArr) / MathUtil.sum(dArr), 3);
    }

    public String getModelName() {
        return this.modelName;
    }

    public InstancesHeader getStreamHeader() {
        return this.streamHeader;
    }

    public List<String> getClasses() {
        return this.classes;
    }

    public int getNoOfFeatures() {
        return this.noOfFeatures;
    }

    @Override // moa.options.AbstractOptionHandler
    protected void prepareForUseImpl(TaskMonitor taskMonitor, ObjectRepository objectRepository) {
    }
}
