package moa.classifiers.meta;

import com.github.javacliparser.FlagOption;
import com.github.javacliparser.FloatOption;
import com.github.javacliparser.ListOption;
import com.github.javacliparser.Option;
import com.yahoo.labs.samoa.instances.Instance;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.Classifier;
import moa.classifiers.MultiClassClassifier;
import moa.core.DoubleVector;
import moa.core.Measurement;
import moa.core.ObjectRepository;
import moa.core.Utils;
import moa.options.ClassOption;
import moa.tasks.TaskMonitor;

/* loaded from: input_file:moa/classifiers/meta/WeightedMajorityAlgorithm.class */
public class WeightedMajorityAlgorithm extends AbstractClassifier implements MultiClassClassifier {
    private static final long serialVersionUID = 1;
    public ListOption learnerListOption = new ListOption("learners", 'l', "The learners to combine.", new ClassOption("learner", ' ', "", Classifier.class, "trees.HoeffdingTree"), new Option[]{new ClassOption("", ' ', "", Classifier.class, "trees.HoeffdingTree -l MC"), new ClassOption("", ' ', "", Classifier.class, "trees.HoeffdingTree -l NB"), new ClassOption("", ' ', "", Classifier.class, "trees.HoeffdingTree -l NBAdaptive"), new ClassOption("", ' ', "", Classifier.class, "bayes.NaiveBayes")}, ',');
    public FloatOption betaOption = new FloatOption("beta", 'b', "Factor to punish mistakes by.", 0.9d, 0.0d, 1.0d);
    public FloatOption gammaOption = new FloatOption("gamma", 'g', "Minimum fraction of weight per model.", 0.01d, 0.0d, 0.5d);
    public FlagOption pruneOption = new FlagOption("prune", 'p', "Prune poorly performing models from ensemble.");
    protected Classifier[] ensemble;
    protected double[] ensembleWeights;

    @Override // moa.classifiers.AbstractClassifier, moa.options.AbstractOptionHandler, moa.options.OptionHandler
    public String getPurposeString() {
        return "Weighted majority algorithm for data streams.";
    }

    @Override // moa.classifiers.AbstractClassifier, moa.options.AbstractOptionHandler
    public void prepareForUseImpl(TaskMonitor taskMonitor, ObjectRepository objectRepository) {
        Option[] list = this.learnerListOption.getList();
        this.ensemble = new Classifier[list.length];
        for (int i = 0; i < list.length; i++) {
            taskMonitor.setCurrentActivity("Materializing learner " + (i + 1) + "...", -1.0d);
            this.ensemble[i] = (Classifier) ((ClassOption) list[i]).materializeObject(taskMonitor, objectRepository);
            if (taskMonitor.taskShouldAbort()) {
                return;
            }
            taskMonitor.setCurrentActivity("Preparing learner " + (i + 1) + "...", -1.0d);
            this.ensemble[i].prepareForUse(taskMonitor, objectRepository);
            if (taskMonitor.taskShouldAbort()) {
                return;
            }
        }
        super.prepareForUseImpl(taskMonitor, objectRepository);
    }

    @Override // moa.classifiers.AbstractClassifier
    public void resetLearningImpl() {
        this.ensembleWeights = new double[this.ensemble.length];
        for (int i = 0; i < this.ensemble.length; i++) {
            this.ensemble[i].resetLearning();
            this.ensembleWeights[i] = 1.0d;
        }
    }

    @Override // moa.classifiers.AbstractClassifier
    public void trainOnInstanceImpl(Instance instance) {
        double d = 0.0d;
        int i = 0;
        while (i < this.ensemble.length) {
            boolean z = false;
            if (!this.ensemble[i].correctlyClassifies(instance)) {
                if (this.ensembleWeights[i] > this.gammaOption.getValue() / this.ensembleWeights.length) {
                    double[] dArr = this.ensembleWeights;
                    int i2 = i;
                    dArr[i2] = dArr[i2] * this.betaOption.getValue() * instance.weight();
                } else if (this.pruneOption.isSet()) {
                    z = true;
                    discardModel(i);
                    i--;
                }
            }
            if (!z) {
                d += this.ensembleWeights[i];
                this.ensemble[i].trainOnInstance(instance);
            }
            i++;
        }
        for (int i3 = 0; i3 < this.ensembleWeights.length; i3++) {
            double[] dArr2 = this.ensembleWeights;
            int i4 = i3;
            dArr2[i4] = dArr2[i4] / d;
        }
    }

    @Override // moa.classifiers.AbstractClassifier, moa.classifiers.Classifier
    public double[] getVotesForInstance(Instance instance) {
        DoubleVector doubleVector = new DoubleVector();
        if (this.trainingWeightSeenByModel > 0.0d) {
            for (int i = 0; i < this.ensemble.length; i++) {
                if (this.ensembleWeights[i] > 0.0d) {
                    DoubleVector doubleVector2 = new DoubleVector(this.ensemble[i].getVotesForInstance(instance));
                    if (doubleVector2.sumOfValues() > 0.0d) {
                        doubleVector2.normalize();
                        doubleVector2.scaleValues(this.ensembleWeights[i]);
                        doubleVector.addValues(doubleVector2);
                    }
                }
            }
        }
        return doubleVector.getArrayRef();
    }

    @Override // moa.classifiers.AbstractClassifier
    public void getModelDescription(StringBuilder sb, int i) {
    }

    @Override // moa.classifiers.AbstractClassifier
    protected Measurement[] getModelMeasurementsImpl() {
        Measurement[] measurementArr = null;
        if (this.ensembleWeights != null) {
            measurementArr = new Measurement[this.ensembleWeights.length];
            for (int i = 0; i < this.ensembleWeights.length; i++) {
                measurementArr[i] = new Measurement("member weight " + (i + 1), this.ensembleWeights[i]);
            }
        }
        return measurementArr;
    }

    @Override // moa.learners.Learner
    public boolean isRandomizable() {
        return false;
    }

    @Override // moa.classifiers.AbstractClassifier, moa.classifiers.Classifier
    public Classifier[] getSubClassifiers() {
        return (Classifier[]) this.ensemble.clone();
    }

    public void discardModel(int i) {
        Classifier[] classifierArr = new Classifier[this.ensemble.length - 1];
        double[] dArr = new double[classifierArr.length];
        int i2 = 0;
        for (int i3 = 0; i3 < classifierArr.length; i3++) {
            if (i2 == i) {
                i2++;
            }
            classifierArr[i3] = this.ensemble[i2];
            dArr[i3] = this.ensembleWeights[i2];
            i2++;
        }
        this.ensemble = classifierArr;
        this.ensembleWeights = dArr;
    }

    protected int removePoorestModelBytes() {
        int minIndex = Utils.minIndex(this.ensembleWeights);
        int measureByteSize = this.ensemble[minIndex].measureByteSize();
        discardModel(minIndex);
        return measureByteSize;
    }
}
