package org.apache.samoa.learners.classifiers.ensemble;

import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import org.apache.samoa.core.ContentEvent;
import org.apache.samoa.core.SerializableInstance;
import org.apache.samoa.instances.Instance;
import org.apache.samoa.learners.InstanceContentEvent;
import org.apache.samoa.learners.ResultContentEvent;
import org.apache.samoa.moa.core.DoubleVector;
import org.apache.samoa.moa.core.Utils;
import org.apache.samoa.topology.Stream;

/* loaded from: input_file:org/apache/samoa/learners/classifiers/ensemble/BoostingPredictionCombinerProcessor.class */
public class BoostingPredictionCombinerProcessor extends PredictionCombinerProcessor {
    private static final long serialVersionUID = -1606045723451191232L;
    protected double[] scms;
    protected double[] swms;
    protected Random random;
    protected int trainingWeightSeenByModel;
    protected Map<Integer, DoubleVector> mapPredictions;
    private Stream trainingStream;

    @Override // org.apache.samoa.learners.classifiers.ensemble.PredictionCombinerProcessor, org.apache.samoa.core.Processor
    public boolean process(ContentEvent contentEvent) {
        ResultContentEvent resultContentEvent = (ResultContentEvent) contentEvent;
        double[] classVotes = resultContentEvent.getClassVotes();
        int instanceIndex = (int) resultContentEvent.getInstanceIndex();
        addStatisticsForInstanceReceived(instanceIndex, resultContentEvent.getClassifierIndex(), classVotes, 1);
        addPredictions(instanceIndex, resultContentEvent, classVotes);
        if (!resultContentEvent.isLastEvent() && !hasAllVotesArrivedInstance(instanceIndex)) {
            return false;
        }
        DoubleVector doubleVector = this.mapVotesforInstanceReceived.get(Integer.valueOf(instanceIndex));
        if (doubleVector == null) {
            doubleVector = new DoubleVector();
        }
        ResultContentEvent resultContentEvent2 = new ResultContentEvent(resultContentEvent.getInstanceIndex(), resultContentEvent.getInstance(), resultContentEvent.getClassId(), doubleVector.getArrayCopy(), resultContentEvent.isLastEvent());
        resultContentEvent2.setEvaluationIndex(resultContentEvent.getEvaluationIndex());
        this.outputStream.put(resultContentEvent2);
        clearStatisticsInstance(instanceIndex);
        computeBoosting(resultContentEvent, instanceIndex);
        return true;
    }

    @Override // org.apache.samoa.learners.classifiers.ensemble.PredictionCombinerProcessor
    protected double getEnsembleMemberWeight(int i) {
        double d = this.swms[i] / (this.scms[i] + this.swms[i]);
        if (d == 0.0d || d > 0.5d) {
            return 0.0d;
        }
        return Math.log(1.0d / (d / (1.0d - d)));
    }

    @Override // org.apache.samoa.learners.classifiers.ensemble.PredictionCombinerProcessor
    public void reset() {
        this.random = new Random();
        this.trainingWeightSeenByModel = 0;
        this.scms = new double[this.ensembleSize];
        this.swms = new double[this.ensembleSize];
    }

    private boolean correctlyClassifies(int i, Instance instance, int i2) {
        return ((int) this.mapPredictions.get(Integer.valueOf(i2)).getValue(i)) == ((int) instance.classValue());
    }

    private void addPredictions(int i, ResultContentEvent resultContentEvent, double[] dArr) {
        if (this.mapPredictions == null) {
            this.mapPredictions = new HashMap();
        }
        DoubleVector doubleVector = this.mapPredictions.get(Integer.valueOf(i));
        if (doubleVector == null) {
            doubleVector = new DoubleVector();
        }
        doubleVector.setValue(resultContentEvent.getClassifierIndex(), Utils.maxIndex(dArr));
        this.mapPredictions.put(Integer.valueOf(i), doubleVector);
    }

    private void computeBoosting(ResultContentEvent resultContentEvent, int i) {
        double d;
        double d2;
        double d3;
        double d4;
        double d5 = 1.0d;
        for (int i2 = 0; i2 < this.ensembleSize; i2++) {
            double d6 = d5;
            SerializableInstance resultContentEvent2 = resultContentEvent.getInstance();
            if (d6 > 0.0d) {
                Instance copy = resultContentEvent2.copy();
                copy.setWeight(resultContentEvent2.weight() * d6);
                InstanceContentEvent instanceContentEvent = new InstanceContentEvent(resultContentEvent.getInstanceIndex(), copy, true, false);
                instanceContentEvent.setClassifierIndex(i2);
                instanceContentEvent.setEvaluationIndex(resultContentEvent.getEvaluationIndex());
                this.trainingStream.put(instanceContentEvent);
            }
            if (correctlyClassifies(i2, resultContentEvent2, i)) {
                double[] dArr = this.scms;
                int i3 = i2;
                dArr[i3] = dArr[i3] + d5;
                d = d5;
                d2 = this.trainingWeightSeenByModel;
                d3 = 2.0d;
                d4 = this.scms[i2];
            } else {
                double[] dArr2 = this.swms;
                int i4 = i2;
                dArr2[i4] = dArr2[i4] + d5;
                d = d5;
                d2 = this.trainingWeightSeenByModel;
                d3 = 2.0d;
                d4 = this.swms[i2];
            }
            d5 = d * (d2 / (d3 * d4));
        }
    }

    public Stream getTrainingStream() {
        return this.trainingStream;
    }

    public void setTrainingStream(Stream stream) {
        this.trainingStream = stream;
    }
}
