package com.datumbox.framework.core.machinelearning.classification;

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.concurrency.StreamMethods;
import com.datumbox.framework.common.dataobjects.AssociativeArray;
import com.datumbox.framework.common.dataobjects.Dataframe;
import com.datumbox.framework.common.dataobjects.Record;
import com.datumbox.framework.common.dataobjects.TypeInference;
import com.datumbox.framework.common.persistentstorage.interfaces.DatabaseConnector;
import com.datumbox.framework.core.machinelearning.common.abstracts.algorithms.AbstractNaiveBayes;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClassifier;
import com.datumbox.framework.core.machinelearning.common.interfaces.PredictParallelizable;
import com.datumbox.framework.core.statistics.descriptivestatistics.Descriptives;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:com/datumbox/framework/core/machinelearning/classification/BernoulliNaiveBayes.class */
public class BernoulliNaiveBayes extends AbstractNaiveBayes<ModelParameters, TrainingParameters, ValidationMetrics> {

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/classification/BernoulliNaiveBayes$ModelParameters.class */
    public static class ModelParameters extends AbstractNaiveBayes.AbstractModelParameters {
        private static final long serialVersionUID = 1;
        private Map<Object, Double> sumOfLog1minusProb;

        protected ModelParameters(DatabaseConnector databaseConnector) {
            super(databaseConnector);
            this.sumOfLog1minusProb = new HashMap();
        }

        public Map<Object, Double> getSumOfLog1minusProb() {
            return this.sumOfLog1minusProb;
        }

        protected void setSumOfLog1minusProb(Map<Object, Double> map) {
            this.sumOfLog1minusProb = map;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/classification/BernoulliNaiveBayes$TrainingParameters.class */
    public static class TrainingParameters extends AbstractNaiveBayes.AbstractTrainingParameters {
        private static final long serialVersionUID = 1;
    }

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/classification/BernoulliNaiveBayes$ValidationMetrics.class */
    public static class ValidationMetrics extends AbstractClassifier.AbstractValidationMetrics {
        private static final long serialVersionUID = 1;
    }

    public BernoulliNaiveBayes(String str, Configuration configuration) {
        super(str, configuration, ModelParameters.class, TrainingParameters.class, ValidationMetrics.class, true);
    }

    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.algorithms.AbstractNaiveBayes, com.datumbox.framework.core.machinelearning.common.interfaces.PredictParallelizable
    public PredictParallelizable.Prediction _predictRecord(Record record) {
        ModelParameters modelParameters = (ModelParameters) kb().getModelParameters();
        Map<List<Object>, Double> logLikelihoods = modelParameters.getLogLikelihoods();
        Map<Object, Double> logPriors = modelParameters.getLogPriors();
        Set<Object> classes = modelParameters.getClasses();
        Map<Object, Double> sumOfLog1minusProb = modelParameters.getSumOfLog1minusProb();
        Object next = classes.iterator().next();
        AssociativeArray associativeArray = new AssociativeArray(new HashMap(logPriors));
        for (Map.Entry<Object, Double> entry : sumOfLog1minusProb.entrySet()) {
            Object key = entry.getKey();
            associativeArray.put(key, Double.valueOf(associativeArray.getDouble(key).doubleValue() + entry.getValue().doubleValue()));
        }
        for (Map.Entry entry2 : record.getX().entrySet()) {
            Object key2 = entry2.getKey();
            if (logLikelihoods.containsKey(Arrays.asList(key2, next))) {
                AssociativeArray associativeArray2 = new AssociativeArray();
                for (Object obj : classes) {
                    associativeArray2.put(obj, logLikelihoods.get(Arrays.asList(key2, obj)));
                }
                Double d = TypeInference.toDouble(entry2.getValue());
                if (d != null && d.doubleValue() != 0.0d) {
                    for (Map.Entry entry3 : associativeArray2.entrySet()) {
                        Object key3 = entry3.getKey();
                        Double d2 = TypeInference.toDouble(entry3.getValue());
                        associativeArray.put(key3, Double.valueOf((associativeArray.getDouble(key3).doubleValue() + Math.log(d2.doubleValue())) - Math.log(1.0d - d2.doubleValue())));
                    }
                }
            }
        }
        Object selectedClassFromClassScores = getSelectedClassFromClassScores(associativeArray);
        Descriptives.normalizeExp(associativeArray);
        return new PredictParallelizable.Prediction(selectedClassFromClassScores, associativeArray);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.algorithms.AbstractNaiveBayes, com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer
    public void _fit(Dataframe dataframe) {
        ModelParameters modelParameters = (ModelParameters) kb().getModelParameters();
        int intValue = modelParameters.getN().intValue();
        int intValue2 = modelParameters.getD().intValue();
        ((TrainingParameters) kb().getTrainingParameters()).setMultiProbabilityWeighted(false);
        Map<List<Object>, Double> logLikelihoods = modelParameters.getLogLikelihoods();
        Map<Object, Double> logPriors = modelParameters.getLogPriors();
        Set<Object> classes = modelParameters.getClasses();
        Map<Object, Double> sumOfLog1minusProb = modelParameters.getSumOfLog1minusProb();
        HashMap hashMap = new HashMap();
        Iterator it = dataframe.iterator();
        while (it.hasNext()) {
            Object y = ((Record) it.next()).getY();
            if (classes.add(y)) {
                logPriors.put(y, Double.valueOf(1.0d));
                hashMap.put(y, 0);
                sumOfLog1minusProb.put(y, Double.valueOf(0.0d));
            } else {
                logPriors.put(y, Double.valueOf(logPriors.get(y).doubleValue() + 1.0d));
            }
        }
        this.streamExecutor.forEach(StreamMethods.stream(dataframe.getXDataTypes().keySet().stream(), isParallelized()), obj -> {
            Iterator it2 = classes.iterator();
            while (it2.hasNext()) {
                logLikelihoods.put(Arrays.asList(obj, it2.next()), Double.valueOf(0.0d));
            }
        });
        this.streamExecutor.forEach(StreamMethods.stream(dataframe.stream(), isParallelized()), record -> {
            Object y2 = record.getY();
            int i = 0;
            for (Map.Entry entry : record.getX().entrySet()) {
                Object key = entry.getKey();
                Double d = TypeInference.toDouble(entry.getValue());
                if (d != null && d.doubleValue() > 0.0d) {
                    List asList = Arrays.asList(key, y2);
                    logLikelihoods.put(asList, Double.valueOf(((Double) logLikelihoods.get(asList)).doubleValue() + 1.0d));
                    i++;
                }
            }
            synchronized (hashMap) {
                hashMap.put(y2, Integer.valueOf(((Integer) hashMap.get(y2)).intValue() + i));
            }
        });
        for (Map.Entry<Object, Double> entry : logPriors.entrySet()) {
            logPriors.put(entry.getKey(), Double.valueOf(Math.log(entry.getValue().doubleValue() / intValue)));
        }
        for (Object obj2 : classes) {
            sumOfLog1minusProb.put(obj2, Double.valueOf(sumOfLog1minusProb.get(obj2).doubleValue() + this.streamExecutor.sum(StreamMethods.stream(dataframe.getXDataTypes().keySet().stream(), isParallelized()).mapToDouble(obj3 -> {
                List asList = Arrays.asList(obj3, obj2);
                Double valueOf = Double.valueOf((((Double) logLikelihoods.get(asList)).doubleValue() + 1.0d) / (((Integer) hashMap.get(obj2)).intValue() + intValue2));
                logLikelihoods.put(asList, valueOf);
                return Math.log(1.0d - valueOf.doubleValue());
            }))));
        }
    }
}
