package moa.classifiers.meta;

import com.github.javacliparser.IntOption;
import com.yahoo.labs.samoa.instances.Instance;
import com.yahoo.labs.samoa.instances.Instances;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.Classifier;
import moa.classifiers.MultiClassClassifier;
import moa.classifiers.trees.HoeffdingTree;
import moa.core.DoubleVector;
import moa.core.Measurement;
import moa.core.ObjectRepository;
import moa.options.ClassOption;
import moa.tasks.TaskMonitor;

/* loaded from: input_file:moa/classifiers/meta/AccuracyUpdatedEnsemble.class */
public class AccuracyUpdatedEnsemble extends AbstractClassifier implements MultiClassClassifier {
    private static final long serialVersionUID = 1;
    public ClassOption learnerOption = new ClassOption("learner", 'l', "Classifier to train.", Classifier.class, "trees.HoeffdingTree -e 2000000 -g 100 -c 0.01");
    public IntOption memberCountOption = new IntOption("memberCount", 'n', "The maximum number of classifiers in an ensemble.", 10, 1, Integer.MAX_VALUE);
    public IntOption chunkSizeOption = new IntOption("chunkSize", 'c', "The chunk size used for classifier creation and evaluation.", 500, 1, Integer.MAX_VALUE);
    public IntOption maxByteSizeOption = new IntOption("maxByteSize", 'm', "Maximum memory consumed by ensemble.", 33554432, 0, Integer.MAX_VALUE);
    protected double[][] weights;
    protected long[] classDistributions;
    protected Classifier[] learners;
    protected int processedInstances;
    protected Classifier candidate;
    protected Instances currentChunk;

    @Override // moa.classifiers.AbstractClassifier, moa.options.AbstractOptionHandler
    public void prepareForUseImpl(TaskMonitor taskMonitor, ObjectRepository objectRepository) {
        this.candidate = (Classifier) getPreparedClassOption(this.learnerOption);
        this.candidate.resetLearning();
        super.prepareForUseImpl(taskMonitor, objectRepository);
    }

    @Override // moa.classifiers.AbstractClassifier
    public void resetLearningImpl() {
        this.currentChunk = null;
        this.classDistributions = null;
        this.processedInstances = 0;
        this.learners = new Classifier[0];
        this.candidate = (Classifier) getPreparedClassOption(this.learnerOption);
        this.candidate.resetLearning();
    }

    @Override // moa.classifiers.AbstractClassifier
    public void trainOnInstanceImpl(Instance instance) {
        initVariables();
        long[] jArr = this.classDistributions;
        int classValue = (int) instance.classValue();
        jArr[classValue] = jArr[classValue] + 1;
        this.currentChunk.add(instance);
        this.processedInstances++;
        if (this.processedInstances % this.chunkSizeOption.getValue() == 0) {
            processChunk();
        }
    }

    @Override // moa.learners.Learner
    public boolean isRandomizable() {
        return false;
    }

    @Override // moa.classifiers.AbstractClassifier, moa.classifiers.Classifier
    public double[] getVotesForInstance(Instance instance) {
        DoubleVector doubleVector = new DoubleVector();
        if (this.trainingWeightSeenByModel > 0.0d) {
            for (int i = 0; i < this.learners.length; i++) {
                if (this.weights[i][0] > 0.0d) {
                    DoubleVector doubleVector2 = new DoubleVector(this.learners[(int) this.weights[i][1]].getVotesForInstance(instance));
                    if (doubleVector2.sumOfValues() > 0.0d) {
                        doubleVector2.normalize();
                        doubleVector2.scaleValues(this.weights[i][0] / ((1.0d * this.learners.length) + 1.0d));
                        doubleVector.addValues(doubleVector2);
                    }
                }
            }
        }
        return doubleVector.getArrayRef();
    }

    @Override // moa.classifiers.AbstractClassifier
    public void getModelDescription(StringBuilder sb, int i) {
    }

    @Override // moa.classifiers.AbstractClassifier, moa.classifiers.Classifier
    public Classifier[] getSubClassifiers() {
        return (Classifier[]) this.learners.clone();
    }

    protected void processChunk() {
        double computeMseR = computeMseR();
        double d = 1.0d / (computeMseR + Double.MIN_VALUE);
        for (int i = 0; i < this.learners.length; i++) {
            this.weights[i][0] = 1.0d / ((computeMseR + computeMse(this.learners[(int) this.weights[i][1]], this.currentChunk)) + Double.MIN_VALUE);
        }
        if (this.learners.length < this.memberCountOption.getValue()) {
            addToStored(this.candidate, d);
        } else {
            int poorestClassifierIndex = getPoorestClassifierIndex();
            if (this.weights[poorestClassifierIndex][0] < d) {
                this.weights[poorestClassifierIndex][0] = d;
                this.learners[(int) this.weights[poorestClassifierIndex][1]] = this.candidate.copy();
            }
        }
        for (int i2 = 0; i2 < this.learners.length; i2++) {
            trainOnChunk(this.learners[(int) this.weights[i2][1]]);
        }
        this.classDistributions = null;
        this.currentChunk = null;
        this.candidate = (Classifier) getPreparedClassOption(this.learnerOption);
        this.candidate.resetLearning();
        enforceMemoryLimit();
    }

    protected void enforceMemoryLimit() {
        double value = this.maxByteSizeOption.getValue() / (this.learners.length + 1);
        for (int i = 0; i < this.learners.length; i++) {
            ((HoeffdingTree) this.learners[(int) this.weights[i][1]]).maxByteSizeOption.setValue((int) Math.round(value));
            ((HoeffdingTree) this.learners[(int) this.weights[i][1]]).enforceTrackerLimit();
        }
    }

    protected double computeMseR() {
        double d = 0.0d;
        for (int i = 0; i < this.classDistributions.length; i++) {
            double value = this.classDistributions[i] / this.chunkSizeOption.getValue();
            d += value * (1.0d - value) * (1.0d - value);
        }
        return d;
    }

    protected double computeMse(Classifier classifier, Instances instances) {
        double d = 0.0d;
        for (int i = 0; i < instances.numInstances(); i++) {
            try {
                double d2 = 0.0d;
                for (double d3 : classifier.getVotesForInstance(instances.instance(i))) {
                    d2 += d3;
                }
                if (d2 > 0.0d) {
                    double d4 = classifier.getVotesForInstance(instances.instance(i))[(int) instances.instance(i).classValue()] / d2;
                    d += (1.0d - d4) * (1.0d - d4);
                } else {
                    d += 1.0d;
                }
            } catch (Exception e) {
                d += 1.0d;
            }
        }
        return d / this.chunkSizeOption.getValue();
    }

    @Override // moa.classifiers.AbstractClassifier
    protected Measurement[] getModelMeasurementsImpl() {
        Measurement[] measurementArr = new Measurement[this.memberCountOption.getValue()];
        for (int i = 0; i < this.memberCountOption.getValue(); i++) {
            measurementArr[i] = new Measurement("Member weight " + (i + 1), -1.0d);
        }
        if (this.weights != null) {
            for (int i2 = 0; i2 < this.weights.length; i2++) {
                measurementArr[i2] = new Measurement("Member weight " + (i2 + 1), this.weights[i2][0]);
            }
        }
        return measurementArr;
    }

    protected Classifier addToStored(Classifier classifier, double d) {
        Classifier classifier2 = null;
        Classifier[] classifierArr = new Classifier[this.learners.length + 1];
        double[][] dArr = new double[classifierArr.length][2];
        for (int i = 0; i < classifierArr.length; i++) {
            if (i < this.learners.length) {
                classifierArr[i] = this.learners[i];
                dArr[i][0] = this.weights[i][0];
                dArr[i][1] = this.weights[i][1];
            } else {
                Classifier copy = classifier.copy();
                classifier2 = copy;
                classifierArr[i] = copy;
                dArr[i][0] = d;
                dArr[i][1] = i;
            }
        }
        this.learners = classifierArr;
        this.weights = dArr;
        return classifier2;
    }

    private int getPoorestClassifierIndex() {
        int i = 0;
        for (int i2 = 1; i2 < this.weights.length; i2++) {
            if (this.weights[i2][0] < this.weights[i][0]) {
                i = i2;
            }
        }
        return i;
    }

    private void initVariables() {
        if (this.currentChunk == null) {
            this.currentChunk = new Instances(getModelContext());
        }
        if (this.classDistributions == null) {
            this.classDistributions = new long[getModelContext().classAttribute().numValues()];
            for (int i = 0; i < this.classDistributions.length; i++) {
                this.classDistributions[i] = 0;
            }
        }
    }

    private void trainOnChunk(Classifier classifier) {
        for (int i = 0; i < this.chunkSizeOption.getValue(); i++) {
            classifier.trainOnInstance(this.currentChunk.instance(i));
        }
    }
}
