/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.scaleout.actor.core.actor;

import akka.actor.ActorRef;
import akka.actor.ActorSystem;
import akka.actor.UntypedActor;
import akka.contrib.pattern.DistributedPubSubExtension;
import akka.contrib.pattern.DistributedPubSubMediator;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedDeque;
import org.deeplearning4j.nn.conf.Configuration;
import org.deeplearning4j.nn.conf.DeepLearningConfigurable;
import org.deeplearning4j.scaleout.actor.core.actor.MasterActor;
import org.deeplearning4j.scaleout.actor.core.protocol.ResetMessage;
import org.deeplearning4j.scaleout.api.statetracker.StateTracker;
import org.deeplearning4j.scaleout.api.workrouter.WorkRouter;
import org.deeplearning4j.scaleout.job.Job;
import org.deeplearning4j.scaleout.job.JobIterator;
import org.deeplearning4j.scaleout.messages.DoneMessage;
import org.deeplearning4j.scaleout.messages.MoreWorkMessage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class BatchActor
extends UntypedActor
implements DeepLearningConfigurable {
    protected JobIterator iter;
    private final ActorRef mediator = DistributedPubSubExtension.get((ActorSystem)this.getContext().system()).mediator();
    private static Logger log = LoggerFactory.getLogger(BatchActor.class);
    public static final String BATCH = "batch";
    private transient StateTracker stateTracker;
    private transient Configuration conf;
    private Queue<String> workers = new ConcurrentLinkedDeque<String>();
    private int numDataSets = 0;
    private WorkRouter workRouter;

    public BatchActor(JobIterator iter, StateTracker stateTracker, Configuration conf, WorkRouter workRouter) {
        this.iter = iter;
        this.stateTracker = stateTracker;
        this.conf = conf;
        this.mediator.tell((Object)new DistributedPubSubMediator.Subscribe(MasterActor.SHUTDOWN, this.getSelf()), this.getSelf());
        this.mediator.tell((Object)new DistributedPubSubMediator.Subscribe(BATCH, this.getSelf()), this.getSelf());
        this.workRouter = workRouter;
    }

    public void onReceive(Object message) throws Exception {
        if (message instanceof DistributedPubSubMediator.SubscribeAck || message instanceof DistributedPubSubMediator.UnsubscribeAck) {
            log.info("Susbcribed batch actor");
            this.mediator.tell((Object)new DistributedPubSubMediator.Publish("topics", message), this.getSelf());
        } else if (message instanceof ResetMessage) {
            this.iter.reset();
            this.self().tell((Object)MoreWorkMessage.getInstance(), this.self());
        } else if (message instanceof MoreWorkMessage) {
            log.info("Saving model");
            this.mediator.tell((Object)new DistributedPubSubMediator.Publish("save", (Object)MoreWorkMessage.getInstance()), this.mediator);
            if (this.iter.hasNext()) {
                log.info("Propagating new work to master");
                ++this.numDataSets;
                log.info("Iterating over next dataset " + this.numDataSets);
                List workers2 = this.stateTracker.workers();
                for (String s : workers2) {
                    log.info("Worker " + s);
                }
                for (String s : this.stateTracker.workerData()) {
                    this.stateTracker.removeWorkerData(s);
                }
                int numWorkers = workers2.size();
                int miniBatchSize = this.stateTracker.inputSplit();
                if (numWorkers == 0) {
                    numWorkers = Runtime.getRuntime().availableProcessors();
                }
                log.info("Number of workers " + numWorkers + " and batch size is " + miniBatchSize);
                for (String worker : this.stateTracker.workers()) {
                    this.stateTracker.enableWorker(worker);
                }
                int batch = numWorkers * miniBatchSize;
                log.info("Batch size for worker is " + batch);
                for (int i = 0; i < numWorkers && this.iter.hasNext(); ++i) {
                    String worker = this.nextWorker();
                    log.info("Saving data for worker " + worker);
                    if (worker == null) {
                        --i;
                        continue;
                    }
                    Job next = this.iter.next(worker);
                    if (next == null) break;
                    this.workRouter.routeJob(next);
                }
                this.stateTracker.incrementBatchesRan(workers2.size());
                this.mediator.tell((Object)new DistributedPubSubMediator.Publish(MasterActor.MASTER, (Object)this.stateTracker.workerData()), this.mediator);
            } else if (!this.iter.hasNext()) {
                this.mediator.tell((Object)new DistributedPubSubMediator.Publish(MasterActor.MASTER, (Object)DoneMessage.getInstance()), this.mediator);
            } else {
                this.unhandled(message);
            }
        }
    }

    private String nextWorker() {
        while (this.workers.isEmpty()) {
            for (String s : this.stateTracker.workers()) {
                if (this.stateTracker.jobFor(s) != null || this.workers.contains(s)) continue;
                this.workers.add(s);
            }
            log.info("Refilling queue with size of " + this.workers.size() + " out of " + this.stateTracker.numWorkers());
        }
        return this.workers.poll();
    }

    public JobIterator getIter() {
        return this.iter;
    }

    public void setup(Configuration conf) {
    }
}

