package moa.classifiers.meta;

import com.github.javacliparser.FlagOption;
import com.github.javacliparser.IntOption;
import com.yahoo.labs.samoa.instances.Attribute;
import com.yahoo.labs.samoa.instances.DenseInstance;
import com.yahoo.labs.samoa.instances.Instance;
import com.yahoo.labs.samoa.instances.Instances;
import com.yahoo.labs.samoa.instances.InstancesHeader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.Classifier;
import moa.classifiers.MultiClassClassifier;
import moa.core.Measurement;
import moa.core.Utils;
import moa.options.ClassOption;

/* loaded from: input_file:moa/classifiers/meta/TemporallyAugmentedClassifier.class */
public class TemporallyAugmentedClassifier extends AbstractClassifier implements MultiClassClassifier {
    private static final long serialVersionUID = 1;
    protected Classifier baseLearner;
    protected double[] oldLabels;
    protected Instances header;
    public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', "Classifier to train.", Classifier.class, "trees.HoeffdingTree");
    public IntOption numOldLabelsOption = new IntOption("numOldLabels", 'n', "The number of old labels to add to each example.", 1, 0, Integer.MAX_VALUE);
    public FlagOption labelDelayOption = new FlagOption("labelDelay", 'd', "Labels arrive with Delay. Use predictions instead of true Labels.");

    @Override // moa.classifiers.AbstractClassifier, moa.options.AbstractOptionHandler, moa.options.OptionHandler
    public String getPurposeString() {
        return "Add some old labels to every instance";
    }

    @Override // moa.classifiers.AbstractClassifier
    public void resetLearningImpl() {
        this.baseLearner = (Classifier) getPreparedClassOption(this.baseLearnerOption);
        this.oldLabels = new double[this.numOldLabelsOption.getValue()];
        this.header = null;
        this.baseLearner.resetLearning();
    }

    @Override // moa.classifiers.AbstractClassifier
    public void trainOnInstanceImpl(Instance instance) {
        this.baseLearner.trainOnInstance(extendWithOldLabels(instance));
        if (this.labelDelayOption.isSet()) {
            return;
        }
        addOldLabel(instance.classValue());
    }

    public void addOldLabel(double d) {
        int length = this.oldLabels.length;
        if (length > 0) {
            for (int i = 1; i < length; i++) {
                this.oldLabels[i - 1] = this.oldLabels[i];
            }
            this.oldLabels[length - 1] = d;
        }
    }

    public void initHeader(Instances instances) {
        int value = this.numOldLabelsOption.getValue();
        Attribute classAttribute = instances.classAttribute();
        ArrayList arrayList = new ArrayList();
        int numValues = classAttribute.numValues();
        for (int i = 0; i < numValues; i++) {
            arrayList.add(classAttribute.value(i));
        }
        ArrayList arrayList2 = new ArrayList(value + instances.numAttributes());
        for (int i2 = 0; i2 < value; i2++) {
            arrayList2.add(new Attribute(classAttribute.name() + "_" + i2, arrayList));
        }
        for (int i3 = 0; i3 < instances.numAttributes(); i3++) {
            Attribute attribute = instances.attribute(i3);
            Attribute attribute2 = attribute.isNominal() ? new Attribute(attribute.name(), attribute.getAttributeValues()) : null;
            if (attribute.isNumeric()) {
                attribute2 = new Attribute(attribute.name());
            }
            if (attribute2 != null) {
                arrayList2.add(attribute2);
            }
        }
        this.header = new Instances("extended_" + instances.getRelationName(), arrayList2, 0);
        this.header.setClassIndex(value + instances.classIndex());
    }

    public Instance extendWithOldLabels(Instance instance) {
        if (this.header == null) {
            initHeader(instance.dataset());
            this.baseLearner.setModelContext(new InstancesHeader(this.header));
        }
        int length = this.oldLabels.length;
        if (length == 0) {
            return instance;
        }
        double[] doubleArray = instance.toDoubleArray();
        double[] copyOfRange = Arrays.copyOfRange(this.oldLabels, 0, length + doubleArray.length);
        System.arraycopy(doubleArray, 0, copyOfRange, length, doubleArray.length);
        DenseInstance denseInstance = new DenseInstance(instance.weight(), copyOfRange);
        denseInstance.setDataset(this.header);
        return denseInstance;
    }

    @Override // moa.classifiers.AbstractClassifier, moa.classifiers.Classifier
    public double[] getVotesForInstance(Instance instance) {
        double[] votesForInstance = this.baseLearner.getVotesForInstance(extendWithOldLabels(instance));
        if (this.labelDelayOption.isSet()) {
            addOldLabel(Utils.maxIndex(votesForInstance));
        }
        return votesForInstance;
    }

    @Override // moa.learners.Learner
    public boolean isRandomizable() {
        return false;
    }

    @Override // moa.classifiers.AbstractClassifier
    protected Measurement[] getModelMeasurementsImpl() {
        LinkedList linkedList = new LinkedList();
        Measurement[] modelMeasurements = ((AbstractClassifier) this.baseLearner).getModelMeasurements();
        if (modelMeasurements != null) {
            for (Measurement measurement : modelMeasurements) {
                linkedList.add(measurement);
            }
        }
        return (Measurement[]) linkedList.toArray(new Measurement[linkedList.size()]);
    }

    @Override // moa.classifiers.AbstractClassifier
    public void getModelDescription(StringBuilder sb, int i) {
    }

    @Override // moa.AbstractMOAObject
    public String toString() {
        return "TemporallyAugmentedClassifier using " + this.numOldLabelsOption.getValue() + " labels\n" + this.baseLearner;
    }
}
