package com.datumbox.framework.core.machinelearning.classification;

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.concurrency.ForkJoinStream;
import com.datumbox.framework.common.concurrency.StreamMethods;
import com.datumbox.framework.common.dataobjects.AssociativeArray;
import com.datumbox.framework.common.dataobjects.Dataframe;
import com.datumbox.framework.common.dataobjects.Record;
import com.datumbox.framework.common.dataobjects.TypeInference;
import com.datumbox.framework.common.persistentstorage.interfaces.BigMap;
import com.datumbox.framework.common.persistentstorage.interfaces.DatabaseConnector;
import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClassifier;
import com.datumbox.framework.core.machinelearning.common.interfaces.PredictParallelizable;
import com.datumbox.framework.core.machinelearning.common.interfaces.TrainParallelizable;
import com.datumbox.framework.core.machinelearning.common.validators.SoftMaxRegressionValidator;
import com.datumbox.framework.core.statistics.descriptivestatistics.Descriptives;
import com.datumbox.framework.core.utilities.regularization.ElasticNetRegularizer;
import com.datumbox.framework.core.utilities.regularization.L1Regularizer;
import com.datumbox.framework.core.utilities.regularization.L2Regularizer;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:com/datumbox/framework/core/machinelearning/classification/SoftMaxRegression.class */
public class SoftMaxRegression extends AbstractClassifier<ModelParameters, TrainingParameters, ValidationMetrics> implements PredictParallelizable, TrainParallelizable {
    private boolean parallelized;
    protected final ForkJoinStream streamExecutor;

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/classification/SoftMaxRegression$ModelParameters.class */
    public static class ModelParameters extends AbstractClassifier.AbstractModelParameters {
        private static final long serialVersionUID = 1;

        @BigMap(mapType = DatabaseConnector.MapType.HASHMAP, storageHint = DatabaseConnector.StorageHint.IN_MEMORY, concurrent = true)
        private Map<List<Object>, Double> thitas;

        protected ModelParameters(DatabaseConnector databaseConnector) {
            super(databaseConnector);
        }

        public Map<List<Object>, Double> getThitas() {
            return this.thitas;
        }

        protected void setThitas(Map<List<Object>, Double> map) {
            this.thitas = map;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/classification/SoftMaxRegression$TrainingParameters.class */
    public static class TrainingParameters extends AbstractTrainer.AbstractTrainingParameters {
        private static final long serialVersionUID = 1;
        private int totalIterations = 100;
        private double learningRate = 0.1d;
        private double l1 = 0.0d;
        private double l2 = 0.0d;

        public int getTotalIterations() {
            return this.totalIterations;
        }

        public void setTotalIterations(int i) {
            this.totalIterations = i;
        }

        public double getLearningRate() {
            return this.learningRate;
        }

        public void setLearningRate(double d) {
            this.learningRate = d;
        }

        public double getL1() {
            return this.l1;
        }

        public void setL1(double d) {
            this.l1 = d;
        }

        public double getL2() {
            return this.l2;
        }

        public void setL2(double d) {
            this.l2 = d;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/classification/SoftMaxRegression$ValidationMetrics.class */
    public static class ValidationMetrics extends AbstractClassifier.AbstractValidationMetrics {
        private static final long serialVersionUID = 1;
        private double SSE = 0.0d;
        private double CountRSquare = 0.0d;

        public double getSSE() {
            return this.SSE;
        }

        public void setSSE(double d) {
            this.SSE = d;
        }

        public double getCountRSquare() {
            return this.CountRSquare;
        }

        public void setCountRSquare(double d) {
            this.CountRSquare = d;
        }
    }

    public SoftMaxRegression(String str, Configuration configuration) {
        super(str, configuration, ModelParameters.class, TrainingParameters.class, ValidationMetrics.class, new SoftMaxRegressionValidator());
        this.parallelized = true;
        this.streamExecutor = new ForkJoinStream(kb().getConf().getConcurrencyConfig());
    }

    @Override // com.datumbox.framework.core.machinelearning.common.interfaces.Parallelizable
    public boolean isParallelized() {
        return this.parallelized;
    }

    @Override // com.datumbox.framework.core.machinelearning.common.interfaces.Parallelizable
    public void setParallelized(boolean z) {
        this.parallelized = z;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractModeler
    public void _predictDataset(Dataframe dataframe) {
        DatabaseConnector dbc = kb().getDbc();
        Map<Integer, PredictParallelizable.Prediction> bigMap = dbc.getBigMap("tmp_resultsBuffer", DatabaseConnector.MapType.HASHMAP, DatabaseConnector.StorageHint.IN_DISK, true, true);
        _predictDatasetParallel(dataframe, bigMap, kb().getConf().getConcurrencyConfig());
        dbc.dropBigMap("tmp_resultsBuffer", bigMap);
    }

    @Override // com.datumbox.framework.core.machinelearning.common.interfaces.PredictParallelizable
    public PredictParallelizable.Prediction _predictRecord(Record record) {
        ModelParameters modelParameters = (ModelParameters) kb().getModelParameters();
        Set<Object> classes = modelParameters.getClasses();
        Map<List<Object>, Double> thitas = modelParameters.getThitas();
        AssociativeArray associativeArray = new AssociativeArray();
        for (Object obj : classes) {
            associativeArray.put(obj, calculateClassScore(record.getX(), obj, thitas));
        }
        Object selectedClassFromClassScores = getSelectedClassFromClassScores(associativeArray);
        Descriptives.normalizeExp(associativeArray);
        return new PredictParallelizable.Prediction(selectedClassFromClassScores, associativeArray);
    }

    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer
    protected void _fit(Dataframe dataframe) {
        ModelParameters modelParameters = (ModelParameters) kb().getModelParameters();
        TrainingParameters trainingParameters = (TrainingParameters) kb().getTrainingParameters();
        Map<List<Object>, Double> thitas = modelParameters.getThitas();
        Set<Object> classes = modelParameters.getClasses();
        Iterator it = dataframe.iterator();
        while (it.hasNext()) {
            classes.add(((Record) it.next()).getY());
        }
        Iterator<Object> it2 = classes.iterator();
        while (it2.hasNext()) {
            thitas.put(Arrays.asList("~CONSTANT", it2.next()), Double.valueOf(0.0d));
        }
        this.streamExecutor.forEach(StreamMethods.stream(dataframe.getXDataTypes().keySet().stream(), isParallelized()), obj -> {
            Iterator it3 = classes.iterator();
            while (it3.hasNext()) {
                thitas.putIfAbsent(Arrays.asList(obj, it3.next()), Double.valueOf(0.0d));
            }
        });
        double d = Double.POSITIVE_INFINITY;
        double learningRate = trainingParameters.getLearningRate();
        int totalIterations = trainingParameters.getTotalIterations();
        DatabaseConnector dbc = kb().getDbc();
        for (int i = 0; i < totalIterations; i++) {
            this.logger.debug("Iteration {}", Integer.valueOf(i));
            Map<? extends List<Object>, ? extends Double> bigMap = dbc.getBigMap("tmp_newThitas", DatabaseConnector.MapType.HASHMAP, DatabaseConnector.StorageHint.IN_MEMORY, true, true);
            bigMap.putAll(thitas);
            batchGradientDescent(dataframe, bigMap, learningRate);
            double calculateError = calculateError(dataframe, bigMap);
            if (calculateError > d) {
                learningRate /= 2.0d;
            } else {
                learningRate *= 1.05d;
                d = calculateError;
                thitas.clear();
                thitas.putAll(bigMap);
            }
            dbc.dropBigMap("tmp_newThitas", bigMap);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClassifier, com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractModeler
    public ValidationMetrics validateModel(Dataframe dataframe) {
        ValidationMetrics validationMetrics = (ValidationMetrics) super.validateModel(dataframe);
        validationMetrics.setCountRSquare(validationMetrics.getAccuracy());
        validationMetrics.setSSE(calculateError(dataframe, ((ModelParameters) kb().getModelParameters()).getThitas()));
        return validationMetrics;
    }

    private void batchGradientDescent(Dataframe dataframe, Map<List<Object>, Double> map, double d) {
        ModelParameters modelParameters = (ModelParameters) kb().getModelParameters();
        double intValue = d / modelParameters.getN().intValue();
        Map<List<Object>, Double> thitas = modelParameters.getThitas();
        Set<Object> classes = modelParameters.getClasses();
        this.streamExecutor.forEach(StreamMethods.stream(dataframe.stream(), isParallelized()), record -> {
            AssociativeArray hypothesisFunction = hypothesisFunction(record.getX(), thitas);
            for (Object obj : classes) {
                double doubleValue = hypothesisFunction.getDouble(obj).doubleValue();
                double d2 = intValue * (record.getY().equals(obj) ? 1.0d - doubleValue : -doubleValue);
                synchronized (map) {
                    for (Map.Entry entry : record.getX().entrySet()) {
                        Double d3 = TypeInference.toDouble(entry.getValue());
                        List asList = Arrays.asList(entry.getKey(), obj);
                        map.put(asList, Double.valueOf(((Double) map.get(asList)).doubleValue() + (d2 * d3.doubleValue())));
                    }
                    List asList2 = Arrays.asList("~CONSTANT", obj);
                    map.put(asList2, Double.valueOf(((Double) map.get(asList2)).doubleValue() + d2));
                }
            }
        });
        double l1 = ((TrainingParameters) kb().getTrainingParameters()).getL1();
        double l2 = ((TrainingParameters) kb().getTrainingParameters()).getL2();
        if (l1 > 0.0d && l2 > 0.0d) {
            ElasticNetRegularizer.updateWeights(l1, l2, d, thitas, map);
        } else if (l1 > 0.0d) {
            L1Regularizer.updateWeights(l1, d, thitas, map);
        } else if (l2 > 0.0d) {
            L2Regularizer.updateWeights(l2, d, thitas, map);
        }
    }

    private Double calculateClassScore(AssociativeArray associativeArray, Object obj, Map<List<Object>, Double> map) {
        double doubleValue = map.get(Arrays.asList("~CONSTANT", obj)).doubleValue();
        for (Map.Entry entry : associativeArray.entrySet()) {
            Double d = TypeInference.toDouble(entry.getValue());
            Double d2 = map.get(Arrays.asList(entry.getKey(), obj));
            if (d2 != null) {
                doubleValue += d2.doubleValue() * d.doubleValue();
            }
        }
        return Double.valueOf(doubleValue);
    }

    private double calculateError(Dataframe dataframe, Map<List<Object>, Double> map) {
        double intValue = (-this.streamExecutor.sum(StreamMethods.stream(dataframe.stream(), isParallelized()).mapToDouble(record -> {
            return Math.log(hypothesisFunction(record.getX(), map).getDouble(record.getY()).doubleValue());
        }))) / ((ModelParameters) kb().getModelParameters()).getN().intValue();
        double l1 = ((TrainingParameters) kb().getTrainingParameters()).getL1();
        double l2 = ((TrainingParameters) kb().getTrainingParameters()).getL2();
        if (l1 > 0.0d && l2 > 0.0d) {
            intValue += ElasticNetRegularizer.estimatePenalty(l1, l2, map);
        } else if (l1 > 0.0d) {
            intValue += L1Regularizer.estimatePenalty(l1, map);
        } else if (l2 > 0.0d) {
            intValue += L2Regularizer.estimatePenalty(l2, map);
        }
        return intValue;
    }

    private AssociativeArray hypothesisFunction(AssociativeArray associativeArray, Map<List<Object>, Double> map) {
        Set<Object> classes = ((ModelParameters) kb().getModelParameters()).getClasses();
        AssociativeArray associativeArray2 = new AssociativeArray();
        for (Object obj : classes) {
            double doubleValue = calculateClassScore(associativeArray, obj, map).doubleValue();
            if (doubleValue <= 0.0d) {
                doubleValue = 1.0E-8d;
            }
            associativeArray2.put(obj, Double.valueOf(doubleValue));
        }
        Descriptives.normalize(associativeArray2);
        return associativeArray2;
    }
}
