/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.impl.paramavg;

import com.google.common.base.Preconditions;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Random;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaRDDLike;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.input.PortableDataStream;
import org.apache.spark.storage.StorageLevel;
import org.deeplearning4j.api.storage.Persistable;
import org.deeplearning4j.api.storage.StatsStorageRouter;
import org.deeplearning4j.api.storage.StatsStorageRouterProvider;
import org.deeplearning4j.api.storage.StorageMetaData;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.spark.api.RDDTrainingApproach;
import org.deeplearning4j.spark.api.Repartition;
import org.deeplearning4j.spark.api.RepartitionStrategy;
import org.deeplearning4j.spark.api.TrainingHook;
import org.deeplearning4j.spark.api.TrainingMaster;
import org.deeplearning4j.spark.api.WorkerConfiguration;
import org.deeplearning4j.spark.api.stats.SparkTrainingStats;
import org.deeplearning4j.spark.api.worker.ExecuteWorkerFlatMap;
import org.deeplearning4j.spark.api.worker.ExecuteWorkerMultiDataSetFlatMap;
import org.deeplearning4j.spark.api.worker.ExecuteWorkerPDSFlatMap;
import org.deeplearning4j.spark.api.worker.ExecuteWorkerPDSMDSFlatMap;
import org.deeplearning4j.spark.api.worker.ExecuteWorkerPathFlatMap;
import org.deeplearning4j.spark.api.worker.ExecuteWorkerPathMDSFlatMap;
import org.deeplearning4j.spark.api.worker.NetBroadcastTuple;
import org.deeplearning4j.spark.impl.graph.SparkComputationGraph;
import org.deeplearning4j.spark.impl.graph.dataset.DataSetToMultiDataSetFn;
import org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouterProvider;
import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer;
import org.deeplearning4j.spark.impl.paramavg.BaseTrainingMaster;
import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingResult;
import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingWorker;
import org.deeplearning4j.spark.impl.paramavg.aggregator.ParameterAveragingAggregationTuple;
import org.deeplearning4j.spark.impl.paramavg.aggregator.ParameterAveragingElementAddFunction;
import org.deeplearning4j.spark.impl.paramavg.aggregator.ParameterAveragingElementCombineFunction;
import org.deeplearning4j.spark.impl.paramavg.stats.ParameterAveragingTrainingMasterStats;
import org.deeplearning4j.spark.util.SparkUtils;
import org.deeplearning4j.util.UIDProvider;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@JsonIgnoreProperties(value={"stats", "listeners", "iterationCount", "rng", "lastExportedRDDId", "lastRDDExportPath", "trainingMasterUID"})
public class ParameterAveragingTrainingMaster
extends BaseTrainingMaster<ParameterAveragingTrainingResult, ParameterAveragingTrainingWorker>
implements TrainingMaster<ParameterAveragingTrainingResult, ParameterAveragingTrainingWorker> {
    private static final Logger log = LoggerFactory.getLogger(ParameterAveragingTrainingMaster.class);
    protected static final int COALESCE_THRESHOLD = 3;
    protected boolean saveUpdater;
    protected Integer numWorkers;
    protected int rddDataSetNumExamples;
    protected int averagingFrequency;
    protected int aggregationDepth;
    protected int prefetchNumBatches;
    protected int iterationCount = 0;
    protected Collection<TrainingHook> trainingHookList;

    protected ParameterAveragingTrainingMaster() {
        String jvmuid = UIDProvider.getJVMUID();
        this.trainingMasterUID = System.currentTimeMillis() + "_" + (jvmuid.length() <= 8 ? jvmuid : jvmuid.substring(0, 8));
        this.rng = new Random();
    }

    protected ParameterAveragingTrainingMaster(Builder builder) {
        this.saveUpdater = builder.saveUpdater;
        this.numWorkers = builder.numWorkers;
        this.rddDataSetNumExamples = builder.rddDataSetNumExamples;
        this.batchSizePerWorker = builder.batchSizePerWorker;
        this.averagingFrequency = builder.averagingFrequency;
        this.aggregationDepth = builder.aggregationDepth;
        this.prefetchNumBatches = builder.prefetchNumBatches;
        this.repartition = builder.repartition;
        this.repartitionStrategy = builder.repartitionStrategy;
        this.storageLevel = builder.storageLevel;
        this.storageLevelStreams = builder.storageLevelStreams;
        this.rddTrainingApproach = builder.rddTrainingApproach;
        this.exportDirectory = builder.exportDirectory;
        this.trainingHookList = builder.trainingHooks;
        this.rng = builder.rngSeed == null ? new Random() : new Random(builder.rngSeed);
        String jvmuid = UIDProvider.getJVMUID();
        this.trainingMasterUID = System.currentTimeMillis() + "_" + (jvmuid.length() <= 8 ? jvmuid : jvmuid.substring(0, 8));
    }

    public ParameterAveragingTrainingMaster(boolean saveUpdater, Integer numWorkers, int rddDataSetNumExamples, int batchSizePerWorker, int averagingFrequency, int prefetchNumBatches) {
        this(saveUpdater, numWorkers, rddDataSetNumExamples, batchSizePerWorker, averagingFrequency, 2, prefetchNumBatches, Repartition.Always, RepartitionStrategy.Balanced, false);
    }

    public ParameterAveragingTrainingMaster(boolean saveUpdater, Integer numWorkers, int rddDataSetNumExamples, int batchSizePerWorker, int averagingFrequency, int aggregationDepth, int prefetchNumBatches, Repartition repartition, RepartitionStrategy repartitionStrategy, boolean collectTrainingStats) {
        this(saveUpdater, numWorkers, rddDataSetNumExamples, batchSizePerWorker, averagingFrequency, aggregationDepth, prefetchNumBatches, repartition, repartitionStrategy, StorageLevel.MEMORY_ONLY_SER(), collectTrainingStats);
    }

    public ParameterAveragingTrainingMaster(boolean saveUpdater, Integer numWorkers, int rddDataSetNumExamples, int batchSizePerWorker, int averagingFrequency, int aggregationDepth, int prefetchNumBatches, Repartition repartition, RepartitionStrategy repartitionStrategy, StorageLevel storageLevel, boolean collectTrainingStats) {
        Preconditions.checkArgument((numWorkers > 0 ? 1 : 0) != 0, (Object)("Invalid number of workers: " + numWorkers + " (must be >= 1)"));
        Preconditions.checkArgument((rddDataSetNumExamples > 0 ? 1 : 0) != 0, (Object)("Invalid rdd data set size: " + rddDataSetNumExamples + " (must be >= 1)"));
        Preconditions.checkArgument((averagingFrequency > 0 ? 1 : 0) != 0, (Object)"Invalid input: averaging frequency must be >= 1");
        Preconditions.checkArgument((aggregationDepth > 0 ? 1 : 0) != 0, (Object)"Invalid input: tree aggregation depth must be >= 1");
        this.saveUpdater = saveUpdater;
        this.numWorkers = numWorkers;
        this.rddDataSetNumExamples = rddDataSetNumExamples;
        this.batchSizePerWorker = batchSizePerWorker;
        this.averagingFrequency = averagingFrequency;
        this.aggregationDepth = aggregationDepth;
        this.prefetchNumBatches = prefetchNumBatches;
        this.collectTrainingStats = collectTrainingStats;
        this.repartition = repartition;
        this.repartitionStrategy = repartitionStrategy;
        this.storageLevel = storageLevel;
        if (collectTrainingStats) {
            this.stats = new ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper();
        }
        String jvmuid = UIDProvider.getJVMUID();
        this.trainingMasterUID = System.currentTimeMillis() + "_" + (jvmuid.length() <= 8 ? jvmuid : jvmuid.substring(0, 8));
        this.rng = new Random();
    }

    @Override
    public void removeHook(TrainingHook trainingHook) {
        if (this.trainingHookList == null) {
            return;
        }
        this.trainingHookList.remove(trainingHook);
    }

    @Override
    public void addHook(TrainingHook trainingHook) {
        if (this.trainingHookList == null) {
            this.trainingHookList = new ArrayList<TrainingHook>();
        }
        this.trainingHookList.add(trainingHook);
    }

    @Override
    public String toJson() {
        ObjectMapper om = ParameterAveragingTrainingMaster.getJsonMapper();
        try {
            return om.writeValueAsString((Object)this);
        }
        catch (JsonProcessingException e) {
            throw new RuntimeException("Error producing JSON representation for ParameterAveragingTrainingMaster", e);
        }
    }

    @Override
    public String toYaml() {
        ObjectMapper om = ParameterAveragingTrainingMaster.getYamlMapper();
        try {
            return om.writeValueAsString((Object)this);
        }
        catch (JsonProcessingException e) {
            throw new RuntimeException("Error producing YAML representation for ParameterAveragingTrainingMaster", e);
        }
    }

    public static ParameterAveragingTrainingMaster fromJson(String jsonStr) {
        ObjectMapper om = ParameterAveragingTrainingMaster.getJsonMapper();
        try {
            return (ParameterAveragingTrainingMaster)om.readValue(jsonStr, ParameterAveragingTrainingMaster.class);
        }
        catch (IOException e) {
            throw new RuntimeException("Could not parse JSON", e);
        }
    }

    public static ParameterAveragingTrainingMaster fromYaml(String yamlStr) {
        ObjectMapper om = ParameterAveragingTrainingMaster.getYamlMapper();
        try {
            return (ParameterAveragingTrainingMaster)om.readValue(yamlStr, ParameterAveragingTrainingMaster.class);
        }
        catch (IOException e) {
            throw new RuntimeException("Could not parse YAML", e);
        }
    }

    @Override
    public ParameterAveragingTrainingWorker getWorkerInstance(SparkDl4jMultiLayer network) {
        NetBroadcastTuple tuple = new NetBroadcastTuple(network.getNetwork().getLayerWiseConfigurations(), network.getNetwork().params(), network.getNetwork().getUpdater().getStateViewArray());
        if (this.collectTrainingStats) {
            this.stats.logBroadcastStart();
        }
        Broadcast broadcast = network.getSparkContext().broadcast((Object)tuple);
        if (this.collectTrainingStats) {
            this.stats.logBroadcastEnd();
        }
        WorkerConfiguration configuration = new WorkerConfiguration(false, this.rddDataSetNumExamples, this.batchSizePerWorker, this.averagingFrequency, this.prefetchNumBatches, this.collectTrainingStats);
        return new ParameterAveragingTrainingWorker((Broadcast<NetBroadcastTuple>)broadcast, this.saveUpdater, configuration, this.trainingHookList, this.listeners, this.getRouterProvider());
    }

    @Override
    public ParameterAveragingTrainingWorker getWorkerInstance(SparkComputationGraph graph) {
        NetBroadcastTuple tuple = new NetBroadcastTuple(graph.getNetwork().getConfiguration(), graph.getNetwork().params(), graph.getNetwork().getUpdater().getStateViewArray());
        if (this.collectTrainingStats) {
            this.stats.logBroadcastStart();
        }
        Broadcast broadcast = graph.getSparkContext().broadcast((Object)tuple);
        if (this.collectTrainingStats) {
            this.stats.logBroadcastEnd();
        }
        WorkerConfiguration configuration = new WorkerConfiguration(true, this.rddDataSetNumExamples, this.batchSizePerWorker, this.averagingFrequency, this.prefetchNumBatches, this.collectTrainingStats);
        return new ParameterAveragingTrainingWorker((Broadcast<NetBroadcastTuple>)broadcast, this.saveUpdater, configuration, this.trainingHookList, this.listeners, this.getRouterProvider());
    }

    protected int numObjectsEachWorker(int numExamplesEachRddObject) {
        return this.batchSizePerWorker * this.averagingFrequency / numExamplesEachRddObject;
    }

    protected int getNumDataSetObjectsPerSplit(int numExamplesEachRddObject) {
        int dataSetObjectsPerSplit;
        if (numExamplesEachRddObject == 1) {
            dataSetObjectsPerSplit = this.numWorkers * this.batchSizePerWorker * this.averagingFrequency;
        } else {
            int numDataSetObjsReqEachWorker = this.numObjectsEachWorker(numExamplesEachRddObject);
            if (numDataSetObjsReqEachWorker < 1) {
                numDataSetObjsReqEachWorker = 1;
            }
            dataSetObjectsPerSplit = numDataSetObjsReqEachWorker * this.numWorkers;
        }
        return dataSetObjectsPerSplit;
    }

    @Override
    public void executeTraining(SparkDl4jMultiLayer network, JavaRDD<DataSet> trainingData) {
        if (this.numWorkers == null) {
            this.numWorkers = network.getSparkContext().defaultParallelism();
        }
        if (this.rddTrainingApproach == RDDTrainingApproach.Direct) {
            this.executeTrainingDirect(network, trainingData);
        } else {
            JavaRDD<String> paths = this.exportIfRequired(network.getSparkContext(), trainingData);
            this.executeTrainingPathsHelper(network, paths, this.batchSizePerWorker);
        }
    }

    protected <T, Repr extends JavaRDDLike<T, Repr>> long getTotalDataSetObjectCount(JavaRDDLike<T, Repr> trainingData) {
        if (this.collectTrainingStats) {
            this.stats.logCountStart();
        }
        long totalDataSetObjectCount = trainingData.count();
        if (this.collectTrainingStats) {
            this.stats.logCountEnd();
        }
        return totalDataSetObjectCount;
    }

    protected <T, Repr> JavaPairRDD<T, Repr>[] getSplitRDDs(JavaPairRDD<T, Repr> trainingData, int totalDataSetObjectCount) {
        int dataSetObjectsPerSplit = this.getNumDataSetObjectsPerSplit(this.rddDataSetNumExamples);
        if (this.collectTrainingStats) {
            this.stats.logSplitStart();
        }
        JavaPairRDD<T, Repr>[] splits = SparkUtils.balancedRandomSplit(totalDataSetObjectCount, dataSetObjectsPerSplit, trainingData, this.rng.nextLong());
        if (this.collectTrainingStats) {
            this.stats.logSplitEnd();
        }
        return splits;
    }

    protected <T> JavaRDD<T>[] getSplitRDDs(JavaRDD<T> trainingData, int totalDataSetObjectCount, int examplesPerDataSetObject) {
        int dataSetObjectsPerSplit = this.getNumDataSetObjectsPerSplit(examplesPerDataSetObject);
        if (this.collectTrainingStats) {
            this.stats.logSplitStart();
        }
        JavaRDD<T>[] splits = SparkUtils.balancedRandomSplit(totalDataSetObjectCount, dataSetObjectsPerSplit, trainingData, this.rng.nextLong());
        if (this.collectTrainingStats) {
            this.stats.logSplitEnd();
        }
        return splits;
    }

    protected void executeTrainingDirect(SparkDl4jMultiLayer network, JavaRDD<DataSet> trainingData) {
        if (this.collectTrainingStats) {
            this.stats.logFitStart();
        }
        if (this.storageLevel != null) {
            trainingData.persist(this.storageLevel);
        }
        long totalDataSetObjectCount = this.getTotalDataSetObjectCount((JavaRDDLike)trainingData);
        JavaRDD<DataSet>[] splits = this.getSplitRDDs(trainingData, (int)totalDataSetObjectCount, this.rddDataSetNumExamples);
        int splitNum = 1;
        for (JavaRDD<DataSet> split : splits) {
            this.doIteration(network, split, splitNum++, splits.length);
        }
        if (this.collectTrainingStats) {
            this.stats.logFitEnd((int)totalDataSetObjectCount);
        }
    }

    @Override
    @Deprecated
    public void executeTraining(SparkDl4jMultiLayer network, JavaPairRDD<String, PortableDataStream> trainingData) {
        int origNumPartitions;
        if (this.numWorkers == null) {
            this.numWorkers = network.getSparkContext().defaultParallelism();
        }
        if (this.collectTrainingStats) {
            this.stats.logFitStart();
        }
        if ((origNumPartitions = trainingData.partitions().size()) >= 3 * this.numWorkers) {
            log.info("Coalescing PortableDataStreams from {} to {} partitions", (Object)origNumPartitions, (Object)this.numWorkers);
            trainingData = trainingData.coalesce(this.numWorkers.intValue());
        }
        if (this.storageLevelStreams != null) {
            trainingData.persist(this.storageLevelStreams);
        }
        long totalDataSetObjectCount = this.getTotalDataSetObjectCount((JavaRDDLike)trainingData);
        JavaPairRDD<String, PortableDataStream>[] splits = this.getSplitRDDs(trainingData, (int)totalDataSetObjectCount);
        int splitNum = 1;
        for (JavaPairRDD<String, PortableDataStream> split : splits) {
            JavaRDD streams = split.values();
            this.doIterationPDS(network, null, (JavaRDD<PortableDataStream>)streams, splitNum++, splits.length);
        }
        if (this.collectTrainingStats) {
            this.stats.logFitEnd((int)totalDataSetObjectCount);
        }
    }

    @Override
    public void executeTrainingPaths(SparkDl4jMultiLayer network, JavaRDD<String> trainingDataPaths) {
        this.executeTrainingPathsHelper(network, trainingDataPaths, this.rddDataSetNumExamples);
    }

    protected void executeTrainingPathsHelper(SparkDl4jMultiLayer network, JavaRDD<String> trainingDataPaths, int dataSetObjectsNumExamples) {
        if (this.numWorkers == null) {
            this.numWorkers = network.getSparkContext().defaultParallelism();
        }
        if (this.collectTrainingStats) {
            this.stats.logFitStart();
        }
        if (this.storageLevelStreams != null) {
            trainingDataPaths.persist(this.storageLevelStreams);
        }
        long totalDataSetObjectCount = this.getTotalDataSetObjectCount((JavaRDDLike)trainingDataPaths);
        JavaRDD<String>[] splits = this.getSplitRDDs(trainingDataPaths, (int)totalDataSetObjectCount, dataSetObjectsNumExamples);
        int splitNum = 1;
        for (JavaRDD<String> split : splits) {
            this.doIterationPaths(network, null, split, splitNum++, splits.length, dataSetObjectsNumExamples);
        }
        if (this.collectTrainingStats) {
            this.stats.logFitEnd((int)totalDataSetObjectCount);
        }
    }

    @Override
    public void executeTraining(SparkComputationGraph graph, JavaRDD<DataSet> trainingData) {
        if (this.numWorkers == null) {
            this.numWorkers = graph.getSparkContext().defaultParallelism();
        }
        JavaRDD mdsTrainingData = trainingData.map((Function)new DataSetToMultiDataSetFn());
        this.executeTrainingMDS(graph, (JavaRDD<MultiDataSet>)mdsTrainingData);
    }

    @Override
    public void executeTrainingMDS(SparkComputationGraph graph, JavaRDD<MultiDataSet> trainingData) {
        if (this.numWorkers == null) {
            this.numWorkers = graph.getSparkContext().defaultParallelism();
        }
        if (this.rddTrainingApproach == RDDTrainingApproach.Direct) {
            this.executeTrainingDirect(graph, trainingData);
        } else {
            JavaRDD<String> paths = this.exportIfRequiredMDS(graph.getSparkContext(), trainingData);
            this.executeTrainingPathsMDSHelper(graph, paths, this.batchSizePerWorker);
        }
    }

    protected void executeTrainingDirect(SparkComputationGraph graph, JavaRDD<MultiDataSet> trainingData) {
        if (this.collectTrainingStats) {
            this.stats.logFitStart();
        }
        if (this.storageLevel != null) {
            trainingData.persist(this.storageLevel);
        }
        long totalDataSetObjectCount = this.getTotalDataSetObjectCount((JavaRDDLike)trainingData);
        JavaRDD<MultiDataSet>[] splits = this.getSplitRDDs(trainingData, (int)totalDataSetObjectCount, this.rddDataSetNumExamples);
        int splitNum = 1;
        for (JavaRDD<MultiDataSet> split : splits) {
            this.doIteration(graph, split, splitNum++, splits.length);
        }
        if (this.collectTrainingStats) {
            this.stats.logFitEnd((int)totalDataSetObjectCount);
        }
    }

    @Override
    public void executeTraining(SparkComputationGraph graph, JavaPairRDD<String, PortableDataStream> trainingData) {
        int origNumPartitions;
        if (this.numWorkers == null) {
            this.numWorkers = graph.getSparkContext().defaultParallelism();
        }
        if (this.collectTrainingStats) {
            this.stats.logFitStart();
        }
        if ((origNumPartitions = trainingData.partitions().size()) >= 3 * this.numWorkers) {
            log.info("Coalescing streams from {} to {} partitions", (Object)origNumPartitions, (Object)this.numWorkers);
            trainingData = trainingData.coalesce(this.numWorkers.intValue());
        }
        if (this.storageLevelStreams != null) {
            trainingData.persist(this.storageLevelStreams);
        }
        long totalDataSetObjectCount = this.getTotalDataSetObjectCount((JavaRDDLike)trainingData);
        JavaPairRDD<String, PortableDataStream>[] splits = this.getSplitRDDs(trainingData, (int)totalDataSetObjectCount);
        int splitNum = 1;
        for (JavaPairRDD<String, PortableDataStream> split : splits) {
            JavaRDD streams = split.values();
            this.doIterationPDS(null, graph, (JavaRDD<PortableDataStream>)streams, splitNum++, splits.length);
        }
        if (this.collectTrainingStats) {
            this.stats.logFitEnd((int)totalDataSetObjectCount);
        }
    }

    @Override
    public void executeTrainingMDS(SparkComputationGraph graph, JavaPairRDD<String, PortableDataStream> trainingData) {
        if (this.numWorkers == null) {
            this.numWorkers = graph.getSparkContext().defaultParallelism();
        }
        if (this.collectTrainingStats) {
            this.stats.logFitStart();
        }
        if (this.storageLevelStreams != null) {
            trainingData.persist(this.storageLevelStreams);
        }
        long totalDataSetObjectCount = this.getTotalDataSetObjectCount((JavaRDDLike)trainingData);
        JavaPairRDD<String, PortableDataStream>[] splits = this.getSplitRDDs(trainingData, (int)totalDataSetObjectCount);
        int splitNum = 1;
        for (JavaPairRDD<String, PortableDataStream> split : splits) {
            JavaRDD streams = split.values();
            if (this.collectTrainingStats) {
                this.stats.logRepartitionStart();
            }
            streams = SparkUtils.repartition(streams, this.repartition, this.repartitionStrategy, this.numObjectsEachWorker(this.rddDataSetNumExamples), this.numWorkers);
            if (this.collectTrainingStats && this.repartition != Repartition.Never) {
                this.stats.logRepartitionEnd();
            }
            this.doIterationPDS_MDS(graph, streams, splitNum++, splits.length);
        }
        if (this.collectTrainingStats) {
            this.stats.logFitEnd((int)totalDataSetObjectCount);
        }
    }

    @Override
    public void executeTrainingPaths(SparkComputationGraph network, JavaRDD<String> trainingDataPaths) {
        if (this.numWorkers == null) {
            this.numWorkers = network.getSparkContext().defaultParallelism();
        }
        if (this.collectTrainingStats) {
            this.stats.logFitStart();
        }
        if (this.storageLevelStreams != null) {
            trainingDataPaths.persist(this.storageLevelStreams);
        }
        long totalDataSetObjectCount = this.getTotalDataSetObjectCount((JavaRDDLike)trainingDataPaths);
        JavaRDD<String>[] splits = this.getSplitRDDs(trainingDataPaths, (int)totalDataSetObjectCount, this.rddDataSetNumExamples);
        int splitNum = 1;
        for (JavaRDD<String> split : splits) {
            this.doIterationPaths(null, network, split, splitNum++, splits.length, this.rddDataSetNumExamples);
        }
        if (this.collectTrainingStats) {
            this.stats.logFitEnd((int)totalDataSetObjectCount);
        }
    }

    @Override
    public void executeTrainingPathsMDS(SparkComputationGraph network, JavaRDD<String> trainingMultiDataPaths) {
        this.executeTrainingPathsMDSHelper(network, trainingMultiDataPaths, this.rddDataSetNumExamples);
    }

    protected void executeTrainingPathsMDSHelper(SparkComputationGraph network, JavaRDD<String> trainingMultiDataPaths, int dataSetObjectsNumExamples) {
        if (this.numWorkers == null) {
            this.numWorkers = network.getSparkContext().defaultParallelism();
        }
        if (this.collectTrainingStats) {
            this.stats.logFitStart();
        }
        if (this.storageLevelStreams != null) {
            trainingMultiDataPaths.persist(this.storageLevelStreams);
        }
        long totalDataSetObjectCount = this.getTotalDataSetObjectCount((JavaRDDLike)trainingMultiDataPaths);
        JavaRDD<String>[] splits = this.getSplitRDDs(trainingMultiDataPaths, (int)totalDataSetObjectCount, dataSetObjectsNumExamples);
        int splitNum = 1;
        for (JavaRDD<String> split : splits) {
            this.doIterationPathsMDS(network, split, splitNum++, splits.length, dataSetObjectsNumExamples);
        }
        if (this.collectTrainingStats) {
            this.stats.logFitEnd((int)totalDataSetObjectCount);
        }
    }

    @Override
    public void setCollectTrainingStats(boolean collectTrainingStats) {
        this.collectTrainingStats = collectTrainingStats;
        if (collectTrainingStats) {
            if (this.stats == null) {
                this.stats = new ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper();
            }
        } else {
            this.stats = null;
        }
    }

    @Override
    public boolean getIsCollectTrainingStats() {
        return this.collectTrainingStats;
    }

    @Override
    public SparkTrainingStats getTrainingStats() {
        if (this.stats != null) {
            return this.stats.build();
        }
        return null;
    }

    @Override
    public void setListeners(Collection<IterationListener> listeners) {
        this.setListeners(null, listeners);
    }

    @Override
    public void setListeners(StatsStorageRouter statsStorage, Collection<IterationListener> listeners) {
        this.statsStorage = statsStorage;
        this.listeners = listeners;
    }

    protected void doIteration(SparkDl4jMultiLayer network, JavaRDD<DataSet> split, int splitNum, int numSplits) {
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, averagingFreq={}, Configured for {} workers", new Object[]{splitNum, numSplits, this.batchSizePerWorker, this.averagingFrequency, this.numWorkers});
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsStart();
        }
        JavaRDD<DataSet> splitData = split;
        if (this.collectTrainingStats) {
            this.stats.logRepartitionStart();
        }
        splitData = SparkUtils.repartition(splitData, this.repartition, this.repartitionStrategy, this.numObjectsEachWorker(this.rddDataSetNumExamples), this.numWorkers);
        int nPartitions = splitData.partitions().size();
        if (this.collectTrainingStats && this.repartition != Repartition.Never) {
            this.stats.logRepartitionEnd();
        }
        ExecuteWorkerFlatMap<ParameterAveragingTrainingResult> function = new ExecuteWorkerFlatMap<ParameterAveragingTrainingResult>(this.getWorkerInstance(network));
        JavaRDD result = splitData.mapPartitions(function);
        this.processResults(network, null, (JavaRDD<ParameterAveragingTrainingResult>)result, splitNum, numSplits);
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsEnd(nPartitions);
        }
    }

    protected void doIterationPDS(SparkDl4jMultiLayer network, SparkComputationGraph graph, JavaRDD<PortableDataStream> split, int splitNum, int numSplits) {
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, averagingFreq={}, Configured for {} workers", new Object[]{splitNum, numSplits, this.batchSizePerWorker, this.averagingFrequency, this.numWorkers});
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsStart();
        }
        JavaRDD<PortableDataStream> splitData = split;
        if (this.collectTrainingStats) {
            this.stats.logRepartitionStart();
        }
        splitData = SparkUtils.repartition(splitData, this.repartition, this.repartitionStrategy, this.numObjectsEachWorker(this.rddDataSetNumExamples), this.numWorkers);
        int nPartitions = splitData.partitions().size();
        if (this.collectTrainingStats && this.repartition != Repartition.Never) {
            this.stats.logRepartitionEnd();
        }
        ExecuteWorkerPDSFlatMap<ParameterAveragingTrainingResult> function = network != null ? new ExecuteWorkerPDSFlatMap<ParameterAveragingTrainingResult>(this.getWorkerInstance(network)) : new ExecuteWorkerPDSFlatMap<ParameterAveragingTrainingResult>(this.getWorkerInstance(graph));
        JavaRDD result = splitData.mapPartitions(function);
        this.processResults(network, graph, (JavaRDD<ParameterAveragingTrainingResult>)result, splitNum, numSplits);
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsEnd(nPartitions);
        }
    }

    protected void doIterationPaths(SparkDl4jMultiLayer network, SparkComputationGraph graph, JavaRDD<String> split, int splitNum, int numSplits, int dataSetObjectNumExamples) {
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, averagingFreq={}, Configured for {} workers", new Object[]{splitNum, numSplits, this.batchSizePerWorker, this.averagingFrequency, this.numWorkers});
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsStart();
        }
        JavaRDD<String> splitData = split;
        if (this.collectTrainingStats) {
            this.stats.logRepartitionStart();
        }
        splitData = SparkUtils.repartition(splitData, this.repartition, this.repartitionStrategy, this.numObjectsEachWorker(dataSetObjectNumExamples), this.numWorkers);
        int nPartitions = splitData.partitions().size();
        if (this.collectTrainingStats && this.repartition != Repartition.Never) {
            this.stats.logRepartitionEnd();
        }
        ExecuteWorkerPathFlatMap<ParameterAveragingTrainingResult> function = network != null ? new ExecuteWorkerPathFlatMap<ParameterAveragingTrainingResult>(this.getWorkerInstance(network)) : new ExecuteWorkerPathFlatMap<ParameterAveragingTrainingResult>(this.getWorkerInstance(graph));
        JavaRDD result = splitData.mapPartitions(function);
        this.processResults(network, graph, (JavaRDD<ParameterAveragingTrainingResult>)result, splitNum, numSplits);
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsEnd(nPartitions);
        }
    }

    protected void doIterationPathsMDS(SparkComputationGraph graph, JavaRDD<String> split, int splitNum, int numSplits, int dataSetObjectNumExamples) {
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, averagingFreq={}, Configured for {} workers", new Object[]{splitNum, numSplits, this.batchSizePerWorker, this.averagingFrequency, this.numWorkers});
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsStart();
        }
        JavaRDD<String> splitData = split;
        if (this.collectTrainingStats) {
            this.stats.logRepartitionStart();
        }
        splitData = SparkUtils.repartition(splitData, this.repartition, this.repartitionStrategy, this.numObjectsEachWorker(dataSetObjectNumExamples), this.numWorkers);
        int nPartitions = splitData.partitions().size();
        if (this.collectTrainingStats && this.repartition != Repartition.Never) {
            this.stats.logRepartitionEnd();
        }
        ExecuteWorkerPathMDSFlatMap<ParameterAveragingTrainingResult> function = new ExecuteWorkerPathMDSFlatMap<ParameterAveragingTrainingResult>(this.getWorkerInstance(graph));
        JavaRDD result = splitData.mapPartitions(function);
        this.processResults(null, graph, (JavaRDD<ParameterAveragingTrainingResult>)result, splitNum, numSplits);
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsEnd(nPartitions);
        }
    }

    protected void doIteration(SparkComputationGraph graph, JavaRDD<MultiDataSet> split, int splitNum, int numSplits) {
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, averagingFreq={}, Configured for {} workers", new Object[]{splitNum, numSplits, this.batchSizePerWorker, this.averagingFrequency, this.numWorkers});
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsStart();
        }
        JavaRDD<MultiDataSet> splitData = split;
        splitData = SparkUtils.repartition(splitData, this.repartition, this.repartitionStrategy, this.numObjectsEachWorker(this.rddDataSetNumExamples), this.numWorkers);
        int nPartitions = split.partitions().size();
        ExecuteWorkerMultiDataSetFlatMap<ParameterAveragingTrainingResult> function = new ExecuteWorkerMultiDataSetFlatMap<ParameterAveragingTrainingResult>(this.getWorkerInstance(graph));
        JavaRDD result = splitData.mapPartitions(function);
        this.processResults(null, graph, (JavaRDD<ParameterAveragingTrainingResult>)result, splitNum, numSplits);
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsEnd(nPartitions);
        }
    }

    protected void doIterationPDS_MDS(SparkComputationGraph graph, JavaRDD<PortableDataStream> split, int splitNum, int numSplits) {
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, averagingFreq={}, Configured for {} workers", new Object[]{splitNum, numSplits, this.batchSizePerWorker, this.averagingFrequency, this.numWorkers});
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsStart();
        }
        JavaRDD<PortableDataStream> splitData = split;
        if (this.collectTrainingStats) {
            this.stats.logRepartitionStart();
        }
        splitData = SparkUtils.repartition(splitData, this.repartition, this.repartitionStrategy, this.numObjectsEachWorker(this.rddDataSetNumExamples), this.numWorkers);
        int nPartitions = splitData.partitions().size();
        if (this.collectTrainingStats && this.repartition != Repartition.Never) {
            this.stats.logRepartitionEnd();
        }
        ExecuteWorkerPDSMDSFlatMap<ParameterAveragingTrainingResult> function = new ExecuteWorkerPDSMDSFlatMap<ParameterAveragingTrainingResult>(this.getWorkerInstance(graph));
        JavaRDD result = splitData.mapPartitions(function);
        this.processResults(null, graph, (JavaRDD<ParameterAveragingTrainingResult>)result, splitNum, numSplits);
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsEnd(nPartitions);
        }
    }

    protected void processResults(SparkDl4jMultiLayer network, SparkComputationGraph graph, JavaRDD<ParameterAveragingTrainingResult> results, int splitNum, int totalSplits) {
        if (this.collectTrainingStats) {
            this.stats.logAggregateStartTime();
        }
        ParameterAveragingAggregationTuple tuple = (ParameterAveragingAggregationTuple)results.treeAggregate(null, (Function2)new ParameterAveragingElementAddFunction(), (Function2)new ParameterAveragingElementCombineFunction(), this.aggregationDepth);
        INDArray params = tuple.getParametersSum();
        int aggCount = tuple.getAggregationsCount();
        SparkTrainingStats aggregatedStats = tuple.getSparkTrainingStats();
        if (this.collectTrainingStats) {
            this.stats.logAggregationEndTime();
        }
        if (this.collectTrainingStats) {
            this.stats.logProcessParamsUpdaterStart();
        }
        if (params != null) {
            params.divi((Number)aggCount);
            INDArray updaterState = tuple.getUpdaterStateSum();
            if (updaterState != null) {
                updaterState.divi((Number)aggCount);
            }
            if (network != null) {
                MultiLayerNetwork net = network.getNetwork();
                net.setParameters(params);
                if (updaterState != null) {
                    net.getUpdater().setStateViewArray(null, updaterState, false);
                }
                network.setScore(tuple.getScoreSum() / (double)tuple.getAggregationsCount());
            } else {
                ComputationGraph g = graph.getNetwork();
                g.setParams(params);
                if (updaterState != null) {
                    g.getUpdater().setStateViewArray(updaterState);
                }
                graph.setScore(tuple.getScoreSum() / (double)tuple.getAggregationsCount());
            }
        } else {
            log.info("Skipping imbalanced split with no data for all executors");
        }
        if (this.collectTrainingStats) {
            this.stats.logProcessParamsUpdaterEnd();
            this.stats.addWorkerStats(aggregatedStats);
        }
        if (this.statsStorage != null) {
            Collection<Persistable> updates;
            Collection<Persistable> staticInfo;
            Collection<StorageMetaData> meta = tuple.getListenerMetaData();
            if (meta != null && meta.size() > 0) {
                this.statsStorage.putStorageMetaData(meta);
            }
            if ((staticInfo = tuple.getListenerStaticInfo()) != null && staticInfo.size() > 0) {
                this.statsStorage.putStaticInfo(staticInfo);
            }
            if ((updates = tuple.getListenerUpdates()) != null && updates.size() > 0) {
                this.statsStorage.putUpdate(updates);
            }
        }
        Nd4j.getExecutioner().commit();
        log.info("Completed training of split {} of {}", (Object)splitNum, (Object)totalSplits);
        if (params != null) {
            MultiLayerConfiguration conf;
            if (network != null) {
                conf = network.getNetwork().getLayerWiseConfigurations();
                int numUpdates = network.getNetwork().conf().getNumIterations() * this.averagingFrequency;
                conf.setIterationCount(conf.getIterationCount() + numUpdates);
            } else {
                conf = graph.getNetwork().getConfiguration();
                int numUpdates = graph.getNetwork().conf().getNumIterations() * this.averagingFrequency;
                conf.setIterationCount(conf.getIterationCount() + numUpdates);
            }
        }
    }

    protected StatsStorageRouterProvider getRouterProvider() {
        if (this.statsStorage == null) {
            return null;
        }
        return new VanillaStatsStorageRouterProvider();
    }

    public boolean isSaveUpdater() {
        return this.saveUpdater;
    }

    public Integer getNumWorkers() {
        return this.numWorkers;
    }

    public int getRddDataSetNumExamples() {
        return this.rddDataSetNumExamples;
    }

    public int getAveragingFrequency() {
        return this.averagingFrequency;
    }

    public int getAggregationDepth() {
        return this.aggregationDepth;
    }

    public int getPrefetchNumBatches() {
        return this.prefetchNumBatches;
    }

    public int getIterationCount() {
        return this.iterationCount;
    }

    public Collection<TrainingHook> getTrainingHookList() {
        return this.trainingHookList;
    }

    public void setSaveUpdater(boolean saveUpdater) {
        this.saveUpdater = saveUpdater;
    }

    public void setNumWorkers(Integer numWorkers) {
        this.numWorkers = numWorkers;
    }

    public void setRddDataSetNumExamples(int rddDataSetNumExamples) {
        this.rddDataSetNumExamples = rddDataSetNumExamples;
    }

    public void setAveragingFrequency(int averagingFrequency) {
        this.averagingFrequency = averagingFrequency;
    }

    public void setAggregationDepth(int aggregationDepth) {
        this.aggregationDepth = aggregationDepth;
    }

    public void setPrefetchNumBatches(int prefetchNumBatches) {
        this.prefetchNumBatches = prefetchNumBatches;
    }

    public void setIterationCount(int iterationCount) {
        this.iterationCount = iterationCount;
    }

    public void setTrainingHookList(Collection<TrainingHook> trainingHookList) {
        this.trainingHookList = trainingHookList;
    }

    public String toString() {
        return "ParameterAveragingTrainingMaster(saveUpdater=" + this.isSaveUpdater() + ", numWorkers=" + this.getNumWorkers() + ", rddDataSetNumExamples=" + this.getRddDataSetNumExamples() + ", averagingFrequency=" + this.getAveragingFrequency() + ", aggregationDepth=" + this.getAggregationDepth() + ", prefetchNumBatches=" + this.getPrefetchNumBatches() + ", iterationCount=" + this.getIterationCount() + ", trainingHookList=" + this.getTrainingHookList() + ")";
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof ParameterAveragingTrainingMaster)) {
            return false;
        }
        ParameterAveragingTrainingMaster other = (ParameterAveragingTrainingMaster)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (this.isSaveUpdater() != other.isSaveUpdater()) {
            return false;
        }
        Integer this$numWorkers = this.getNumWorkers();
        Integer other$numWorkers = other.getNumWorkers();
        if (this$numWorkers == null ? other$numWorkers != null : !((Object)this$numWorkers).equals(other$numWorkers)) {
            return false;
        }
        if (this.getRddDataSetNumExamples() != other.getRddDataSetNumExamples()) {
            return false;
        }
        if (this.getAveragingFrequency() != other.getAveragingFrequency()) {
            return false;
        }
        if (this.getAggregationDepth() != other.getAggregationDepth()) {
            return false;
        }
        if (this.getPrefetchNumBatches() != other.getPrefetchNumBatches()) {
            return false;
        }
        Collection<TrainingHook> this$trainingHookList = this.getTrainingHookList();
        Collection<TrainingHook> other$trainingHookList = other.getTrainingHookList();
        return !(this$trainingHookList == null ? other$trainingHookList != null : !((Object)this$trainingHookList).equals(other$trainingHookList));
    }

    protected boolean canEqual(Object other) {
        return other instanceof ParameterAveragingTrainingMaster;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + (this.isSaveUpdater() ? 79 : 97);
        Integer $numWorkers = this.getNumWorkers();
        result = result * 59 + ($numWorkers == null ? 43 : ((Object)$numWorkers).hashCode());
        result = result * 59 + this.getRddDataSetNumExamples();
        result = result * 59 + this.getAveragingFrequency();
        result = result * 59 + this.getAggregationDepth();
        result = result * 59 + this.getPrefetchNumBatches();
        Collection<TrainingHook> $trainingHookList = this.getTrainingHookList();
        result = result * 59 + ($trainingHookList == null ? 43 : ((Object)$trainingHookList).hashCode());
        return result;
    }

    public static class Builder {
        protected boolean saveUpdater;
        protected Integer numWorkers;
        protected int rddDataSetNumExamples;
        protected int batchSizePerWorker = 16;
        protected int averagingFrequency = 5;
        protected int aggregationDepth = 2;
        protected int prefetchNumBatches = 0;
        protected Repartition repartition = Repartition.Always;
        protected RepartitionStrategy repartitionStrategy = RepartitionStrategy.Balanced;
        protected StorageLevel storageLevel = StorageLevel.MEMORY_ONLY_SER();
        protected StorageLevel storageLevelStreams = StorageLevel.MEMORY_ONLY();
        protected RDDTrainingApproach rddTrainingApproach = RDDTrainingApproach.Export;
        protected String exportDirectory = null;
        protected Long rngSeed;
        protected Collection<TrainingHook> trainingHooks;

        public Builder trainingHooks(Collection<TrainingHook> trainingHooks) {
            this.trainingHooks = trainingHooks;
            return this;
        }

        public Builder trainingHooks(TrainingHook ... hooks) {
            this.trainingHooks = Arrays.asList(hooks);
            return this;
        }

        public Builder(int rddDataSetNumExamples) {
            this(null, rddDataSetNumExamples);
        }

        public Builder(Integer numWorkers, int rddDataSetNumExamples) {
            Preconditions.checkArgument((numWorkers == null || numWorkers > 0 ? 1 : 0) != 0, (Object)("Invalid number of workers: " + numWorkers + " (must be >= 1)"));
            Preconditions.checkArgument((rddDataSetNumExamples > 0 ? 1 : 0) != 0, (Object)("Invalid rdd data set size: " + rddDataSetNumExamples + " (must be >= 1)"));
            this.numWorkers = numWorkers;
            this.rddDataSetNumExamples = rddDataSetNumExamples;
        }

        public Builder batchSizePerWorker(int batchSizePerWorker) {
            this.batchSizePerWorker = batchSizePerWorker;
            return this;
        }

        public Builder averagingFrequency(int averagingFrequency) {
            Preconditions.checkArgument((averagingFrequency > 0 ? 1 : 0) != 0, (Object)"Invalid input: averaging frequency must be >= 1");
            this.averagingFrequency = averagingFrequency;
            return this;
        }

        public Builder aggregationDepth(int aggregationDepth) {
            Preconditions.checkArgument((aggregationDepth > 0 ? 1 : 0) != 0, (Object)"Invalid input: tree aggregation depth must be >= 1");
            this.aggregationDepth = aggregationDepth;
            return this;
        }

        public Builder workerPrefetchNumBatches(int prefetchNumBatches) {
            this.prefetchNumBatches = prefetchNumBatches;
            return this;
        }

        public Builder saveUpdater(boolean saveUpdater) {
            this.saveUpdater = saveUpdater;
            return this;
        }

        public Builder repartionData(Repartition repartition) {
            this.repartition = repartition;
            return this;
        }

        public Builder repartitionStrategy(RepartitionStrategy repartitionStrategy) {
            this.repartitionStrategy = repartitionStrategy;
            return this;
        }

        public Builder storageLevel(StorageLevel storageLevel) {
            this.storageLevel = storageLevel;
            return this;
        }

        public Builder storageLevelStreams(StorageLevel storageLevelStreams) {
            this.storageLevelStreams = storageLevelStreams;
            return this;
        }

        public Builder rddTrainingApproach(RDDTrainingApproach rddTrainingApproach) {
            this.rddTrainingApproach = rddTrainingApproach;
            return this;
        }

        public Builder exportDirectory(String exportDirectory) {
            this.exportDirectory = exportDirectory;
            return this;
        }

        public Builder rngSeed(long rngSeed) {
            this.rngSeed = rngSeed;
            return this;
        }

        public ParameterAveragingTrainingMaster build() {
            return new ParameterAveragingTrainingMaster(this);
        }
    }
}

