/*
 * Decompiled with CFR 0.152.
 */
package ws.palladian.kaggle.restaurants.aggregation;

import java.io.File;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import ws.palladian.classification.evaluation.ConfusionMatrixEvaluator;
import ws.palladian.classification.liblinear.LibLinearClassifier;
import ws.palladian.classification.liblinear.LibLinearLearner;
import ws.palladian.classification.liblinear.LibLinearModel;
import ws.palladian.classification.utils.CsvDatasetReader;
import ws.palladian.classification.utils.CsvDatasetReaderConfig;
import ws.palladian.core.CategoryEntries;
import ws.palladian.core.Classifier;
import ws.palladian.core.FeatureVector;
import ws.palladian.core.Instance;
import ws.palladian.core.InstanceBuilder;
import ws.palladian.core.Model;
import ws.palladian.core.dataset.Dataset;
import ws.palladian.core.dataset.DefaultDataset;
import ws.palladian.core.value.ImmutableStringValue;
import ws.palladian.core.value.NumericValue;
import ws.palladian.core.value.Value;
import ws.palladian.helper.collection.DefaultMultiMap;
import ws.palladian.helper.collection.LazyMap;
import ws.palladian.helper.collection.MultiMap;
import ws.palladian.helper.date.DateHelper;
import ws.palladian.helper.functional.Factory;
import ws.palladian.helper.io.FileHelper;
import ws.palladian.helper.math.ConfusionMatrix;
import ws.palladian.helper.math.FatStats;
import ws.palladian.helper.math.Stats;
import ws.palladian.kaggle.restaurants.dataset.Label;
import ws.palladian.kaggle.restaurants.utils.Config;

public class MultiClassificationAggregator {
    public static void main(String[] args) throws IOException {
        MultiClassificationAggregator.classify();
    }

    private static void classify() throws IOException {
        Object probability;
        String businessId;
        Dataset classificationResult1 = MultiClassificationAggregator.read("/Volumes/iMac HD/Research/Yelp_Kaggle_Restaurants/data/joined_classified_test_with_bizIds_2016-04-02_13-36-15.csv");
        Dataset classificationResult2 = MultiClassificationAggregator.read("/Volumes/iMac HD/Research/Yelp_Kaggle_Restaurants/data/joined_classified_test_with_bizIds_2016-04-13_00-28-31.csv");
        LazyMap allStats = new LazyMap((Factory)new Factory<Map<String, Stats>>(){

            public Map<String, Stats> create() {
                return new LazyMap(FatStats.FACTORY);
            }
        });
        for (Label[] instance : classificationResult1) {
            businessId = ((Value)instance.getVector().get((Object)"businessId")).toString();
            for (Label label : Label.values()) {
                probability = (NumericValue)instance.getVector().get((Object)label.toString());
                ((Stats)((Map)allStats.get(businessId)).get(label.toString() + "_1")).add((Number)probability.getDouble());
            }
        }
        for (Label[] instance : classificationResult2) {
            businessId = ((Value)instance.getVector().get((Object)"businessId")).toString();
            for (Label label : Label.values()) {
                probability = (NumericValue)instance.getVector().get((Object)label.toString());
                ((Stats)((Map)allStats.get(businessId)).get(label.toString() + "_2")).add((Number)probability.getDouble());
            }
        }
        HashMap<Label, LibLinearModel> models = new HashMap<Label, LibLinearModel>();
        for (Label labelToTrain : Label.values()) {
            File modelPath = Config.getFilePath("model.aggregation.multi." + labelToTrain.toString().toLowerCase());
            LibLinearModel model = (LibLinearModel)FileHelper.deserialize((String)modelPath.getAbsolutePath());
            models.put(labelToTrain, model);
        }
        StringBuilder result = new StringBuilder();
        result.append("business_id,labels").append('\n');
        for (Map.Entry instance : allStats.entrySet()) {
            String businessId2 = (String)instance.getKey();
            InstanceBuilder builder = new InstanceBuilder();
            Map stats = (Map)allStats.get(businessId2);
            for (Map.Entry entry : stats.entrySet()) {
                builder.set(((String)entry.getKey()).toLowerCase() + "_mean_probability", ((Stats)entry.getValue()).getMean());
                builder.set(((String)entry.getKey()).toLowerCase() + "_max_probability", ((Stats)entry.getValue()).getMax());
            }
            FeatureVector fv = builder.create();
            StringBuilder stringBuilder = new StringBuilder();
            stringBuilder.append(businessId2);
            stringBuilder.append(',');
            for (Label l : Label.values()) {
                boolean isTrue;
                CategoryEntries res = new LibLinearClassifier().classify(fv, (LibLinearModel)models.get((Object)l));
                boolean bl = isTrue = res.getProbability("true") > 0.5;
                if (!isTrue) continue;
                stringBuilder.append(l.getLabelId()).append(' ');
            }
            String line = stringBuilder.toString().trim();
            result.append(line).append('\n');
        }
        FileHelper.writeToFile((String)("/Users/pk/Desktop/submission_multi_" + DateHelper.getCurrentDatetime() + ".csv"), (CharSequence)result.toString());
    }

    private static void train() throws IOException {
        NumericValue probability;
        HashMap<String, Double> currentProbabilities;
        Dataset classificationResult1 = MultiClassificationAggregator.read("/Volumes/iMac HD/Research/Yelp_Kaggle_Restaurants/data/classified_train_true_2016-04-02_08-30-41.csv");
        Dataset classificationResult2 = MultiClassificationAggregator.read("/Volumes/iMac HD/Research/Yelp_Kaggle_Restaurants/data/classified_train_true_2016-04-12_22-59-48.csv");
        MultiMap businessIdResults = DefaultMultiMap.createWithSet();
        for (Instance instance : classificationResult1) {
            currentProbabilities = new HashMap<String, Double>();
            for (Label label : Label.values()) {
                probability = (NumericValue)instance.getVector().get((Object)label.toString());
                currentProbabilities.put(label.toString() + "_1", probability.getDouble());
            }
            String string = ((Value)instance.getVector().get((Object)"businessId")).toString();
            businessIdResults.add((Object)string, currentProbabilities);
        }
        for (Instance instance : classificationResult2) {
            currentProbabilities = new HashMap();
            for (Label label : Label.values()) {
                probability = (NumericValue)instance.getVector().get((Object)label.toString());
                currentProbabilities.put(label.toString() + "_2", probability.getDouble());
            }
            String string = ((Value)instance.getVector().get((Object)"businessId")).toString();
            businessIdResults.add((Object)string, currentProbabilities);
        }
        for (Label label : Label.values()) {
            HashMap allStats = new HashMap();
            for (Map.Entry businessPhotos : businessIdResults.entrySet()) {
                Collection input = (Collection)businessPhotos.getValue();
                LazyMap statsMap = new LazyMap(FatStats.FACTORY);
                for (Map singleInput : input) {
                    for (Map.Entry inputEntry : singleInput.entrySet()) {
                        ((Stats)statsMap.get(inputEntry.getKey())).add((Number)inputEntry.getValue());
                    }
                }
                allStats.put(businessPhotos.getKey(), statsMap);
            }
            File trainCsv = Config.getFilePath("dataset.yelp.restaurants.train.csv");
            CsvDatasetReaderConfig.Builder trainConfigBuilder = CsvDatasetReaderConfig.filePath((File)trainCsv);
            trainConfigBuilder.readClassFromLastColumn(false);
            trainConfigBuilder.setFieldSeparator(',');
            trainConfigBuilder.treatAsNullValue("");
            trainConfigBuilder.parser("business_id", ImmutableStringValue.PARSER);
            CsvDatasetReader trainCsvReader = trainConfigBuilder.create();
            ArrayList<Instance> labelTrainingInstances = new ArrayList<Instance>();
            for (Instance instance : trainCsvReader) {
                String businessId2 = ((Value)instance.getVector().get((Object)"business_id")).toString();
                InstanceBuilder builder = new InstanceBuilder();
                Map stats = (Map)allStats.get(businessId2);
                for (Map.Entry statsEntry : stats.entrySet()) {
                    builder.set(((String)statsEntry.getKey()).toLowerCase() + "_mean_probability", ((Stats)statsEntry.getValue()).getMean());
                    builder.set(((String)statsEntry.getKey()).toLowerCase() + "_max_probability", ((Stats)statsEntry.getValue()).getMax());
                }
                String labelsString = ((Value)instance.getVector().get((Object)"labels")).toString();
                HashSet<String> labelSet = new HashSet<String>(Arrays.asList(labelsString.split(" ")));
                boolean positiveClass = labelSet.contains(label.getLabelId() + "");
                Instance inst = builder.create(positiveClass);
                labelTrainingInstances.add(inst);
            }
            DefaultDataset dataset = new DefaultDataset(labelTrainingInstances);
            LibLinearModel model = new LibLinearLearner().train((Dataset)dataset);
            ConfusionMatrix cf = new ConfusionMatrixEvaluator().evaluate((Classifier)new LibLinearClassifier(), (Model)model, (Dataset)dataset);
            System.out.println((Object)label);
            System.out.println(cf);
            FileHelper.serialize((Serializable)model, (String)("/Users/pk/Desktop/aggregation_multi_" + label.toString() + "_" + DateHelper.getCurrentDatetime() + ".ser.gz"));
        }
    }

    private static Dataset read(String filePath) {
        CsvDatasetReaderConfig.Builder configBuilder = CsvDatasetReaderConfig.filePath((File)new File(filePath));
        configBuilder.parser("businessId", ImmutableStringValue.PARSER);
        return configBuilder.create();
    }
}

