/*
 * Decompiled with CFR 0.152.
 */
package ws.palladian.kaggle.restaurants;

import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import ws.palladian.core.Classifier;
import ws.palladian.core.Learner;
import ws.palladian.core.Model;
import ws.palladian.core.dataset.Dataset;
import ws.palladian.core.dataset.DatasetTransformer;
import ws.palladian.core.dataset.DatasetWithFeatureAsCategory;
import ws.palladian.core.dataset.IdentityDatasetTransformer;
import ws.palladian.helper.ProgressMonitor;
import ws.palladian.helper.ProgressReporter;
import ws.palladian.helper.date.DateHelper;
import ws.palladian.helper.io.FileHelper;
import ws.palladian.helper.math.ConfusionMatrix;
import ws.palladian.kaggle.restaurants.dataset.Label;
import ws.palladian.kaggle.restaurants.utils.ClassifierCombination;

public class Experimenter {
    private static final Logger LOGGER = LoggerFactory.getLogger(Experimenter.class);
    private final Dataset training;
    private final Dataset testing;
    private final File resultsDirectory;
    private final List<String> classLabels = new ArrayList<String>();
    private final List<Experiment> experiments = new ArrayList<Experiment>();
    private final List<DatasetTransformer> transformers = new ArrayList<DatasetTransformer>();
    private final String trueClass;

    public Experimenter(Dataset training, Dataset testing, File resultsDirectory) {
        this(training, testing, resultsDirectory, "true");
    }

    public Experimenter(Dataset training, Dataset testing, File resultsDirectory, String trueClass) {
        this.training = Objects.requireNonNull(training);
        this.testing = Objects.requireNonNull(testing);
        this.resultsDirectory = Objects.requireNonNull(resultsDirectory);
        this.trueClass = trueClass;
    }

    public Experimenter withClassLabel(Label label) {
        this.classLabels.add(label.toString());
        return this;
    }

    public Experimenter withClassLabels(Label ... labels) {
        this.classLabels.addAll(Arrays.stream(labels).map(l -> l.toString()).collect(Collectors.toList()));
        return this;
    }

    public Experimenter withTransformer(DatasetTransformer transformer) {
        this.transformers.add(transformer);
        return this;
    }

    public <M extends Model> Experimenter withClassifier(Learner<M> learner, Classifier<M> classifier, Collection<? extends Predicate<? super String>> featureSets) {
        this.experiments.add(new Experiment(new ClassifierCombination<M>(learner, classifier), featureSets));
        return this;
    }

    public <M extends Model> Experimenter withClassifier(Learner<M> learner, Classifier<M> classifier, Predicate<? super String> featureSet) {
        return this.withClassifier(learner, classifier, Collections.singleton(featureSet));
    }

    public <LC extends Learner<M> & Classifier<M>, M extends Model> Experimenter withClassifier(LC learnerClassifier, Collection<? extends Predicate<? super String>> featureSets) {
        this.experiments.add(new Experiment(new ClassifierCombination(learnerClassifier), featureSets));
        return this;
    }

    public void run() {
        this.run(false);
    }

    public void dryRun() {
        this.run(true);
    }

    private void run(boolean dryRun) {
        if (!this.classLabels.isEmpty()) {
            LOGGER.info("# class labels: {}", (Object)this.classLabels.size());
        }
        LOGGER.info("# total combinations: {}", (Object)this.getNumCombinations());
        ProgressMonitor progress = new ProgressMonitor();
        progress.startTask("Experiments", (long)this.getNumCombinations());
        if (this.classLabels.isEmpty()) {
            this.runExpriments((ProgressReporter)progress, null, this.training, this.testing, dryRun);
        } else {
            for (String classLabel : this.classLabels) {
                if (dryRun) {
                    System.out.println("class label: " + classLabel);
                }
                DatasetWithFeatureAsCategory classTraining = new DatasetWithFeatureAsCategory(this.training, classLabel);
                DatasetWithFeatureAsCategory classTesting = new DatasetWithFeatureAsCategory(this.testing, classLabel);
                this.runExpriments((ProgressReporter)progress, classLabel, (Dataset)classTraining, (Dataset)classTesting, dryRun);
            }
        }
    }

    public int getNumCombinations() {
        int numCombinations = 0;
        int numClassLabels = this.classLabels.size() > 0 ? this.classLabels.size() : 1;
        int numTransformers = this.transformers.size() > 0 ? this.transformers.size() : 1;
        for (Experiment experiment : this.experiments) {
            numCombinations += numClassLabels * experiment.featureSets.size() * numTransformers;
        }
        return numCombinations;
    }

    private void runExpriments(ProgressReporter progress, String classLabel, Dataset classTraining, Dataset classTesting, boolean dryRun) {
        ArrayList<DatasetTransformer> transformers = new ArrayList<DatasetTransformer>(this.transformers);
        if (this.transformers.isEmpty()) {
            transformers.add((DatasetTransformer)IdentityDatasetTransformer.INSTANCE);
        }
        for (Experiment experiment : this.experiments) {
            if (dryRun) {
                System.out.println("\tclassifier: " + experiment.classifierCombination);
            }
            for (Predicate<? super String> predicate : experiment.featureSets) {
                Dataset experimentTraining = classTraining.filterFeatures(predicate);
                Dataset experimentTesting = classTesting.filterFeatures(predicate);
                Set featureNames = experimentTraining.getFeatureInformation().getFeatureNames();
                if (dryRun) {
                    System.out.println("\t\tfeature set: " + predicate + " (" + featureNames.size() + ")");
                }
                for (DatasetTransformer transformer : transformers) {
                    if (dryRun) {
                        if (transformers.size() <= 1) continue;
                        System.out.println("\t\t\ttransformer: " + transformer);
                        continue;
                    }
                    Dataset experimentTrainingTransformed = experimentTraining.transform(transformer);
                    Dataset experimentTestingTransformed = experimentTesting.transform(transformer);
                    ClassifierCombination.EvaluationResult<?> evaluationResult = experiment.classifierCombination.runEvaluation(experimentTrainingTransformed, experimentTestingTransformed, this.trueClass);
                    ConfusionMatrix confusionMatrix = evaluationResult.getConfusionMatrix();
                    StringBuilder result = new StringBuilder();
                    if (classLabel != null) {
                        result.append("Class:       ").append(classLabel).append('\n');
                    }
                    result.append("Learner:     ").append(experiment.classifierCombination.getLearner()).append('\n');
                    result.append("Classifier:  ").append(experiment.classifierCombination.getClassifier()).append('\n');
                    result.append("\n\n");
                    result.append("Features:    ").append(featureNames.size()).append('\n');
                    result.append("Filter:      ").append(predicate).append('\n');
                    result.append("Transformer: ").append(transformer).append('\n');
                    result.append('\n');
                    for (String featureName : featureNames) {
                        result.append(featureName).append('\n');
                    }
                    result.append('\n');
                    long secondsTraining = TimeUnit.MILLISECONDS.toSeconds(evaluationResult.getTrainingTime());
                    long secondsTesting = TimeUnit.MILLISECONDS.toSeconds(evaluationResult.getTestingTime());
                    result.append("Training:    ").append(secondsTraining).append(" seconds\n");
                    result.append("Testing:     ").append(secondsTesting).append(" seconds\n");
                    result.append("\n\n");
                    result.append("ROC AUC:     ").append(evaluationResult.getRocCurves().getAreaUnderCurve());
                    result.append("\n\n");
                    result.append(confusionMatrix.toString());
                    result.append("\n\n").append("Threshold analysis:\n");
                    result.append(evaluationResult.getThresholdAnalyzer().toString());
                    String timestamp = DateHelper.getCurrentDatetime();
                    File resultFile = new File(this.resultsDirectory, "result-" + timestamp + ".txt");
                    FileHelper.writeToFile((String)resultFile.getAbsolutePath(), (CharSequence)result);
                    File summaryCsv = new File(this.resultsDirectory, "_summary.csv");
                    StringBuilder csvResult = new StringBuilder();
                    if (!summaryCsv.exists()) {
                        if (classLabel != null) {
                            csvResult.append("classLabel;");
                        }
                        csvResult.append("details;learner;classifier;featureSet;numFeatures;transformer;timeTraining;timeTesting;precision;recall;f1;accuracy;superiority;matthewsCorrelationCoefficient;rocAuc\n");
                    }
                    if (classLabel != null) {
                        csvResult.append(classLabel).append(';');
                    }
                    csvResult.append(resultFile.getName()).append(';');
                    csvResult.append(experiment.classifierCombination.getLearner()).append(';');
                    csvResult.append(experiment.classifierCombination.getClassifier()).append(';');
                    csvResult.append(predicate).append(';');
                    csvResult.append(featureNames.size()).append(';');
                    String transformerString = transformer.toString();
                    int newlineIdx = transformer.toString().indexOf(10);
                    if (newlineIdx != -1) {
                        transformerString = transformerString.substring(0, newlineIdx);
                    }
                    csvResult.append(transformerString).append(';');
                    csvResult.append(secondsTraining).append(';');
                    csvResult.append(secondsTesting).append(';');
                    csvResult.append(confusionMatrix.getPrecision(this.trueClass)).append(';');
                    csvResult.append(confusionMatrix.getRecall(this.trueClass)).append(';');
                    csvResult.append(confusionMatrix.getF(1.0, this.trueClass)).append(';');
                    csvResult.append(confusionMatrix.getAccuracy()).append(';');
                    csvResult.append(confusionMatrix.getSuperiority()).append(';');
                    double mcc = -1.0;
                    try {
                        mcc = confusionMatrix.getMatthewsCorrelationCoefficient();
                    }
                    catch (Exception e) {
                        e.printStackTrace();
                    }
                    csvResult.append(mcc).append(';');
                    csvResult.append(evaluationResult.getRocCurves().getAreaUnderCurve()).append('\n');
                    FileHelper.appendFile((String)summaryCsv.getAbsolutePath(), (CharSequence)csvResult);
                    File serializedFile = new File(this.resultsDirectory, "model-" + timestamp + ".ser.gz");
                    FileHelper.trySerialize(evaluationResult.getModel(), (String)serializedFile.getAbsolutePath());
                    try {
                        evaluationResult.getRocCurves().saveCurves(new File(this.resultsDirectory, "roc-" + timestamp + ".png"));
                    }
                    catch (Exception e) {
                        throw new IllegalStateException("Could not save ROC curves", e);
                    }
                    progress.increment();
                }
            }
        }
    }

    private static final class Experiment {
        final ClassifierCombination<?> classifierCombination;
        final Collection<? extends Predicate<? super String>> featureSets;

        Experiment(ClassifierCombination<?> classifierCombination, Collection<? extends Predicate<? super String>> featureSets) {
            this.classifierCombination = Objects.requireNonNull(classifierCombination);
            this.featureSets = Objects.requireNonNull(featureSets);
        }
    }
}

