/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.impl.paramavg;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.api.storage.Persistable;
import org.deeplearning4j.api.storage.StatsStorageRouter;
import org.deeplearning4j.api.storage.StatsStorageRouterProvider;
import org.deeplearning4j.api.storage.StorageMetaData;
import org.deeplearning4j.api.storage.listener.RoutingIterationListener;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.util.ComputationGraphUtil;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.MultiLayerUpdater;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.spark.api.TrainingHook;
import org.deeplearning4j.spark.api.WorkerConfiguration;
import org.deeplearning4j.spark.api.stats.SparkTrainingStats;
import org.deeplearning4j.spark.api.worker.NetBroadcastTuple;
import org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouter;
import org.deeplearning4j.spark.impl.paramavg.BaseTrainingWorker;
import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingResult;
import org.deeplearning4j.spark.impl.paramavg.stats.ParameterAveragingTrainingWorkerStats;
import org.deeplearning4j.util.UIDProvider;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;

public class ParameterAveragingTrainingWorker
extends BaseTrainingWorker<ParameterAveragingTrainingResult> {
    private final Broadcast<NetBroadcastTuple> broadcast;
    private final boolean saveUpdater;
    private Collection<TrainingHook> trainingHooks;
    private final WorkerConfiguration configuration;
    private ParameterAveragingTrainingWorkerStats.ParameterAveragingTrainingWorkerStatsHelper stats = null;
    private Collection<IterationListener> iterationListeners;
    private StatsStorageRouterProvider listenerRouterProvider;

    public ParameterAveragingTrainingWorker(Broadcast<NetBroadcastTuple> broadcast, boolean saveUpdater, WorkerConfiguration configuration, Collection<TrainingHook> trainingHooks, Collection<IterationListener> listeners, StatsStorageRouterProvider routerProvider) {
        this.broadcast = broadcast;
        this.saveUpdater = saveUpdater;
        this.configuration = configuration;
        this.trainingHooks = trainingHooks;
        this.iterationListeners = listeners;
        this.listenerRouterProvider = routerProvider;
    }

    @Override
    public void removeHook(TrainingHook trainingHook) {
        if (this.trainingHooks == null) {
            return;
        }
        this.trainingHooks.remove(trainingHook);
    }

    @Override
    public void addHook(TrainingHook trainingHook) {
        if (this.trainingHooks == null) {
            this.trainingHooks = new ArrayList<TrainingHook>();
        }
        this.trainingHooks.add(trainingHook);
    }

    @Override
    public MultiLayerNetwork getInitialModel() {
        if (this.configuration.isCollectTrainingStats()) {
            this.stats = new ParameterAveragingTrainingWorkerStats.ParameterAveragingTrainingWorkerStatsHelper();
        }
        if (this.configuration.isCollectTrainingStats()) {
            this.stats.logBroadcastGetValueStart();
        }
        NetBroadcastTuple tuple = (NetBroadcastTuple)this.broadcast.getValue();
        if (this.configuration.isCollectTrainingStats()) {
            this.stats.logBroadcastGetValueEnd();
        }
        MultiLayerNetwork net = new MultiLayerNetwork(tuple.getConfiguration().clone());
        net.init(tuple.getParameters().unsafeDuplication(), false);
        if (tuple.getUpdaterState() != null) {
            net.setUpdater((Updater)new MultiLayerUpdater(net, tuple.getUpdaterState().unsafeDuplication()));
        }
        Nd4j.getExecutioner().commit();
        this.configureListeners((Model)net, tuple.getCounter().getAndIncrement());
        if (this.configuration.isCollectTrainingStats()) {
            this.stats.logInitEnd();
        }
        return net;
    }

    @Override
    public ComputationGraph getInitialModelGraph() {
        if (this.configuration.isCollectTrainingStats()) {
            this.stats = new ParameterAveragingTrainingWorkerStats.ParameterAveragingTrainingWorkerStatsHelper();
        }
        if (this.configuration.isCollectTrainingStats()) {
            this.stats.logBroadcastGetValueStart();
        }
        NetBroadcastTuple tuple = (NetBroadcastTuple)this.broadcast.getValue();
        if (this.configuration.isCollectTrainingStats()) {
            this.stats.logBroadcastGetValueEnd();
        }
        ComputationGraph net = new ComputationGraph(tuple.getGraphConfiguration().clone());
        net.init(tuple.getParameters().unsafeDuplication(), false);
        if (tuple.getUpdaterState() != null) {
            net.setUpdater(new ComputationGraphUpdater(net, tuple.getUpdaterState().unsafeDuplication()));
        }
        Nd4j.getExecutioner().commit();
        this.configureListeners((Model)net, tuple.getCounter().getAndIncrement());
        if (this.configuration.isCollectTrainingStats()) {
            this.stats.logInitEnd();
        }
        return net;
    }

    private void configureListeners(Model m, int counter) {
        if (this.iterationListeners != null) {
            ArrayList<IterationListener> list = new ArrayList<IterationListener>(this.iterationListeners.size());
            for (IterationListener l : this.iterationListeners) {
                if (this.listenerRouterProvider != null && l instanceof RoutingIterationListener) {
                    RoutingIterationListener rl = (RoutingIterationListener)l;
                    rl.setStorageRouter(this.listenerRouterProvider.getRouter());
                    String workerID = UIDProvider.getJVMUID() + "_" + counter;
                    rl.setWorkerID(workerID);
                }
                list.add(l);
            }
            if (m instanceof MultiLayerNetwork) {
                ((MultiLayerNetwork)m).setListeners(list);
            } else {
                ((ComputationGraph)m).setListeners(list);
            }
        }
    }

    @Override
    public ParameterAveragingTrainingResult processMinibatch(DataSet dataSet, MultiLayerNetwork network, boolean isLast) {
        if (this.configuration.isCollectTrainingStats()) {
            this.stats.logFitStart();
        }
        if (this.trainingHooks != null) {
            for (TrainingHook trainingHook : this.trainingHooks) {
                trainingHook.preUpdate(dataSet, (Model)network);
            }
        }
        network.fit(dataSet);
        if (this.trainingHooks != null) {
            for (TrainingHook trainingHook : this.trainingHooks) {
                trainingHook.postUpdate(dataSet, (Model)network);
            }
        }
        if (this.configuration.isCollectTrainingStats()) {
            this.stats.logFitEnd(dataSet.numExamples());
        }
        Nd4j.getExecutioner().commit();
        if (isLast) {
            return this.getFinalResult(network);
        }
        return null;
    }

    @Override
    public ParameterAveragingTrainingResult processMinibatch(DataSet dataSet, ComputationGraph graph, boolean isLast) {
        return this.processMinibatch(ComputationGraphUtil.toMultiDataSet((DataSet)dataSet), graph, isLast);
    }

    @Override
    public ParameterAveragingTrainingResult processMinibatch(MultiDataSet dataSet, ComputationGraph graph, boolean isLast) {
        if (this.configuration.isCollectTrainingStats()) {
            this.stats.logFitStart();
        }
        if (this.trainingHooks != null) {
            for (TrainingHook trainingHook : this.trainingHooks) {
                trainingHook.preUpdate(dataSet, (Model)graph);
            }
        }
        graph.fit(dataSet);
        if (this.trainingHooks != null) {
            for (TrainingHook trainingHook : this.trainingHooks) {
                trainingHook.postUpdate(dataSet, (Model)graph);
            }
        }
        if (this.configuration.isCollectTrainingStats()) {
            this.stats.logFitEnd(dataSet.getFeatures(0).size(0));
        }
        Nd4j.getExecutioner().commit();
        if (isLast) {
            return this.getFinalResult(graph);
        }
        return null;
    }

    @Override
    public Pair<ParameterAveragingTrainingResult, SparkTrainingStats> processMinibatchWithStats(DataSet dataSet, MultiLayerNetwork network, boolean isLast) {
        ParameterAveragingTrainingResult result = this.processMinibatch(dataSet, network, isLast);
        if (result == null) {
            return null;
        }
        ParameterAveragingTrainingWorkerStats statsToReturn = this.stats != null ? this.stats.build() : null;
        return new Pair((Object)result, (Object)statsToReturn);
    }

    @Override
    public Pair<ParameterAveragingTrainingResult, SparkTrainingStats> processMinibatchWithStats(DataSet dataSet, ComputationGraph graph, boolean isLast) {
        return this.processMinibatchWithStats(ComputationGraphUtil.toMultiDataSet((DataSet)dataSet), graph, isLast);
    }

    @Override
    public Pair<ParameterAveragingTrainingResult, SparkTrainingStats> processMinibatchWithStats(MultiDataSet dataSet, ComputationGraph graph, boolean isLast) {
        ParameterAveragingTrainingResult result = this.processMinibatch(dataSet, graph, isLast);
        if (result == null) {
            return null;
        }
        ParameterAveragingTrainingWorkerStats statsToReturn = this.stats != null ? this.stats.build() : null;
        return new Pair((Object)result, (Object)statsToReturn);
    }

    @Override
    public ParameterAveragingTrainingResult getFinalResult(MultiLayerNetwork network) {
        StatsStorageRouter r;
        Updater u;
        INDArray updaterState = null;
        if (this.saveUpdater && (u = network.getUpdater()) != null) {
            updaterState = u.getStateViewArray();
        }
        Nd4j.getExecutioner().commit();
        List<StorageMetaData> storageMetaData = null;
        List<Persistable> listenerStaticInfo = null;
        List<Persistable> listenerUpdates = null;
        if (this.listenerRouterProvider != null && (r = this.listenerRouterProvider.getRouter()) instanceof VanillaStatsStorageRouter) {
            VanillaStatsStorageRouter ssr = (VanillaStatsStorageRouter)r;
            storageMetaData = ssr.getStorageMetaData();
            listenerStaticInfo = ssr.getStaticInfo();
            listenerUpdates = ssr.getUpdates();
        }
        return new ParameterAveragingTrainingResult(network.params(), updaterState, network.score(), storageMetaData, listenerStaticInfo, listenerUpdates);
    }

    @Override
    public ParameterAveragingTrainingResult getFinalResult(ComputationGraph network) {
        StatsStorageRouter r;
        ComputationGraphUpdater u;
        INDArray updaterState = null;
        if (this.saveUpdater && (u = network.getUpdater()) != null) {
            updaterState = u.getStateViewArray();
        }
        Nd4j.getExecutioner().commit();
        List<StorageMetaData> storageMetaData = null;
        List<Persistable> listenerStaticInfo = null;
        List<Persistable> listenerUpdates = null;
        if (this.listenerRouterProvider != null && (r = this.listenerRouterProvider.getRouter()) instanceof VanillaStatsStorageRouter) {
            VanillaStatsStorageRouter ssr = (VanillaStatsStorageRouter)r;
            storageMetaData = ssr.getStorageMetaData();
            listenerStaticInfo = ssr.getStaticInfo();
            listenerUpdates = ssr.getUpdates();
        }
        return new ParameterAveragingTrainingResult(network.params(), updaterState, network.score(), storageMetaData, listenerStaticInfo, listenerUpdates);
    }

    @Override
    public ParameterAveragingTrainingResult getFinalResultNoData() {
        return new ParameterAveragingTrainingResult(null, null, 0.0, null, null, null);
    }

    @Override
    public Pair<ParameterAveragingTrainingResult, SparkTrainingStats> getFinalResultNoDataWithStats() {
        return new Pair((Object)this.getFinalResultNoData(), null);
    }

    @Override
    public Pair<ParameterAveragingTrainingResult, SparkTrainingStats> getFinalResultWithStats(MultiLayerNetwork network) {
        ParameterAveragingTrainingResult result = this.getFinalResult(network);
        if (result == null) {
            return null;
        }
        ParameterAveragingTrainingWorkerStats statsToReturn = this.stats != null ? this.stats.build() : null;
        return new Pair((Object)result, (Object)statsToReturn);
    }

    @Override
    public Pair<ParameterAveragingTrainingResult, SparkTrainingStats> getFinalResultWithStats(ComputationGraph graph) {
        ParameterAveragingTrainingResult result = this.getFinalResult(graph);
        if (result == null) {
            return null;
        }
        ParameterAveragingTrainingWorkerStats statsToReturn = this.stats != null ? this.stats.build() : null;
        return new Pair((Object)result, (Object)statsToReturn);
    }

    @Override
    public WorkerConfiguration getDataConfiguration() {
        return this.configuration;
    }
}

