package org.apache.samoa.learners.classifiers.ensemble;

import com.github.javacliparser.ClassOption;
import com.github.javacliparser.Configurable;
import com.github.javacliparser.IntOption;
import com.google.common.collect.ImmutableSet;
import java.util.Iterator;
import java.util.Set;
import org.apache.samoa.core.Processor;
import org.apache.samoa.instances.Instances;
import org.apache.samoa.learners.ClassificationLearner;
import org.apache.samoa.learners.Learner;
import org.apache.samoa.learners.classifiers.trees.VerticalHoeffdingTree;
import org.apache.samoa.topology.Stream;
import org.apache.samoa.topology.TopologyBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/samoa/learners/classifiers/ensemble/Bagging.class */
public class Bagging implements ClassificationLearner, Configurable {
    private static final long serialVersionUID = -2971850264864952099L;
    private static final Logger logger = LoggerFactory.getLogger(Bagging.class);
    public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', "Classifier to train.", Learner.class, VerticalHoeffdingTree.class.getName());
    public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's', "The number of models in the bag.", 10, 1, Integer.MAX_VALUE);
    private BaggingDistributorProcessor distributorP;
    private Stream[] ensembleStreams;
    protected Stream resultStream;
    private Instances dataset;
    protected Learner[] ensemble;
    private TopologyBuilder builder;

    protected void setLayout() {
        int value = this.ensembleSizeOption.getValue();
        this.distributorP = new BaggingDistributorProcessor();
        this.distributorP.setEnsembleSize(value);
        this.builder.addProcessor(this.distributorP, 1);
        this.ensemble = new Learner[value];
        for (int i = 0; i < value; i++) {
            try {
                this.ensemble[i] = (Learner) ClassOption.createObject(this.baseLearnerOption.getValueAsCLIString(), this.baseLearnerOption.getRequiredType());
                this.ensemble[i].init(this.builder, this.dataset, 1);
            } catch (Exception e) {
                logger.error("Unable to create members of the ensemble. Please check your CLI parameters");
                e.printStackTrace();
                throw new IllegalArgumentException(e);
            }
        }
        PredictionCombinerProcessor predictionCombinerProcessor = new PredictionCombinerProcessor();
        predictionCombinerProcessor.setEnsembleSize(value);
        this.builder.addProcessor(predictionCombinerProcessor, 1);
        this.resultStream = this.builder.createStream(predictionCombinerProcessor);
        predictionCombinerProcessor.setOutputStream(this.resultStream);
        for (Learner learner : this.ensemble) {
            Iterator<Stream> it = learner.getResultStreams().iterator();
            while (it.hasNext()) {
                this.builder.connectInputKeyStream(it.next(), predictionCombinerProcessor);
            }
        }
        this.ensembleStreams = new Stream[value];
        for (int i2 = 0; i2 < value; i2++) {
            this.ensembleStreams[i2] = this.builder.createStream(this.distributorP);
            this.builder.connectInputShuffleStream(this.ensembleStreams[i2], this.ensemble[i2].getInputProcessor());
        }
        this.distributorP.setOutputStreams(this.ensembleStreams);
    }

    @Override // org.apache.samoa.learners.Learner
    public void init(TopologyBuilder topologyBuilder, Instances instances, int i) {
        this.builder = topologyBuilder;
        this.dataset = instances;
        setLayout();
    }

    @Override // org.apache.samoa.learners.Learner
    public Processor getInputProcessor() {
        return this.distributorP;
    }

    @Override // org.apache.samoa.learners.Learner
    public Set<Stream> getResultStreams() {
        return ImmutableSet.of(this.resultStream);
    }
}
