package org.wso2.extension.siddhi.execution.ml.samoa.utils.classification;

import java.util.Collections;
import java.util.Vector;
import org.apache.samoa.evaluation.ClassificationPerformanceEvaluator;
import org.apache.samoa.instances.Instance;
import org.apache.samoa.instances.Utils;
import org.apache.samoa.moa.AbstractMOAObject;
import org.apache.samoa.moa.core.Measurement;

/* loaded from: input_file:org/wso2/extension/siddhi/execution/ml/samoa/utils/classification/StreamingClassificationPerformanceEvaluator.class */
public class StreamingClassificationPerformanceEvaluator extends AbstractMOAObject implements ClassificationPerformanceEvaluator {
    private static final long serialVersionUID = 1;
    protected int numClasses = -1;
    protected long[] support;
    protected long[] truePositive;
    protected long[] falsePositive;
    protected long[] trueNegative;
    protected long[] falseNegative;
    protected double weightObserved;
    protected double weightCorrect;
    protected double[] columnKappa;
    protected double[] rowKappa;
    private double weightCorrectNoChangeClassifier;
    private int lastSeenClass;

    @Override // org.apache.samoa.evaluation.PerformanceEvaluator
    public void reset() {
        reset(this.numClasses);
    }

    public void reset(int i) {
        this.numClasses = i;
        this.support = new long[i];
        this.truePositive = new long[i];
        this.falsePositive = new long[i];
        this.trueNegative = new long[i];
        this.falseNegative = new long[i];
        this.rowKappa = new double[i];
        this.columnKappa = new double[i];
        for (int i2 = 0; i2 < this.numClasses; i2++) {
            this.support[i2] = 0;
            this.truePositive[i2] = 0;
            this.falsePositive[i2] = 0;
            this.trueNegative[i2] = 0;
            this.falseNegative[i2] = 0;
            this.rowKappa[i2] = 0.0d;
            this.columnKappa[i2] = 0.0d;
        }
        this.weightObserved = 0.0d;
        this.weightCorrect = 0.0d;
        this.weightCorrectNoChangeClassifier = 0.0d;
        this.lastSeenClass = 0;
    }

    @Override // org.apache.samoa.evaluation.PerformanceEvaluator
    public void addResult(Instance instance, double[] dArr) {
        if (this.numClasses == -1) {
            reset(instance.numClasses());
        }
        double weight = instance.weight();
        int classValue = (int) instance.classValue();
        if (weight > 0.0d) {
            if (this.weightObserved == 0.0d) {
                reset(instance.numClasses());
            }
            this.weightObserved += weight;
            int maxIndex = Utils.maxIndex(dArr);
            if (maxIndex == classValue) {
                this.weightCorrect += weight;
            }
            if (this.rowKappa.length > 0) {
                double[] dArr2 = this.rowKappa;
                dArr2[maxIndex] = dArr2[maxIndex] + weight;
            }
            if (this.columnKappa.length > 0) {
                double[] dArr3 = this.columnKappa;
                dArr3[classValue] = dArr3[classValue] + weight;
            }
        }
        if (this.lastSeenClass == classValue) {
            this.weightCorrectNoChangeClassifier += weight;
        }
        this.lastSeenClass = classValue;
        long[] jArr = this.support;
        jArr[classValue] = jArr[classValue] + 1;
        int maxIndex2 = Utils.maxIndex(dArr);
        if (maxIndex2 == classValue) {
            long[] jArr2 = this.truePositive;
            jArr2[classValue] = jArr2[classValue] + 1;
            for (int i = 0; i < this.numClasses; i++) {
                if (i != maxIndex2) {
                    long[] jArr3 = this.trueNegative;
                    int i2 = i;
                    jArr3[i2] = jArr3[i2] + 1;
                }
            }
            return;
        }
        long[] jArr4 = this.falsePositive;
        jArr4[maxIndex2] = jArr4[maxIndex2] + 1;
        long[] jArr5 = this.falseNegative;
        jArr5[classValue] = jArr5[classValue] + 1;
        for (int i3 = 0; i3 < this.numClasses; i3++) {
            if (i3 != maxIndex2 && i3 != classValue) {
                long[] jArr6 = this.trueNegative;
                int i4 = i3;
                jArr6[i4] = jArr6[i4] + 1;
            }
        }
    }

    @Override // org.apache.samoa.evaluation.PerformanceEvaluator
    public Measurement[] getPerformanceMeasurements() {
        Measurement[] measurementArr = {new Measurement("classified instances", getTotalWeightObserved()), new Measurement("classifications correct (percent)", getFractionCorrectlyClassified() * 100.0d), new Measurement("Kappa Statistic (percent)", getKappaStatistic() * 100.0d), new Measurement("Kappa Temporal Statistic (percent)", getKappaTemporalStatistic() * 100.0d)};
        Vector vector = new Vector();
        Collections.addAll(vector, measurementArr);
        Collections.addAll(vector, getSupportMeasurements());
        Collections.addAll(vector, getPrecisionMeasurements());
        Collections.addAll(vector, getRecallMeasurements());
        Collections.addAll(vector, getF1Measurements());
        return (Measurement[]) vector.toArray(new Measurement[vector.size()]);
    }

    private Measurement[] getSupportMeasurements() {
        Measurement[] measurementArr = new Measurement[this.numClasses];
        for (int i = 0; i < this.numClasses; i++) {
            measurementArr[i] = new Measurement(String.format("class %s support", Integer.valueOf(i)), this.support[i]);
        }
        return measurementArr;
    }

    private Measurement[] getPrecisionMeasurements() {
        Measurement[] measurementArr = new Measurement[this.numClasses];
        for (int i = 0; i < this.numClasses; i++) {
            measurementArr[i] = new Measurement(String.format("class %s precision", Integer.valueOf(i)), getPrecision(i), 10);
        }
        return measurementArr;
    }

    private Measurement[] getRecallMeasurements() {
        Measurement[] measurementArr = new Measurement[this.numClasses];
        for (int i = 0; i < this.numClasses; i++) {
            measurementArr[i] = new Measurement(String.format("class %s recall", Integer.valueOf(i)), getRecall(i), 10);
        }
        return measurementArr;
    }

    private Measurement[] getF1Measurements() {
        Measurement[] measurementArr = new Measurement[this.numClasses];
        for (int i = 0; i < this.numClasses; i++) {
            measurementArr[i] = new Measurement(String.format("class %s f1-score", Integer.valueOf(i)), getF1Score(i), 10);
        }
        return measurementArr;
    }

    @Override // org.apache.samoa.moa.MOAObject
    public void getDescription(StringBuilder sb, int i) {
        Measurement.getMeasurementsDescription(getSupportMeasurements(), sb, i);
        Measurement.getMeasurementsDescription(getPrecisionMeasurements(), sb, i);
        Measurement.getMeasurementsDescription(getRecallMeasurements(), sb, i);
        Measurement.getMeasurementsDescription(getF1Measurements(), sb, i);
        Measurement.getMeasurementsDescription(getPerformanceMeasurements(), sb, i);
    }

    private double getPrecision(int i) {
        return this.truePositive[i] / (this.truePositive[i] + this.falsePositive[i]);
    }

    private double getRecall(int i) {
        return this.truePositive[i] / (this.truePositive[i] + this.falseNegative[i]);
    }

    private double getF1Score(int i) {
        double precision = getPrecision(i);
        double recall = getRecall(i);
        return ((2.0d * precision) * recall) / (precision + recall);
    }

    public double getTotalWeightObserved() {
        return this.weightObserved;
    }

    public double getFractionCorrectlyClassified() {
        if (this.weightObserved > 0.0d) {
            return this.weightCorrect / this.weightObserved;
        }
        return 0.0d;
    }

    public double getFractionIncorrectlyClassified() {
        return 1.0d - getFractionCorrectlyClassified();
    }

    public double getKappaStatistic() {
        if (this.weightObserved <= 0.0d) {
            return 0.0d;
        }
        double fractionCorrectlyClassified = getFractionCorrectlyClassified();
        double d = 0.0d;
        for (int i = 0; i < this.numClasses; i++) {
            d += (this.rowKappa[i] / this.weightObserved) * (this.columnKappa[i] / this.weightObserved);
        }
        return (fractionCorrectlyClassified - d) / (1.0d - d);
    }

    public double getKappaTemporalStatistic() {
        if (this.weightObserved <= 0.0d) {
            return 0.0d;
        }
        double d = this.weightCorrect / this.weightObserved;
        double d2 = this.weightCorrectNoChangeClassifier / this.weightObserved;
        return (d - d2) / (1.0d - d2);
    }
}
