/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.api.worker;

import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import org.datavec.spark.functions.FlatMapFunctionAdapter;
import org.deeplearning4j.datasets.iterator.AsyncDataSetIterator;
import org.deeplearning4j.datasets.iterator.IteratorDataSetIterator;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.spark.api.TrainingResult;
import org.deeplearning4j.spark.api.TrainingWorker;
import org.deeplearning4j.spark.api.WorkerConfiguration;
import org.deeplearning4j.spark.api.stats.CommonSparkTrainingStats;
import org.deeplearning4j.spark.api.stats.SparkTrainingStats;
import org.deeplearning4j.spark.api.stats.StatsCalculationHelper;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;

class ExecuteWorkerFlatMapAdapter<R extends TrainingResult>
implements FlatMapFunctionAdapter<Iterator<org.nd4j.linalg.dataset.DataSet>, R> {
    private final TrainingWorker<R> worker;

    public ExecuteWorkerFlatMapAdapter(TrainingWorker<R> worker) {
        this.worker = worker;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public Iterable<R> call(Iterator<org.nd4j.linalg.dataset.DataSet> dataSetIterator) throws Exception {
        StatsCalculationHelper s;
        WorkerConfiguration dataConfig = this.worker.getDataConfiguration();
        boolean isGraph = dataConfig.isGraphNetwork();
        boolean stats = dataConfig.isCollectTrainingStats();
        StatsCalculationHelper statsCalculationHelper = s = stats ? new StatsCalculationHelper() : null;
        if (stats) {
            s.logMethodStartTime();
        }
        if (!dataSetIterator.hasNext()) {
            if (stats) {
                s.logReturnTime();
                Pair<R, SparkTrainingStats> pair = this.worker.getFinalResultNoDataWithStats();
                ((TrainingResult)pair.getFirst()).setStats(s.build((SparkTrainingStats)pair.getSecond()));
                return Collections.singletonList(pair.getFirst());
            }
            return Collections.singletonList(this.worker.getFinalResultNoData());
        }
        int batchSize = dataConfig.getBatchSizePerWorker();
        int prefetchCount = dataConfig.getPrefetchNumBatches();
        IteratorDataSetIterator batchedIterator = new IteratorDataSetIterator(dataSetIterator, batchSize);
        if (prefetchCount > 0) {
            batchedIterator = new AsyncDataSetIterator((DataSetIterator)batchedIterator, prefetchCount);
        }
        try {
            List<R> list;
            int maxMinibatches;
            MultiLayerNetwork net = null;
            ComputationGraph graph = null;
            if (stats) {
                s.logInitialModelBefore();
            }
            if (isGraph) {
                graph = this.worker.getInitialModelGraph();
            } else {
                net = this.worker.getInitialModel();
            }
            if (stats) {
                s.logInitialModelAfter();
            }
            int miniBatchCount = 0;
            int n = maxMinibatches = dataConfig.getMaxBatchesPerWorker() > 0 ? dataConfig.getMaxBatchesPerWorker() : Integer.MAX_VALUE;
            while (batchedIterator.hasNext() && miniBatchCount++ < maxMinibatches) {
                Object result;
                if (stats) {
                    s.logNextDataSetBefore();
                }
                org.nd4j.linalg.dataset.DataSet next = (org.nd4j.linalg.dataset.DataSet)batchedIterator.next();
                if (stats) {
                    s.logNextDataSetAfter(next.numExamples());
                }
                if (stats) {
                    s.logProcessMinibatchBefore();
                    result = isGraph ? this.worker.processMinibatchWithStats((DataSet)next, graph, !batchedIterator.hasNext()) : this.worker.processMinibatchWithStats((DataSet)next, net, !batchedIterator.hasNext());
                    s.logProcessMinibatchAfter();
                    if (result == null) continue;
                    s.logReturnTime();
                    SparkTrainingStats workerStats = (SparkTrainingStats)result.getSecond();
                    CommonSparkTrainingStats returnStats = s.build(workerStats);
                    ((TrainingResult)result.getFirst()).setStats(returnStats);
                    List<Object> list2 = Collections.singletonList(result.getFirst());
                    return list2;
                }
                result = isGraph ? this.worker.processMinibatch((DataSet)next, graph, !batchedIterator.hasNext()) : this.worker.processMinibatch((DataSet)next, net, !batchedIterator.hasNext());
                if (result == null) continue;
                List<R> list3 = Collections.singletonList(result);
                return list3;
            }
            if (stats) {
                s.logReturnTime();
                Pair<R, SparkTrainingStats> pair = isGraph ? this.worker.getFinalResultWithStats(graph) : this.worker.getFinalResultWithStats(net);
                ((TrainingResult)pair.getFirst()).setStats(s.build((SparkTrainingStats)pair.getSecond()));
                List<Object> list4 = Collections.singletonList(pair.getFirst());
                return list4;
            }
            if (isGraph) {
                list = Collections.singletonList(this.worker.getFinalResult(graph));
                return list;
            }
            list = Collections.singletonList(this.worker.getFinalResult(net));
            return list;
        }
        finally {
            Nd4j.getExecutioner().commit();
            if (batchedIterator instanceof AsyncDataSetIterator) {
                ((AsyncDataSetIterator)batchedIterator).shutdown();
            }
        }
    }
}

