/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.datavec;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.datavec.spark.functions.FlatMapFunctionAdapter;
import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee;
import org.nd4j.linalg.dataset.DataSet;

public class RDDMiniBatches
implements Serializable {
    private int miniBatches = 10;
    private JavaRDD<DataSet> toSplitJava;

    public RDDMiniBatches(int miniBatches, JavaRDD<DataSet> toSplit) {
        this.miniBatches = miniBatches;
        this.toSplitJava = toSplit;
    }

    public JavaRDD<DataSet> miniBatchesJava() {
        return this.toSplitJava.mapPartitions((FlatMapFunction)new MiniBatchFunction(this.miniBatches));
    }

    static class MiniBatchFunctionAdapter
    implements FlatMapFunctionAdapter<Iterator<DataSet>, DataSet> {
        private int batchSize = 10;

        public MiniBatchFunctionAdapter(int batchSize) {
            this.batchSize = batchSize;
        }

        public Iterable<DataSet> call(Iterator<DataSet> dataSetIterator) throws Exception {
            ArrayList<DataSet> ret = new ArrayList<DataSet>();
            ArrayList<DataSet> temp = new ArrayList<DataSet>();
            while (dataSetIterator.hasNext()) {
                temp.add(dataSetIterator.next().copy());
                if (temp.size() != this.batchSize) continue;
                ret.add(DataSet.merge(temp));
                temp.clear();
            }
            if (temp.size() > 1) {
                ret.add(DataSet.merge(temp));
            }
            return ret;
        }
    }

    public static class MiniBatchFunction
    extends BaseFlatMapFunctionAdaptee<Iterator<DataSet>, DataSet> {
        public MiniBatchFunction(int batchSize) {
            super((FlatMapFunctionAdapter)new MiniBatchFunctionAdapter(batchSize));
        }
    }
}

