/*
 * Decompiled with CFR 0.152.
 */
package org.streaminer.stream.classifier.bayes;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.streaminer.stream.classifier.AbstractClassifier;
import org.streaminer.stream.data.Data;
import org.streaminer.stream.model.Distribution;
import org.streaminer.stream.model.NominalDistributionModel;
import org.streaminer.stream.model.NumericalDistributionModel;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class NaiveBayesWOP
extends AbstractClassifier<Data, String> {
    private static final long serialVersionUID = 1095437834368310484L;
    static Logger log = LoggerFactory.getLogger(NaiveBayesWOP.class);
    String labelAttribute = null;
    Double laplaceCorrection = 1.0E-4;
    Distribution<String> classDistribution = null;
    Map<String, Distribution<?>> distributions = new HashMap();

    public NaiveBayesWOP() {
        this.classDistribution = this.createNominalDistribution();
    }

    public NaiveBayesWOP(String labelAttribute) {
        this();
        this.setLabelAttribute(labelAttribute);
    }

    public String getLabelAttribute() {
        return this.labelAttribute;
    }

    public void setLabelAttribute(String labelAttribute) {
        this.labelAttribute = labelAttribute;
    }

    public Double getLaplaceCorrection() {
        return this.laplaceCorrection;
    }

    public void setLaplaceCorrection(Double laplaceCorrection) {
        this.laplaceCorrection = laplaceCorrection;
    }

    @Override
    public String predict(Data item) {
        Double likelihood;
        LinkedHashMap<String, Double> classLikeli = new LinkedHashMap<String, Double>();
        log.debug("Predicting one of these classes: {}", this.classDistribution.getElements());
        for (String label : this.getClassDistribution().getElements()) {
            Double cl = this.getClassDistribution().getCount(label).doubleValue();
            log.debug("class likelihood for class '" + label + "' is {} / {}", (Object)cl, (Object)this.getClassDistribution().getCount());
            Double p_label = this.getClassDistribution().getHistogram().get(label) / (double)this.getClassDistribution().getCount().intValue();
            classLikeli.put(label, p_label);
            classLikeli.put(label, 1.0);
        }
        Double max = 0.0;
        String maxClass = null;
        Double totalLikelihood = 0.0;
        for (String clazz : classLikeli.keySet()) {
            likelihood = (Double)classLikeli.get(clazz);
            for (String attribute : item.keySet()) {
                if (this.labelAttribute.equals(attribute)) continue;
                Object value = item.get(attribute);
                if (value.getClass().equals(Double.class)) {
                    Distribution<?> dist = this.distributions.get(clazz);
                    likelihood = likelihood * dist.prob((Double)value);
                    continue;
                }
                String feature = this.getNominalCondition(attribute, item);
                Double d = this.distributions.get(clazz).getCount(feature).doubleValue();
                Double total = this.getClassDistribution().getCount(clazz).doubleValue();
                if (d == null || d == 0.0) {
                    d = this.laplaceCorrection;
                    total = total + this.laplaceCorrection;
                }
                log.debug("  likelihood for {}  is  {}  |" + clazz + " ", (Object)feature, (Object)(d / total));
                likelihood = likelihood * (d / total);
            }
            classLikeli.put(clazz, likelihood);
            totalLikelihood = totalLikelihood + likelihood;
        }
        for (String clazz : classLikeli.keySet()) {
            likelihood = (Double)classLikeli.get(clazz) / totalLikelihood;
            log.debug("probability for {} is {}", (Object)clazz, (Object)likelihood);
            if (maxClass != null && !(likelihood > max)) continue;
            maxClass = clazz;
            max = likelihood;
        }
        return maxClass;
    }

    public String getNominalCondition(String attribute, Data item) {
        return attribute + "='" + item.get(attribute) + "'";
    }

    @Override
    public void learn(Data item) {
        if (this.labelAttribute == null) {
            for (String name : item.keySet()) {
                if (!name.startsWith("_class")) continue;
                this.labelAttribute = name;
                break;
            }
        }
        if (item.get(this.labelAttribute) == null) {
            log.warn("Not processing unlabeled data item {}", (Object)item);
            return;
        }
        String clazz = ((Serializable)item.get(this.labelAttribute)).toString();
        log.debug("Learning from example with label={}", (Object)clazz);
        if (this.classDistribution == null) {
            this.classDistribution = new NominalDistributionModel<String>();
        }
        if (log.isDebugEnabled()) {
            log.debug("Classes: {}", this.classDistribution.getElements());
            for (String t : this.classDistribution.getElements()) {
                log.debug("    {}:  {}", (Object)t, (Object)this.classDistribution.getCount(t));
            }
        }
        for (String attribute : item.keySet()) {
            Object value;
            if (attribute.equalsIgnoreCase(this.labelAttribute)) {
                this.classDistribution.update(clazz);
                continue;
            }
            Object obj = item.get(attribute);
            if (obj.getClass().equals(Double.class)) {
                value = (Double)obj;
                log.debug("Handling numerical case ({}) with value  {}", obj, value);
                Distribution<Object> numDist = this.distributions.get(attribute);
                if (numDist == null) {
                    numDist = this.createNumericalDistribution();
                    log.debug("Creating new numerical distribution model for attribute {}", (Object)attribute);
                    this.distributions.put(attribute, numDist);
                }
                numDist.update(value);
                continue;
            }
            value = this.getNominalCondition(attribute, item);
            log.debug("Handling nominal case for [ {} | {} ]", value, (Object)("class=" + clazz));
            Distribution<Object> nomDist = this.distributions.get(clazz);
            if (nomDist == null) {
                nomDist = this.createNominalDistribution();
                log.debug("Creating new nominal distribution model for attribute {}, {}", (Object)attribute, (Object)("class=" + clazz));
                this.distributions.put(clazz, nomDist);
            }
            nomDist.update(value);
        }
    }

    public Distribution<String> getClassDistribution() {
        if (this.classDistribution == null) {
            this.classDistribution = this.createNominalDistribution();
        }
        return this.classDistribution;
    }

    public List<Distribution<Double>> getNumericalDistributions() {
        ArrayList<Distribution<Double>> numDists = new ArrayList<Distribution<Double>>();
        for (Distribution<?> d : this.distributions.values()) {
            if (!(d instanceof NumericalDistributionModel)) continue;
            numDists.add(d);
        }
        return numDists;
    }

    public Distribution<String> createNominalDistribution() {
        return new NominalDistributionModel<String>();
    }

    public Distribution<Double> createNumericalDistribution() {
        return new NumericalDistributionModel(1000, 1.0);
    }
}

