/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.data;

import java.io.OutputStream;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.spark.api.java.function.Function2;
import org.deeplearning4j.util.UIDProvider;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.primitives.Pair;

public class BatchAndExportMultiDataSetsFunction
implements Function2<Integer, Iterator<org.nd4j.linalg.dataset.api.MultiDataSet>, Iterator<String>> {
    private static final Configuration conf = new Configuration();
    private final int minibatchSize;
    private final String exportBaseDirectory;
    private final String jvmuid;

    public BatchAndExportMultiDataSetsFunction(int minibatchSize, String exportBaseDirectory) {
        this.minibatchSize = minibatchSize;
        this.exportBaseDirectory = exportBaseDirectory;
        String fullUID = UIDProvider.getJVMUID();
        this.jvmuid = fullUID.length() <= 8 ? fullUID : fullUID.substring(0, 8);
    }

    public Iterator<String> call(Integer partitionIdx, Iterator<org.nd4j.linalg.dataset.api.MultiDataSet> iterator) throws Exception {
        ArrayList<String> outputPaths = new ArrayList<String>();
        LinkedList<org.nd4j.linalg.dataset.api.MultiDataSet> tempList = new LinkedList<org.nd4j.linalg.dataset.api.MultiDataSet>();
        int count = 0;
        while (iterator.hasNext()) {
            org.nd4j.linalg.dataset.api.MultiDataSet next = iterator.next();
            if (next.getFeatures(0).size(0) == this.minibatchSize) {
                outputPaths.add(this.export(next, partitionIdx, count++));
                continue;
            }
            tempList.add(next);
            Pair<Integer, List<String>> countAndPaths = this.processList(tempList, partitionIdx, count, false);
            if (countAndPaths.getSecond() != null && ((List)countAndPaths.getSecond()).size() > 0) {
                outputPaths.addAll((Collection)countAndPaths.getSecond());
            }
            count = (Integer)countAndPaths.getFirst();
        }
        Pair<Integer, List<String>> countAndPaths = this.processList(tempList, partitionIdx, count, true);
        if (countAndPaths.getSecond() != null && ((List)countAndPaths.getSecond()).size() > 0) {
            outputPaths.addAll((Collection)countAndPaths.getSecond());
        }
        return outputPaths.iterator();
    }

    private Pair<Integer, List<String>> processList(LinkedList<org.nd4j.linalg.dataset.api.MultiDataSet> tempList, int partitionIdx, int countBefore, boolean finalExport) throws Exception {
        int numExamples = 0;
        for (org.nd4j.linalg.dataset.api.MultiDataSet ds : tempList) {
            numExamples += ds.getFeatures(0).size(0);
        }
        if (tempList.size() == 0 || numExamples < this.minibatchSize && !finalExport) {
            return new Pair((Object)countBefore, Collections.emptyList());
        }
        ArrayList<String> exportPaths = new ArrayList<String>();
        int countAfter = countBefore;
        int countSoFar = 0;
        ArrayList<org.nd4j.linalg.dataset.api.MultiDataSet> tempToMerge = new ArrayList<org.nd4j.linalg.dataset.api.MultiDataSet>();
        while (tempList.size() > 0 && countSoFar != this.minibatchSize) {
            org.nd4j.linalg.dataset.api.MultiDataSet next = tempList.removeFirst();
            if (countSoFar + next.getFeatures(0).size(0) <= this.minibatchSize) {
                tempToMerge.add(next);
                countSoFar += next.getFeatures(0).size(0);
                continue;
            }
            List examples = next.asList();
            for (org.nd4j.linalg.dataset.api.MultiDataSet ds : examples) {
                tempList.addFirst(ds);
            }
        }
        MultiDataSet toExport = MultiDataSet.merge(tempToMerge);
        exportPaths.add(this.export((org.nd4j.linalg.dataset.api.MultiDataSet)toExport, partitionIdx, countAfter++));
        return new Pair((Object)countAfter, exportPaths);
    }

    private String export(org.nd4j.linalg.dataset.api.MultiDataSet dataSet, int partitionIdx, int outputCount) throws Exception {
        String filename = "mds_" + partitionIdx + this.jvmuid + "_" + outputCount + ".bin";
        URI uri = new URI(this.exportBaseDirectory + (this.exportBaseDirectory.endsWith("/") || this.exportBaseDirectory.endsWith("\\") ? "" : "/") + filename);
        FileSystem file = FileSystem.get((URI)uri, (Configuration)conf);
        try (FSDataOutputStream out = file.create(new Path(uri));){
            dataSet.save((OutputStream)out);
        }
        return uri.getPath();
    }
}

