package edu.stanford.nlp.classify;

import edu.stanford.nlp.util.Function;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.Triple;
import java.util.Iterator;

/* loaded from: input_file:edu/stanford/nlp/classify/CrossValidator.class */
public class CrossValidator<L, F> {
    private final GeneralDataset<L, F> originalTrainData;
    private final int kFold;
    private final SavedState[] savedStates;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:edu/stanford/nlp/classify/CrossValidator$CrossValidationIterator.class */
    public class CrossValidationIterator implements Iterator<Triple<GeneralDataset<L, F>, GeneralDataset<L, F>, SavedState>> {
        int iter = 0;

        CrossValidationIterator() {
        }

        @Override // java.util.Iterator
        public boolean hasNext() {
            return this.iter < CrossValidator.this.kFold;
        }

        @Override // java.util.Iterator
        public void remove() {
            throw new RuntimeException("CrossValidationIterator doesn't support remove()");
        }

        @Override // java.util.Iterator
        public Triple<GeneralDataset<L, F>, GeneralDataset<L, F>, SavedState> next() {
            if (this.iter == CrossValidator.this.kFold) {
                return null;
            }
            Pair<GeneralDataset<L, F>, GeneralDataset<L, F>> split = CrossValidator.this.originalTrainData.split((CrossValidator.this.originalTrainData.size() * this.iter) / CrossValidator.this.kFold, (CrossValidator.this.originalTrainData.size() * (this.iter + 1)) / CrossValidator.this.kFold);
            GeneralDataset<L, F> first = split.first();
            GeneralDataset<L, F> second = split.second();
            SavedState[] savedStateArr = CrossValidator.this.savedStates;
            int i = this.iter;
            this.iter = i + 1;
            return new Triple<>(first, second, savedStateArr[i]);
        }
    }

    /* loaded from: input_file:edu/stanford/nlp/classify/CrossValidator$SavedState.class */
    public static class SavedState {
        public Object state;
    }

    public CrossValidator(GeneralDataset<L, F> generalDataset) {
        this(generalDataset, 10);
    }

    public CrossValidator(GeneralDataset<L, F> generalDataset, int i) {
        this.originalTrainData = generalDataset;
        this.kFold = i;
        this.savedStates = new SavedState[i];
        for (int i2 = 0; i2 < this.savedStates.length; i2++) {
            this.savedStates[i2] = new SavedState();
        }
    }

    private Iterator<Triple<GeneralDataset<L, F>, GeneralDataset<L, F>, SavedState>> iterator() {
        return new CrossValidationIterator();
    }

    public double computeAverage(Function<Triple<GeneralDataset<L, F>, GeneralDataset<L, F>, SavedState>, Double> function) {
        double d = 0.0d;
        Iterator<Triple<GeneralDataset<L, F>, GeneralDataset<L, F>, SavedState>> it = iterator();
        while (it.hasNext()) {
            d += function.apply(it.next()).doubleValue();
        }
        return d / this.kFold;
    }

    public static void main(String[] strArr) {
        Iterator<Triple<GeneralDataset<L, F>, GeneralDataset<L, F>, SavedState>> it = new CrossValidator(Dataset.readSVMLightFormat(strArr[0])).iterator();
        if (it.hasNext()) {
            it.next();
        }
    }
}
