/*
 * Decompiled with CFR 0.152.
 */
package ws.palladian.kaggle.restaurants.utils;

import ws.palladian.classification.evaluation.roc.RocCurves;
import ws.palladian.core.CategoryEntries;
import ws.palladian.core.Classifier;
import ws.palladian.core.Instance;
import ws.palladian.core.Learner;
import ws.palladian.core.Model;
import ws.palladian.core.dataset.Dataset;
import ws.palladian.helper.math.ConfusionMatrix;
import ws.palladian.helper.math.ThresholdAnalyzer;

public final class ClassifierCombination<M extends Model> {
    private final Learner<M> learner;
    private final Classifier<M> classifier;

    public <LC extends Learner<M> & Classifier<M>> ClassifierCombination(LC learnerClassifier) {
        this.learner = learnerClassifier;
        this.classifier = learnerClassifier;
    }

    public ClassifierCombination(Learner<M> learner, Classifier<M> classifier) {
        this.learner = learner;
        this.classifier = classifier;
    }

    @Deprecated
    public ConfusionMatrix evaluate(Dataset trainingInstances, Dataset testingInstances) {
        return this.runEvaluation(trainingInstances, testingInstances, "true").getConfusionMatrix();
    }

    public EvaluationResult<M> runEvaluation(Dataset trainingInstances, Dataset testingInstances) {
        return this.runEvaluation(trainingInstances, testingInstances, "true");
    }

    public EvaluationResult<M> runEvaluation(Dataset trainingInstances, Dataset testingInstances, String trueCategory) {
        long start = System.currentTimeMillis();
        Model model = this.learner.train(trainingInstances);
        long trainingTime = System.currentTimeMillis() - start;
        ConfusionMatrix confusionMatrix = new ConfusionMatrix();
        ThresholdAnalyzer thresholdAnalyzer = new ThresholdAnalyzer(100);
        RocCurves.RocCurvesBuilder rocCurvesBuilder = new RocCurves.RocCurvesBuilder();
        start = System.currentTimeMillis();
        for (Instance testInstance : testingInstances) {
            CategoryEntries classification = this.classifier.classify(testInstance.getVector(), model);
            String classifiedCategory = classification.getMostLikelyCategory();
            String actualCategory = testInstance.getCategory();
            double probability = classification.getProbability(trueCategory);
            boolean isTrueCategory = actualCategory.equals(trueCategory);
            confusionMatrix.add(actualCategory, classifiedCategory);
            thresholdAnalyzer.add(isTrueCategory, probability);
            rocCurvesBuilder.add(isTrueCategory, probability);
        }
        long testingTime = System.currentTimeMillis() - start;
        RocCurves rocCurves = rocCurvesBuilder.create();
        return new EvaluationResult<Model>(confusionMatrix, model, trainingTime, testingTime, thresholdAnalyzer, rocCurves);
    }

    public Learner<M> getLearner() {
        return this.learner;
    }

    public Classifier<M> getClassifier() {
        return this.classifier;
    }

    public String toString() {
        return this.classifier.toString();
    }

    public static final class EvaluationResult<M extends Model> {
        private final ConfusionMatrix confusionMatrix;
        private final M model;
        private final long trainingTime;
        private final long testingTime;
        private final ThresholdAnalyzer thresholdAnalyzer;
        private RocCurves rocCurves;

        EvaluationResult(ConfusionMatrix confusionMatrix, M model, long trainingTime, long testingTime, ThresholdAnalyzer thresholdAnalyzer, RocCurves rocCurves) {
            this.confusionMatrix = confusionMatrix;
            this.model = model;
            this.trainingTime = trainingTime;
            this.testingTime = testingTime;
            this.thresholdAnalyzer = thresholdAnalyzer;
            this.rocCurves = rocCurves;
        }

        public ConfusionMatrix getConfusionMatrix() {
            return this.confusionMatrix;
        }

        public M getModel() {
            return this.model;
        }

        public long getTrainingTime() {
            return this.trainingTime;
        }

        public long getTestingTime() {
            return this.testingTime;
        }

        public ThresholdAnalyzer getThresholdAnalyzer() {
            return this.thresholdAnalyzer;
        }

        public RocCurves getRocCurves() {
            return this.rocCurves;
        }
    }
}

