package org.apache.samoa.learners.classifiers;

import java.util.HashMap;
import java.util.Map;
import org.apache.samoa.instances.Instance;
import org.apache.samoa.instances.Instances;
import org.apache.samoa.moa.classifiers.core.attributeclassobservers.GaussianNumericAttributeClassObserver;
import org.apache.samoa.moa.core.GaussianEstimator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/samoa/learners/classifiers/NaiveBayes.class */
public class NaiveBayes implements LocalLearner {
    private static final double ADDITIVE_SMOOTHING_FACTOR = 1.0E-20d;
    private static final long serialVersionUID = 1325775209672996822L;
    private static final Logger logger = LoggerFactory.getLogger(NaiveBayes.class);
    protected Map<Integer, GaussianNumericAttributeClassObserver> attributeObservers;
    protected Map<Integer, Double> classInstances;
    protected Map<Integer, Double> classPrototypes;
    protected long instancesSeen = 0;

    protected int getNumberOfClasses() {
        return this.classInstances.size();
    }

    public NaiveBayes() {
        resetLearning();
    }

    @Override // org.apache.samoa.learners.classifiers.LocalLearner
    public LocalLearner create() {
        return new NaiveBayes();
    }

    @Override // org.apache.samoa.learners.classifiers.LocalLearner
    public double[] getVotesForInstance(Instance instance) {
        double[] dArr = new double[getNumberOfClasses()];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = Math.log(getPrior(i));
            for (int i2 = 0; i2 < instance.numAttributes(); i2++) {
                int index = instance.index(i2);
                if (index != instance.classIndex()) {
                    Double valueOf = Double.valueOf(instance.value(index));
                    GaussianNumericAttributeClassObserver gaussianNumericAttributeClassObserver = this.attributeObservers.get(Integer.valueOf(index));
                    GaussianEstimator gaussianEstimator = null;
                    if (gaussianNumericAttributeClassObserver != null && gaussianNumericAttributeClassObserver.getEstimator(i) != null) {
                        gaussianEstimator = gaussianNumericAttributeClassObserver.getEstimator(i);
                    }
                    int i3 = i;
                    dArr[i3] = dArr[i3] + Math.log(gaussianEstimator != null ? gaussianEstimator.probabilityDensity(valueOf.doubleValue()) : 1.0E-20d);
                }
            }
            if (this.classPrototypes.get(Integer.valueOf(i)) != null) {
                int i4 = i;
                dArr[i4] = dArr[i4] + Math.log(this.classPrototypes.get(Integer.valueOf(i)).doubleValue());
            }
        }
        return dArr;
    }

    private double getPrior(int i) {
        Double d = this.classInstances.get(Integer.valueOf(i));
        if (d == null || d.doubleValue() == 0.0d) {
            return 0.0d;
        }
        return (d.doubleValue() * 1.0d) / this.instancesSeen;
    }

    @Override // org.apache.samoa.learners.classifiers.LocalLearner
    public void resetLearning() {
        this.instancesSeen = 0L;
        this.classInstances = new HashMap();
        this.classPrototypes = new HashMap();
        this.attributeObservers = new HashMap();
    }

    @Override // org.apache.samoa.learners.classifiers.LocalLearner
    public void trainOnInstance(Instance instance) {
        int classValue = (int) instance.classValue();
        Double d = this.classInstances.get(Integer.valueOf(classValue));
        if (d == null) {
            d = Double.valueOf(0.0d);
        }
        this.classInstances.put(Integer.valueOf(classValue), Double.valueOf(d.doubleValue() + instance.weight()));
        Double d2 = this.classPrototypes.get(Integer.valueOf(classValue));
        if (d2 == null) {
            d2 = Double.valueOf(1.0d);
        }
        for (int i = 0; i < instance.numAttributes(); i++) {
            int index = instance.index(i);
            if (index != instance.classIndex()) {
                GaussianNumericAttributeClassObserver gaussianNumericAttributeClassObserver = this.attributeObservers.get(Integer.valueOf(index));
                if (gaussianNumericAttributeClassObserver == null) {
                    gaussianNumericAttributeClassObserver = new GaussianNumericAttributeClassObserver();
                    this.attributeObservers.put(Integer.valueOf(index), gaussianNumericAttributeClassObserver);
                }
                GaussianEstimator estimator = gaussianNumericAttributeClassObserver.getEstimator(classValue);
                if (estimator != null) {
                    d2 = Double.valueOf(d2.doubleValue() - estimator.probabilityDensity(0.0d));
                }
                gaussianNumericAttributeClassObserver.observeAttributeClass(instance.valueSparse(i), (int) instance.classValue(), instance.weight());
                d2 = Double.valueOf(d2.doubleValue() + gaussianNumericAttributeClassObserver.getEstimator(classValue).probabilityDensity(0.0d));
            }
        }
        this.classPrototypes.put(Integer.valueOf(classValue), d2);
        this.instancesSeen++;
    }

    @Override // org.apache.samoa.learners.classifiers.LocalLearner
    public void setDataset(Instances instances) {
    }
}
