package moa.classifiers.multitarget;

import com.github.javacliparser.IntOption;
import com.yahoo.labs.samoa.instances.DenseInstance;
import com.yahoo.labs.samoa.instances.Instance;
import com.yahoo.labs.samoa.instances.Instances;
import com.yahoo.labs.samoa.instances.InstancesHeader;
import com.yahoo.labs.samoa.instances.MultiLabelInstance;
import com.yahoo.labs.samoa.instances.MultiLabelPrediction;
import com.yahoo.labs.samoa.instances.Prediction;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.AbstractMultiLabelLearner;
import moa.classifiers.Classifier;
import moa.classifiers.rules.AMRulesRegressor;
import moa.core.DoubleVector;
import moa.core.FastVector;
import moa.core.Measurement;
import moa.core.StringUtils;
import moa.options.ClassOption;
import moa.streams.InstanceStream;

/* loaded from: input_file:moa/classifiers/multitarget/BasicMultiLabelLearner.class */
public class BasicMultiLabelLearner extends AbstractMultiLabelLearner {
    private static final long serialVersionUID = 1;
    public ClassOption baseLearnerOption;
    protected Classifier[] ensemble;
    protected InstancesHeader[] header;
    public IntOption randomSeedOption = new IntOption("randomSeedOption", 'r', "randomSeedOption", 1, Integer.MIN_VALUE, Integer.MAX_VALUE);
    protected boolean hasStarted = false;

    public BasicMultiLabelLearner() {
        this.randomSeedOption = this.randomSeedOption;
        init();
    }

    protected void init() {
        this.baseLearnerOption = new ClassOption("baseLearner", 'l', "Classifier to train.", Classifier.class, AMRulesRegressor.class.getName());
    }

    @Override // moa.classifiers.AbstractClassifier
    public void resetLearningImpl() {
        this.hasStarted = false;
        if (this.ensemble != null) {
            for (int i = 0; i < this.ensemble.length; i++) {
                this.ensemble[i].resetLearning();
            }
        }
    }

    @Override // moa.classifiers.AbstractMultiLabelLearner, moa.classifiers.MultiLabelLearner
    public void trainOnInstanceImpl(MultiLabelInstance multiLabelInstance) {
        if (!this.hasStarted) {
            this.ensemble = new Classifier[multiLabelInstance.numberOutputTargets()];
            Classifier classifier = (Classifier) getPreparedClassOption(this.baseLearnerOption);
            if (classifier.isRandomizable()) {
                classifier.setRandomSeed(this.randomSeed);
            }
            classifier.resetLearning();
            for (int i = 0; i < this.ensemble.length; i++) {
                this.ensemble[i] = classifier.copy();
            }
            this.hasStarted = true;
        }
        for (int i2 = 0; i2 < this.ensemble.length; i2++) {
            this.ensemble[i2].trainOnInstance(transformInstance(multiLabelInstance, i2));
        }
    }

    protected Instance transformInstance(MultiLabelInstance multiLabelInstance, int i) {
        if (this.header == null) {
            this.header = new InstancesHeader[this.ensemble.length];
        }
        if (this.header[i] == null) {
            FastVector fastVector = new FastVector();
            for (int i2 = 0; i2 < multiLabelInstance.numInputAttributes(); i2++) {
                fastVector.addElement(multiLabelInstance.inputAttribute(i2));
            }
            fastVector.addElement(multiLabelInstance.outputAttribute(i));
            this.header[i] = new InstancesHeader(new Instances(getCLICreationString(InstanceStream.class), fastVector, 0));
            this.header[i].setClassIndex(fastVector.size() - 1);
            this.ensemble[i].setModelContext(this.header[i]);
        }
        int numInputAttributes = this.header[i].numInputAttributes();
        double[] dArr = new double[numInputAttributes + 1];
        for (int i3 = 0; i3 < numInputAttributes; i3++) {
            dArr[i3] = multiLabelInstance.valueInputAttribute(i3);
        }
        DenseInstance denseInstance = new DenseInstance(1.0d, dArr);
        denseInstance.setDataset(this.header[i]);
        denseInstance.setClassValue(multiLabelInstance.valueOutputAttribute(i));
        return denseInstance;
    }

    @Override // moa.learners.Learner
    public boolean isRandomizable() {
        return true;
    }

    @Override // moa.classifiers.AbstractClassifier
    protected Measurement[] getModelMeasurementsImpl() {
        Measurement[] modelMeasurements = ((Classifier) getPreparedClassOption(this.baseLearnerOption)).getModelMeasurements();
        int length = modelMeasurements.length;
        Measurement[] measurementArr = new Measurement[length];
        if (this.ensemble != null) {
            int length2 = this.ensemble.length;
            for (int i = 0; i < length; i++) {
                double d = 0.0d;
                for (int i2 = 0; i2 < length2; i2++) {
                    d += this.ensemble[i2].getModelMeasurements()[i].getValue();
                }
                measurementArr[i] = new Measurement("Sum " + modelMeasurements[i].getName(), d);
            }
        } else {
            for (int i3 = 0; i3 < modelMeasurements.length; i3++) {
                measurementArr[i3] = modelMeasurements[i3];
            }
        }
        return measurementArr;
    }

    @Override // moa.classifiers.AbstractClassifier
    public void getModelDescription(StringBuilder sb, int i) {
        if (this.ensemble.length <= 0 || !(this.ensemble[0] instanceof AbstractClassifier)) {
            return;
        }
        for (int i2 = 0; i2 < this.ensemble.length; i2++) {
            StringUtils.appendIndented(sb, i + 1, "\nModel output attribute #" + i2);
            ((AbstractClassifier) this.ensemble[i2]).getModelDescription(sb, i + 1);
        }
    }

    @Override // moa.classifiers.AbstractMultiLabelLearner, moa.classifiers.MultiLabelLearner
    public Prediction getPredictionForInstance(MultiLabelInstance multiLabelInstance) {
        MultiLabelPrediction multiLabelPrediction = null;
        double[] dArr = new double[multiLabelInstance.numClasses()];
        if (this.hasStarted) {
            multiLabelPrediction = new MultiLabelPrediction(this.ensemble.length);
            new DoubleVector();
            for (int i = 0; i < this.ensemble.length; i++) {
                multiLabelPrediction.setVote(i, 0, this.ensemble[i].getVotesForInstance(transformInstance(multiLabelInstance, i))[0]);
            }
        }
        return multiLabelPrediction;
    }
}
