/*
 * Decompiled with CFR 0.152.
 */
package org.openimaj.experiment.validation.cross;

import gnu.trove.list.array.TIntArrayList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.openimaj.data.RandomData;
import org.openimaj.data.dataset.GroupedDataset;
import org.openimaj.data.dataset.ListBackedDataset;
import org.openimaj.data.dataset.ListDataset;
import org.openimaj.data.dataset.MapBackedDataset;
import org.openimaj.experiment.dataset.util.DatasetAdaptors;
import org.openimaj.experiment.validation.DefaultValidationData;
import org.openimaj.experiment.validation.ValidationData;
import org.openimaj.experiment.validation.cross.CrossValidationIterable;
import org.openimaj.experiment.validation.cross.CrossValidator;
import org.openimaj.util.list.AcceptingListView;
import org.openimaj.util.list.SkippingListView;
import org.openimaj.util.pair.IntObjectPair;

public class GroupedKFold<KEY, INSTANCE>
implements CrossValidator<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>> {
    private int k;

    public GroupedKFold(int k) {
        this.k = k;
    }

    @Override
    public CrossValidationIterable<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>> createIterable(GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> data) {
        return new GroupedKFoldIterable(data, this.k);
    }

    public String toString() {
        return this.k + "-Fold Cross-Validation for grouped datasets";
    }

    private class GroupedKFoldIterable
    implements CrossValidationIterable<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>> {
        private GroupedDataset<KEY, ? extends ListDataset<INSTANCE>, INSTANCE> dataset;
        private Map<KEY, int[][]> subsetIndices = new HashMap();
        private int numFolds;

        public GroupedKFoldIterable(GroupedDataset<KEY, ? extends ListDataset<INSTANCE>, INSTANCE> dataset, int k) {
            if (k > dataset.numInstances()) {
                throw new IllegalArgumentException("The number of folds must be less than the number of items in the dataset");
            }
            if (k <= 0) {
                throw new IllegalArgumentException("The number of folds must be at least one");
            }
            this.dataset = dataset;
            this.numFolds = k;
            int[] allIndices = RandomData.getUniqueRandomInts((int)dataset.numInstances(), (int)0, (int)dataset.numInstances());
            int[][] flatSubsetIndices = new int[k][];
            int splitSize = dataset.numInstances() / k;
            for (int i = 0; i < k - 1; ++i) {
                flatSubsetIndices[i] = Arrays.copyOfRange(allIndices, splitSize * i, splitSize * (i + 1));
            }
            flatSubsetIndices[k - 1] = Arrays.copyOfRange(allIndices, splitSize * (k - 1), allIndices.length);
            ArrayList groups = new ArrayList(dataset.getGroups());
            for (Object key : groups) {
                this.subsetIndices.put(key, new int[k][]);
            }
            for (int i = 0; i < flatSubsetIndices.length; ++i) {
                HashMap<Object, TIntArrayList> tmp = new HashMap<Object, TIntArrayList>();
                for (int flatIdx : flatSubsetIndices[i]) {
                    IntObjectPair idx = this.computeIndex(groups, flatIdx);
                    TIntArrayList list = (TIntArrayList)tmp.get(idx.second);
                    if (list == null) {
                        list = new TIntArrayList();
                        tmp.put(idx.second, list);
                    }
                    list.add(idx.first);
                }
                Object object = tmp.entrySet().iterator();
                while (object.hasNext()) {
                    Map.Entry kv = (Map.Entry)object.next();
                    this.subsetIndices.get(kv.getKey())[i] = ((TIntArrayList)kv.getValue()).toArray();
                }
            }
        }

        private IntObjectPair<KEY> computeIndex(ArrayList<KEY> groups, int flatIdx) {
            int count = 0;
            for (Object group : groups) {
                ListDataset instances = (ListDataset)this.dataset.getInstances(group);
                int size = instances.size();
                if (count + size <= flatIdx) {
                    count += size;
                    continue;
                }
                return new IntObjectPair(flatIdx - count, group);
            }
            throw new RuntimeException("Index not found");
        }

        @Override
        public int numberIterations() {
            return this.numFolds;
        }

        @Override
        public Iterator<ValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>> iterator() {
            return new Iterator<ValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>>(){
                int validationSubset = 0;

                @Override
                public boolean hasNext() {
                    return this.validationSubset < GroupedKFoldIterable.this.numFolds;
                }

                @Override
                public ValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>> next() {
                    HashMap train = new HashMap();
                    HashMap valid = new HashMap();
                    for (Object group : GroupedKFoldIterable.this.subsetIndices.keySet()) {
                        int[][] si = (int[][])GroupedKFoldIterable.this.subsetIndices.get(group);
                        List keyData = DatasetAdaptors.asList(GroupedKFoldIterable.this.dataset.getInstances(group));
                        train.put(group, new ListBackedDataset((List)new SkippingListView(keyData, si[this.validationSubset])));
                        valid.put(group, new ListBackedDataset((List)new AcceptingListView(keyData, si[this.validationSubset])));
                    }
                    MapBackedDataset cvTrain = new MapBackedDataset(train);
                    MapBackedDataset cvValid = new MapBackedDataset(valid);
                    ++this.validationSubset;
                    return new DefaultValidationData<MapBackedDataset>(cvTrain, cvValid);
                }

                @Override
                public void remove() {
                    throw new UnsupportedOperationException();
                }
            };
        }
    }
}

