package moa.classifiers.meta;

import com.github.javacliparser.FloatOption;
import com.github.javacliparser.IntOption;
import com.github.javacliparser.MultiChoiceOption;
import com.yahoo.labs.samoa.instances.Instance;
import com.yahoo.labs.samoa.instances.Instances;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.Classifier;
import moa.classifiers.MultiClassClassifier;
import moa.core.DoubleVector;
import moa.core.Measurement;
import moa.options.ClassOption;

/* loaded from: input_file:moa/classifiers/meta/LearnNSE.class */
public class LearnNSE extends AbstractClassifier implements MultiClassClassifier {
    public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', "Classifier to train.", Classifier.class, "bayes.NaiveBayes");
    public IntOption periodOption = new IntOption("period", 'p', "Size of the environments.", 250, 1, Integer.MAX_VALUE);
    public FloatOption sigmoidSlopeOption = new FloatOption("sigmoidSlope", 'a', "Slope of the sigmoid function controlling the number of previous periods taken into account during weighting.", 0.5d, 0.0d, 3.4028234663852886E38d);
    public FloatOption sigmoidCrossingPointOption = new FloatOption("sigmoidCrossingPoint", 'b', "Halfway crossing point of the sigmoid function controlling the number of previous periods taken into account during weighting.", 10.0d, 0.0d, 3.4028234663852886E38d);
    public IntOption ensembleSizeOption = new IntOption("ensembleSize", 'e', "Ensemble size.", 15, 1, Integer.MAX_VALUE);
    public MultiChoiceOption pruningStrategyOption = new MultiChoiceOption("pruningStrategy", 's', "Classifiers pruning strategy to be used.", new String[]{"NO", "AGE", "ERROR"}, new String[]{"Don't prune classifiers", "Age-based", "Error-based"}, 0);
    protected List<Classifier> ensemble;
    protected List<Double> ensembleWeights;
    protected List<ArrayList<Double>> bkts;
    protected List<ArrayList<Double>> wkts;
    protected Instances buffer;
    protected long index;
    protected double slope;
    protected double crossingPoint;
    protected int pruning;
    protected int ensembleSize;

    @Override // moa.classifiers.AbstractClassifier
    public void resetLearningImpl() {
        this.ensemble = new ArrayList();
        this.ensembleWeights = new ArrayList();
        this.bkts = new ArrayList();
        this.wkts = new ArrayList();
        this.index = 0L;
        this.buffer = null;
        this.slope = this.sigmoidSlopeOption.getValue();
        this.crossingPoint = this.sigmoidCrossingPointOption.getValue();
        this.pruning = this.pruningStrategyOption.getChosenIndex();
        this.ensembleSize = this.ensembleSizeOption.getValue();
    }

    @Override // moa.classifiers.AbstractClassifier
    public void trainOnInstanceImpl(Instance instance) {
        this.index++;
        if (this.buffer == null) {
            this.buffer = new Instances(instance.dataset());
        }
        this.buffer.add(instance);
        if (this.index % this.periodOption.getValue() == 0) {
            this.index = 0L;
            double numInstances = this.buffer.numInstances();
            Classifier classifier = (Classifier) getPreparedClassOption(this.baseLearnerOption);
            classifier.resetLearning();
            if (this.ensemble.size() > 0) {
                double d = 0.0d;
                for (int i = 0; i < numInstances; i++) {
                    if (!correctlyClassifies(this.buffer.instance(i))) {
                        d += 1.0d / numInstances;
                    }
                }
                double d2 = 0.0d;
                for (int i2 = 0; i2 < numInstances; i2++) {
                    Instance instance2 = this.buffer.instance(i2);
                    double d3 = (1.0d / numInstances) * (correctlyClassifies(instance2) ? d : 1.0d);
                    instance2.setWeight(d3);
                    d2 += d3;
                }
                for (int i3 = 0; i3 < numInstances; i3++) {
                    Instance instance3 = this.buffer.instance(i3);
                    instance3.setWeight(instance3.weight() / d2);
                    Instance copy = instance3.copy();
                    copy.setWeight(1.0d);
                    classifier.trainOnInstance(copy);
                }
            } else {
                for (int i4 = 0; i4 < numInstances; i4++) {
                    Instance instance4 = this.buffer.instance(i4);
                    instance4.setWeight(1.0d / numInstances);
                    Instance copy2 = instance4.copy();
                    copy2.setWeight(1.0d);
                    classifier.trainOnInstance(copy2);
                }
            }
            this.ensemble.add(classifier);
            this.bkts.add(new ArrayList<>());
            this.wkts.add(new ArrayList<>());
            this.ensembleWeights.clear();
            int size = this.ensemble.size();
            double d4 = Double.NEGATIVE_INFINITY;
            int i5 = Integer.MIN_VALUE;
            for (int i6 = 1; i6 <= size; i6++) {
                double d5 = 0.0d;
                for (int i7 = 0; i7 < numInstances; i7++) {
                    Instance instance5 = this.buffer.instance(i7);
                    if (!this.ensemble.get(i6 - 1).correctlyClassifies(instance5)) {
                        d5 += instance5.weight();
                    }
                }
                if (i6 == size && d5 > 0.5d) {
                    Classifier classifier2 = (Classifier) getPreparedClassOption(this.baseLearnerOption);
                    classifier2.resetLearning();
                    this.ensemble.set(i6 - 1, classifier2);
                } else if (d5 > 0.5d) {
                    d5 = 0.5d;
                }
                if (d5 > d4) {
                    d4 = d5;
                    i5 = i6;
                }
                ArrayList<Double> arrayList = this.bkts.get(i6 - 1);
                arrayList.add(Double.valueOf(d5 / (1.0d - d5)));
                double exp = 1.0d / (1.0d + Math.exp((-this.slope) * ((size - i6) - this.crossingPoint)));
                ArrayList<Double> arrayList2 = this.wkts.get(i6 - 1);
                double d6 = 0.0d;
                Iterator<Double> it = arrayList2.iterator();
                while (it.hasNext()) {
                    d6 += it.next().doubleValue();
                }
                arrayList2.add(Double.valueOf(exp / (d6 + exp)));
                double d7 = 0.0d;
                for (int i8 = 0; i8 < arrayList2.size(); i8++) {
                    d7 += arrayList2.get(i8).doubleValue() * arrayList.get(i8).doubleValue();
                }
                this.ensembleWeights.add(Double.valueOf(Math.log(1.0d / d7)));
            }
            if (this.pruning == 1 && size > this.ensembleSize) {
                this.ensemble.remove(0);
                this.ensembleWeights.remove(0);
                this.bkts.remove(0);
                this.wkts.remove(0);
            } else if (this.pruning == 2 && size > this.ensembleSize) {
                this.ensemble.remove(i5 - 1);
                this.ensembleWeights.remove(i5 - 1);
                this.bkts.remove(i5 - 1);
                this.wkts.remove(i5 - 1);
            }
            this.buffer = new Instances(getModelContext());
        }
    }

    @Override // moa.learners.Learner
    public boolean isRandomizable() {
        return false;
    }

    @Override // moa.classifiers.AbstractClassifier, moa.classifiers.Classifier
    public double[] getVotesForInstance(Instance instance) {
        DoubleVector doubleVector = new DoubleVector();
        if (this.trainingWeightSeenByModel > 0.0d) {
            for (int i = 0; i < this.ensemble.size(); i++) {
                if (this.ensembleWeights.get(i).doubleValue() > 0.0d) {
                    DoubleVector doubleVector2 = new DoubleVector(this.ensemble.get(i).getVotesForInstance(instance));
                    if (doubleVector2.sumOfValues() > 0.0d) {
                        doubleVector2.normalize();
                        doubleVector2.scaleValues(this.ensembleWeights.get(i).doubleValue());
                        doubleVector.addValues(doubleVector2);
                    }
                }
            }
        }
        return doubleVector.getArrayRef();
    }

    @Override // moa.classifiers.AbstractClassifier
    public void getModelDescription(StringBuilder sb, int i) {
    }

    @Override // moa.classifiers.AbstractClassifier
    protected Measurement[] getModelMeasurementsImpl() {
        Measurement[] measurementArr = null;
        if (this.ensembleWeights != null) {
            measurementArr = new Measurement[this.ensembleWeights.size()];
            for (int i = 0; i < this.ensembleWeights.size(); i++) {
                measurementArr[i] = new Measurement("member weight " + (i + 1), this.ensembleWeights.get(i).doubleValue());
            }
        }
        return measurementArr;
    }
}
