/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.scaleout.statetracker.hazelcast;

import com.hazelcast.client.HazelcastClient;
import com.hazelcast.client.config.ClientConfig;
import com.hazelcast.config.Config;
import com.hazelcast.config.JoinConfig;
import com.hazelcast.config.ListConfig;
import com.hazelcast.config.MapConfig;
import com.hazelcast.config.SetConfig;
import com.hazelcast.core.Hazelcast;
import com.hazelcast.core.HazelcastInstance;
import com.hazelcast.core.IAtomicLong;
import com.hazelcast.core.IAtomicReference;
import com.hazelcast.core.IList;
import com.hazelcast.core.IMap;
import com.hazelcast.core.ISet;
import com.hazelcast.core.MemberAttributeEvent;
import com.hazelcast.core.MembershipEvent;
import com.hazelcast.core.MembershipListener;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.net.InetAddress;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.io.IOUtils;
import org.deeplearning4j.scaleout.actor.util.PortTaken;
import org.deeplearning4j.scaleout.aggregator.JobAggregator;
import org.deeplearning4j.scaleout.api.statetracker.IterateAndUpdate;
import org.deeplearning4j.scaleout.api.statetracker.NewUpdateListener;
import org.deeplearning4j.scaleout.api.statetracker.StateTracker;
import org.deeplearning4j.scaleout.api.statetracker.UpdateSaver;
import org.deeplearning4j.scaleout.api.statetracker.WorkRetriever;
import org.deeplearning4j.scaleout.job.Job;
import org.deeplearning4j.scaleout.statetracker.hazelcast.HazelCastStateTracker;
import org.deeplearning4j.scaleout.statetracker.hazelcast.StateTrackerDropWizardResource;
import org.deeplearning4j.scaleout.statetracker.workretriever.LocalWorkRetriever;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.io.ClassPathResource;

public abstract class BaseHazelCastStateTracker
implements StateTracker {
    private static final long serialVersionUID = -7374372180080957334L;
    public static final String JOBS = "org.deeplearning4j.jobs";
    public static final String NUM_TIMES_PRETRAIN_RAN = "pretrainran";
    public static final String WORKERS = "org.deeplearning4j.workers";
    public static final String AVAILABLE_WORKERS = "AVAILABLE_WORKERS";
    public static final String NUM_TIMES_RUN_PRETRAIN = "PRETRAIN";
    public static final String TOPICS = "topics";
    public static final String RESULT = "RESULT";
    public static final String DONE = "done";
    public static final String UPDATES = "updates";
    public static final String REPLICATE_WEIGHTS = "replicate";
    public static final String HEART_BEAT = "heartbeat";
    public static final String WORKER_ENABLED = "workerenabled";
    public static final String INPUT_SPLIT = "inputsplit";
    public static final String IS_PRETRAIN = "ispretrain";
    public static final String BEST_LOSS = "bestloss";
    public static final String IMPROVEMENT_THRESHOLD = "improvementthreshold";
    public static final String EARLY_STOP = "earlystop";
    public static final String PATIENCE = "patience";
    public static final String BEGUN = "begun";
    public static final String NUM_BATCHES_SO_FAR_RAN = "numbatches";
    public static final String GLOBAL_REFERENCE = "globalreference";
    public static final String RECENTLY_CLEARED = "recentlycleared";
    private volatile transient IAtomicReference<Serializable> master;
    private volatile transient IList<Job> jobs;
    private volatile transient IAtomicReference<Integer> numTimesPretrain;
    private volatile transient IAtomicReference<Integer> numTimesPretrainRan;
    private volatile transient IAtomicReference<Double> bestLoss;
    private volatile transient IAtomicReference<Integer> numBatches;
    private volatile transient ISet<String> recentlyClearedJobs;
    private volatile transient IAtomicReference<Boolean> earlyStop;
    private volatile transient IMap<String, Serializable> references;
    private volatile transient IAtomicReference<Boolean> done;
    private volatile transient IList<String> replicate;
    private volatile transient IMap<String, Boolean> workerEnabled;
    private volatile transient IList<String> workers;
    private volatile transient IList<String> topics;
    private volatile transient IList<String> updates;
    private volatile IAtomicReference<Double> patience;
    private volatile IAtomicReference<Boolean> begunTraining;
    private volatile IAtomicReference<Integer> miniBatchSize;
    private WorkRetriever workRetriever = new LocalWorkRetriever();
    protected UpdateSaver saver;
    private volatile IAtomicReference<Boolean> isPretrain;
    private static final Logger log = LoggerFactory.getLogger(HazelCastStateTracker.class);
    private transient Config config;
    public static final int DEFAULT_HAZELCAST_PORT = 2510;
    private transient HazelcastInstance h;
    private String type = "master";
    private int hazelCastPort = -1;
    private String connectionString;
    private Map<String, Long> heartbeat;
    private StateTrackerDropWizardResource resource;
    protected JobAggregator jobAggregator;
    protected Serializable cachedCurrent;
    public static final String HAZELCAST_HOST = "hazelcast.host";
    private List<NewUpdateListener> listeners = new ArrayList<NewUpdateListener>();

    public BaseHazelCastStateTracker() throws Exception {
        this(2510);
    }

    public <E extends Serializable> void define(String key, E o) {
        this.references.put((Object)key, o);
    }

    public <E extends Serializable> E get(String key) {
        return (E)((Serializable)this.references.get((Object)key));
    }

    public double count(String key) {
        IAtomicLong long2 = this.h.getAtomicLong(key);
        return long2.get();
    }

    public void increment(String key, double by) {
        IAtomicLong long2 = this.h.getAtomicLong(key);
        long2.addAndGet((long)by);
    }

    public void removeUpdateListener(NewUpdateListener listener) {
        this.listeners.remove(listener);
    }

    public void addUpdateListener(NewUpdateListener listener) {
        this.listeners.add(listener);
    }

    public int numBatchesRan() {
        return (Integer)this.numBatches.get();
    }

    public void incrementBatchesRan(int numBatchesRan) {
        this.numBatches.set((Object)(numBatchesRan + (Integer)this.numBatches.get()));
    }

    public void startRestApi() {
        String startApi = System.getProperty("startapi", "false");
        Boolean b = Boolean.parseBoolean(startApi);
        if (!b.booleanValue()) {
            return;
        }
        try {
            if (PortTaken.portTaken(8080) || PortTaken.portTaken(8180)) {
                log.warn("Port taken for rest api");
                return;
            }
            InputStream is = new ClassPathResource("/hazelcast/dropwizard.yml").getInputStream();
            this.resource = new StateTrackerDropWizardResource(this);
            File tmpConfig = new File("hazelcast/dropwizard.yml");
            if (!tmpConfig.getParentFile().exists()) {
                tmpConfig.getParentFile().mkdirs();
            }
            BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(tmpConfig));
            IOUtils.copy((InputStream)is, (OutputStream)bos);
            bos.flush();
            this.resource.run(new String[]{"server", tmpConfig.getAbsolutePath()});
            tmpConfig.deleteOnExit();
        }
        catch (Error e1) {
            log.warn("Unable to start server", (Throwable)e1);
        }
        catch (Exception e) {
            log.warn("Unable to start server", (Throwable)e);
        }
    }

    public JobAggregator jobAggregator() {
        return this.jobAggregator;
    }

    public void setJobAggregator(JobAggregator aggregator) {
        this.jobAggregator = aggregator;
    }

    public abstract UpdateSaver createUpdateSaver();

    public int miniBatchSize() {
        return (Integer)this.miniBatchSize.get();
    }

    public boolean hasBegun() {
        return (Boolean)this.begunTraining.get();
    }

    public void removeWorkerData(String worker) {
        this.workRetriever.clear(worker);
    }

    public Collection<String> workerData() {
        return this.workRetriever.workers();
    }

    public void setWorkRetriever(WorkRetriever workRetriever) {
        this.workRetriever = workRetriever;
    }

    public Collection<String> workerUpdates() {
        return this.updates;
    }

    public void setUpdateSaver(UpdateSaver updateSaver) {
        this.saver = updateSaver;
    }

    public UpdateSaver updateSaver() {
        return this.saver;
    }

    public void setMiniBatchSize(int batchSize) {
        this.miniBatchSize.set((Object)batchSize);
    }

    public int inputSplit() {
        Integer get = (Integer)this.miniBatchSize.get();
        if (get == null) {
            this.miniBatchSize.set((Object)10);
        }
        return (Integer)this.miniBatchSize.get() * this.numWorkers() / this.numWorkers();
    }

    public int partition() {
        return this.inputSplit();
    }

    public boolean workerEnabled(String id) {
        return this.workerEnabled.containsKey((Object)id) && (Boolean)this.workerEnabled.get((Object)id) != false;
    }

    public void enableWorker(String id) {
        this.workerEnabled.put((Object)id, (Object)true);
    }

    public void disableWorker(String id) {
        this.workerEnabled.put((Object)id, (Object)false);
    }

    public void doneReplicating(String workerId) {
        this.replicate.remove((Object)workerId);
    }

    public void addReplicate(String workerId) {
        if (!this.replicate.contains((Object)workerId)) {
            this.replicate.add((Object)workerId);
        }
    }

    public boolean needsReplicate(String workerId) {
        return this.replicate.contains((Object)workerId);
    }

    public void addUpdate(String id, Job update) {
        if (update == null) {
            return;
        }
        try {
            this.updateSaver().save(id, update);
            update.setWork(null);
            update.setResult(null);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        this.updates.add((Object)id);
    }

    public abstract IterateAndUpdate updates();

    public void setConnectionString(String connectionString) {
        this.connectionString = connectionString;
    }

    public String connectionString() {
        return this.connectionString;
    }

    public BaseHazelCastStateTracker(int stateTrackerPort) throws Exception {
        this("master", "master", stateTrackerPort);
    }

    public BaseHazelCastStateTracker(String connectionString) throws Exception {
        this(connectionString, "worker", 2510);
    }

    public BaseHazelCastStateTracker(String connectionString, String type, int stateTrackerPort) throws Exception {
        log.info("Setting up hazelcast with type " + type + " connection string " + connectionString + " and port " + stateTrackerPort);
        if (type.equals("master") && !PortTaken.portTaken(stateTrackerPort)) {
            if (connectionString.equals("master")) {
                String hazelCastHost;
                try {
                    hazelCastHost = System.getProperty(HAZELCAST_HOST, InetAddress.getLocalHost().getHostName());
                }
                catch (Exception e) {
                    hazelCastHost = "0.0.0.0";
                }
                this.connectionString = hazelCastHost + ":" + stateTrackerPort;
            }
            this.hazelCastPort = stateTrackerPort;
            this.config = this.hazelcast();
            this.h = Hazelcast.newHazelcastInstance((Config)this.config);
            this.h.getCluster().addMembershipListener(new MembershipListener(){

                public void memberAdded(MembershipEvent membershipEvent) {
                    log.info("Member added " + membershipEvent.toString());
                }

                public void memberRemoved(MembershipEvent membershipEvent) {
                    log.info("Member removed " + membershipEvent.toString());
                }

                public void memberAttributeChanged(MemberAttributeEvent memberAttributeEvent) {
                    log.info("Member changed " + memberAttributeEvent.toString());
                }
            });
        } else {
            if (type.equals("master") && PortTaken.portTaken(stateTrackerPort)) {
                throw new IllegalStateException("Specified type was master and the port specified was taken, please specify a different port");
            }
            this.setConnectionString(connectionString);
            log.info("Connecting to hazelcast on " + connectionString);
            ClientConfig client = new ClientConfig();
            client.getNetworkConfig().addAddress(new String[]{connectionString});
            this.h = HazelcastClient.newHazelcastClient((ClientConfig)client);
        }
        this.type = type;
        this.jobs = this.h.getList(JOBS);
        this.workers = this.h.getList(WORKERS);
        this.recentlyClearedJobs = this.h.getSet(RECENTLY_CLEARED);
        this.begunTraining = this.h.getAtomicReference(BEGUN);
        this.miniBatchSize = this.h.getAtomicReference(INPUT_SPLIT);
        this.workerEnabled = this.h.getMap(WORKER_ENABLED);
        this.replicate = this.h.getList(REPLICATE_WEIGHTS);
        this.topics = this.h.getList(TOPICS);
        this.updates = this.h.getList(UPDATES);
        this.heartbeat = this.h.getMap(HEART_BEAT);
        this.master = this.h.getAtomicReference(RESULT);
        this.isPretrain = this.h.getAtomicReference(IS_PRETRAIN);
        this.numTimesPretrain = this.h.getAtomicReference(NUM_TIMES_RUN_PRETRAIN);
        this.numTimesPretrainRan = this.h.getAtomicReference(NUM_TIMES_PRETRAIN_RAN);
        this.done = this.h.getAtomicReference(DONE);
        this.bestLoss = this.h.getAtomicReference(BEST_LOSS);
        this.earlyStop = this.h.getAtomicReference(EARLY_STOP);
        this.patience = this.h.getAtomicReference(PATIENCE);
        this.numBatches = this.h.getAtomicReference(NUM_BATCHES_SO_FAR_RAN);
        this.references = this.h.getMap(GLOBAL_REFERENCE);
        if (type.equals("master")) {
            this.begunTraining.set((Object)false);
            this.saver = this.createUpdateSaver();
            this.numTimesPretrainRan.set((Object)0);
            this.numTimesPretrain.set((Object)1);
            this.isPretrain.set((Object)true);
            this.done.set((Object)false);
            this.resource = new StateTrackerDropWizardResource(this);
            this.bestLoss.set((Object)Double.POSITIVE_INFINITY);
            this.earlyStop.set((Object)true);
            this.numBatches.set((Object)0);
        }
        this.workRetriever = new LocalWorkRetriever(this.h);
    }

    private Config hazelcast() {
        Config conf = new Config();
        conf.getNetworkConfig().setPort(this.hazelCastPort);
        conf.getNetworkConfig().setPortAutoIncrement(false);
        conf.setProperty("hazelcast.initial.min.cluster.size", "1");
        conf.setProperty("hazelcast.shutdownhook.enabled", "false");
        JoinConfig join = conf.getNetworkConfig().getJoin();
        boolean isAws = System.getProperty("hazelcast.aws", "false").equals("true");
        log.info("Setting up Joiner with this being " + (isAws ? "AWS" : "Multicast"));
        join.getAwsConfig().setEnabled(isAws);
        if (isAws) {
            join.getAwsConfig().setAccessKey(System.getProperty("hazelcast.access-key"));
            join.getAwsConfig().setSecretKey(System.getProperty("hazelcast.access-secret"));
        }
        join.getMulticastConfig().setEnabled(!isAws);
        String interf = System.getProperty("hazelcast.interface");
        if (interf != null) {
            conf.getNetworkConfig().getInterfaces().setEnabled(true).addInterface(interf);
        }
        ListConfig jobConfig = new ListConfig();
        jobConfig.setName(JOBS);
        conf.addListConfig(jobConfig);
        ListConfig replicateConfig = new ListConfig();
        replicateConfig.setName(REPLICATE_WEIGHTS);
        conf.addListConfig(replicateConfig);
        SetConfig cleared = new SetConfig();
        cleared.setName(RECENTLY_CLEARED);
        MapConfig referenceConfig = new MapConfig();
        referenceConfig.setName(GLOBAL_REFERENCE);
        conf.addMapConfig(referenceConfig);
        ListConfig topicsConfig = new ListConfig();
        topicsConfig.setName(TOPICS);
        conf.addListConfig(topicsConfig);
        ListConfig updatesConfig = new ListConfig();
        updatesConfig.setName(UPDATES);
        conf.addListConfig(updatesConfig);
        ListConfig availableWorkersConfig = new ListConfig();
        availableWorkersConfig.setName(AVAILABLE_WORKERS);
        conf.addListConfig(availableWorkersConfig);
        MapConfig heartbeatConfig = new MapConfig();
        heartbeatConfig.setName(HEART_BEAT);
        conf.addMapConfig(heartbeatConfig);
        MapConfig workerEnabledConfig = new MapConfig();
        workerEnabledConfig.setName(WORKER_ENABLED);
        conf.addMapConfig(workerEnabledConfig);
        MapConfig fileUpdateSaver = new MapConfig();
        fileUpdateSaver.setName("updatesaver");
        conf.addMapConfig(fileUpdateSaver);
        MapConfig workRetriever = new MapConfig();
        workRetriever.setName("workretriever");
        conf.addMapConfig(workRetriever);
        return conf;
    }

    public boolean addJobToCurrent(Job j) throws Exception {
        IAtomicReference r = this.h.getAtomicReference("job-" + j.workerId());
        if (r.get() != null || !r.isNull()) {
            boolean sent = false;
            while (!sent) {
                for (String s : this.workers()) {
                    if (this.jobFor(s) != null) continue;
                    log.info("Redirecting worker " + j.workerId() + " to " + s + " due to work already being allocated");
                    r = this.h.getAtomicReference("job-" + s);
                    j.setWorkerId(s);
                    sent = true;
                }
            }
        }
        r.set((Object)j);
        this.jobs.add((Object)j);
        return true;
    }

    public void setServerPort(int port) {
        this.hazelCastPort = port;
    }

    public int getServerPort() {
        return this.hazelCastPort;
    }

    public List<Job> currentJobs() throws Exception {
        return this.jobs;
    }

    public Set<String> recentlyCleared() {
        return this.recentlyClearedJobs;
    }

    public void updateJob(Job j) {
        IAtomicReference jRef = this.h.getAtomicReference("job-" + j.workerId());
        jRef.set((Object)j);
    }

    public void clearJob(String id) throws Exception {
        if (id == null) {
            log.warn("No job to clear; was null, returning");
            return;
        }
        this.recentlyClearedJobs.add((Object)id);
        IAtomicReference jRef = this.h.getAtomicReference("job-" + id);
        if (jRef.isNull()) {
            return;
        }
        jRef.clear();
        log.info("Destroyed job ref " + id);
        Job remove = null;
        for (Job j : this.jobs) {
            if (!j.workerId().equals(id)) continue;
            remove = j;
            break;
        }
        if (remove != null) {
            this.jobs.remove(remove);
        }
    }

    public void shutdown() {
        if (this.h != null) {
            this.h.shutdown();
            this.h.getLifecycleService().shutdown();
        }
        if (this.resource != null) {
            this.resource.shutdown();
        }
    }

    public void addTopic(String topic) throws Exception {
        this.topics.add((Object)topic);
    }

    public List<String> topics() throws Exception {
        return this.topics;
    }

    public Serializable getCurrent() throws Exception {
        if (this.cachedCurrent != null) {
            return this.cachedCurrent;
        }
        Serializable u = (Serializable)this.master.get();
        if (u == null) {
            return null;
        }
        return u;
    }

    public void setCurrent(Serializable e) throws Exception {
        if (e == null) {
            log.warn("Not setting a null update");
            return;
        }
        for (NewUpdateListener listener : this.listeners) {
            listener.onUpdate(e);
        }
        this.master.set((Object)e);
    }

    public Job jobFor(String id) {
        if (((Boolean)this.done.get()).booleanValue()) {
            return null;
        }
        IAtomicReference j = this.h.getAtomicReference("job-" + id);
        if (j.isNull() || this.isCurrentlyJob(id)) {
            return null;
        }
        return (Job)j.get();
    }

    private boolean isCurrentlyJob(String id) {
        for (Job j : this.jobs) {
            if (!j.equals((Object)id)) continue;
            return true;
        }
        return false;
    }

    public void availableForWork(String id) {
        if (!this.workers.contains((Object)id)) {
            this.workers.add((Object)id);
        }
    }

    public List<String> jobIds() {
        ArrayList<String> ret = new ArrayList<String>();
        for (Job j : this.jobs) {
            ret.add(j.workerId());
        }
        return ret;
    }

    public void addWorker(String worker) {
        this.heartbeat.put(worker, System.currentTimeMillis());
        if (!this.workers.contains((Object)worker)) {
            log.info("Adding worker " + worker);
            this.workers.add((Object)worker);
            log.info("Number of workers is now " + this.workers.size());
        }
    }

    public void removeWorker(String worker) {
        this.workers.remove((Object)worker);
        if (this.jobFor(worker) != null) {
            try {
                this.clearJob(worker);
            }
            catch (Exception e) {
                log.warn("Unable to clear job for worker with id" + worker);
            }
        }
    }

    public List<String> workers() {
        return this.workers;
    }

    public int numWorkers() {
        int num = this.workers.size();
        if (num < 1) {
            throw new IllegalStateException("There appears to have been an issue during initialization. No workers found.");
        }
        return num;
    }

    public synchronized HazelcastInstance getH() {
        return this.h;
    }

    public synchronized void setH(HazelcastInstance h) {
        this.h = h;
    }

    public Map<String, Long> getHeartBeats() {
        return this.heartbeat;
    }

    public void runPreTrainIterations(int numTimes) {
        this.numTimesPretrain.set((Object)numTimes);
    }

    public int runPreTrainIterations() {
        return (Integer)this.numTimesPretrain.get();
    }

    public int numTimesPreTrainRun() {
        return (Integer)this.numTimesPretrainRan.get();
    }

    public void incrementNumTimesPreTrainRan() {
        this.numTimesPretrainRan.set((Object)(this.numTimesPreTrainRun() + 1));
    }

    public boolean isDone() {
        try {
            return (Boolean)this.done.get();
        }
        catch (Exception e) {
            log.warn("Hazelcast already shutdown...returning true on isDone()");
            return true;
        }
    }

    public void finish() {
        try {
            if (this.getCurrent() != null) {
                this.cachedCurrent = this.getCurrent();
                for (NewUpdateListener listener : this.listeners) {
                    listener.onUpdate(this.cachedCurrent);
                }
            }
            this.done.set((Object)true);
            this.updateSaver().cleanup();
        }
        catch (Exception e) {
            log.warn("Hazelcast already shutdown...done() being called is pointless");
        }
    }
}

