/*
 * Decompiled with CFR 0.152.
 */
package ws.palladian.kaggle.restaurants.dataset;

import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.function.Predicate;
import java.util.stream.Stream;
import ws.palladian.core.Instance;
import ws.palladian.core.InstanceBuilder;
import ws.palladian.core.dataset.AbstractDataset;
import ws.palladian.core.dataset.Dataset;
import ws.palladian.core.dataset.FeatureInformation;
import ws.palladian.core.dataset.FeatureInformationBuilder;
import ws.palladian.core.value.ImmutableIntegerValue;
import ws.palladian.core.value.ImmutableStringValue;
import ws.palladian.core.value.NullValue;
import ws.palladian.core.value.Value;
import ws.palladian.dataset.ImageValue;
import ws.palladian.helper.ProgressMonitor;
import ws.palladian.helper.ProgressReporter;
import ws.palladian.helper.collection.AbstractIterator;
import ws.palladian.helper.collection.CollectionHelper;
import ws.palladian.helper.io.CloseableIterator;
import ws.palladian.kaggle.restaurants.dataset.Label;
import ws.palladian.kaggle.restaurants.utils.Config;
import ws.palladian.utils.ModuloFilter;

public class YelpKaggleDatasetReader
extends AbstractDataset {
    private final File photoToBizCsv;
    private final File baseImagePath;
    private final Map<Integer, int[]> businessIdToLabels = new HashMap<Integer, int[]>();
    private final long numPhotos;
    private final boolean training;

    public YelpKaggleDatasetReader(File photoToBizCsv, File trainCsv, File baseImagePath) throws IOException {
        this.photoToBizCsv = Objects.requireNonNull(photoToBizCsv);
        this.baseImagePath = Objects.requireNonNull(baseImagePath);
        this.numPhotos = Files.lines(photoToBizCsv.toPath()).count() - 1L;
        boolean bl = this.training = trainCsv != null;
        if (trainCsv != null) {
            try (Stream<String> lines = Files.lines(trainCsv.toPath());){
                boolean[] first = new boolean[]{true};
                lines.forEach(line -> {
                    if (first[0]) {
                        first[0] = false;
                        return;
                    }
                    String[] split = line.split(",");
                    int businessId = Integer.valueOf(split[0]);
                    int[] labels = split.length == 2 ? Arrays.stream(split[1].split(" ")).mapToInt(Integer::valueOf).toArray() : new int[]{};
                    this.businessIdToLabels.put(businessId, labels);
                });
            }
        }
    }

    @Deprecated
    public YelpKaggleDatasetReader(File photoToBizCsv, File baseImagePath) throws IOException {
        this(photoToBizCsv, null, baseImagePath);
    }

    public CloseableIterator<Instance> iterator() {
        return new DatasetIterator();
    }

    public Iterable<Instance> subset(Subset subset, boolean train) {
        Objects.requireNonNull(subset);
        return CollectionHelper.filter((Iterable)((Object)this), (Predicate)new ModuloFilter(subset.mod, train ? 0 : 1));
    }

    public FeatureInformation getFeatureInformation() {
        FeatureInformationBuilder builder = new FeatureInformationBuilder();
        builder.set("photoId", ImmutableIntegerValue.class);
        builder.set("businessId", ImmutableStringValue.class);
        return builder.create();
    }

    public long size() {
        return this.numPhotos;
    }

    public static void main(String[] args) throws IOException {
        File trainCsv = Config.getFilePath("dataset.yelp.restaurants.train.csv");
        File trainPhotoToBizCsv = Config.getFilePath("dataset.yelp.restaurants.train.photoToBizCsv");
        File baseTrainPath = Config.getFilePath("dataset.yelp.restaurants.train.photos");
        YelpKaggleDatasetReader trainSet = new YelpKaggleDatasetReader(trainPhotoToBizCsv, trainCsv, baseTrainPath);
        Dataset trainingInstances = trainSet.subset(BusinessFilter.TRAIN);
        Dataset validationInstances = trainSet.subset(BusinessFilter.VALIDATE);
        System.out.println("# training = " + CollectionHelper.count(trainingInstances.iterator()));
        System.out.println("# validation = " + CollectionHelper.count(validationInstances.iterator()));
    }

    private final class DatasetIterator
    extends AbstractIterator<Instance>
    implements CloseableIterator<Instance> {
        private final Stream<String> lines;
        private final Iterator<String> iterator;
        private final ProgressReporter progress;

        DatasetIterator() {
            try {
                this.lines = Files.lines(YelpKaggleDatasetReader.this.photoToBizCsv.toPath());
                this.iterator = this.lines.iterator();
                this.iterator.next();
                this.progress = new ProgressMonitor(0.5);
                this.progress.startTask(YelpKaggleDatasetReader.this.photoToBizCsv.toString(), YelpKaggleDatasetReader.this.numPhotos);
            }
            catch (IOException e) {
                throw new IllegalStateException(e);
            }
        }

        public void close() throws IOException {
            this.lines.close();
        }

        protected Instance getNext() throws AbstractIterator.Finished {
            if (this.iterator.hasNext()) {
                this.progress.increment();
                String[] split = this.iterator.next().split(",");
                int photoId = Integer.valueOf(split[0]);
                String businessIdString = split[1];
                InstanceBuilder instanceBuilder = new InstanceBuilder();
                instanceBuilder.set("photoId", photoId);
                instanceBuilder.set("businessId", businessIdString);
                if (YelpKaggleDatasetReader.this.training) {
                    int[] labels = (int[])YelpKaggleDatasetReader.this.businessIdToLabels.get(Integer.valueOf(businessIdString));
                    if (labels == null) {
                        throw new IllegalStateException("No entry for businessId " + businessIdString);
                    }
                    int[] idx = new int[]{0};
                    while (idx[0] < Label.values().length) {
                        String label = Label.getById(idx[0]).toString();
                        boolean value = Arrays.stream(labels).anyMatch(v -> v == idx[0]);
                        instanceBuilder.set(label, value);
                        idx[0] = idx[0] + 1;
                    }
                }
                instanceBuilder.set("image", (Value)new ImageValue(new File(YelpKaggleDatasetReader.this.baseImagePath, photoId + ".jpg")));
                return instanceBuilder.create(false);
            }
            throw FINISHED;
        }
    }

    public static enum BusinessFilter implements Predicate<Instance>
    {
        TRAIN(true),
        VALIDATE(false);

        private final boolean b;

        private BusinessFilter(boolean b) {
            this.b = b;
        }

        @Override
        public boolean test(Instance item) {
            Value businessIdValue = (Value)item.getVector().get((Object)"businessId");
            if (businessIdValue == null || businessIdValue == NullValue.NULL) {
                throw new IllegalArgumentException("businessId is missing");
            }
            return Integer.valueOf(businessIdValue.toString()) % 2 == 0 ^ this.b;
        }
    }

    public static enum Subset {
        SMALL(200),
        MEDIUM(20),
        FULL(2);

        public final int mod;

        private Subset(int mod) {
            this.mod = mod;
        }

        public String toString() {
            return this.name().toLowerCase();
        }
    }
}

