package org.deeplearning4j.parallelism;

import java.util.ArrayList;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.deeplearning4j.datasets.iterator.AsyncDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.aggregate.UpdaterAggregator;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/parallelism/ParallelWrapper.class */
public class ParallelWrapper {
    private static Logger logger = LoggerFactory.getLogger(ParallelWrapper.class);
    private Model model;
    private int workers;
    private int prefetchSize;
    private Trainer[] zoo;
    private int averagingFrequency = 1;
    private AtomicLong iterationsCounter = new AtomicLong(0);

    /* loaded from: input_file:org/deeplearning4j/parallelism/ParallelWrapper$Builder.class */
    public static class Builder {
        private Model model;
        private int workers = 2;
        private int prefetchSize = 2;
        private int averagingFrequency = 1;

        public Builder(@NonNull MultiLayerNetwork multiLayerNetwork) {
            if (multiLayerNetwork == null) {
                throw new NullPointerException("mln");
            }
            this.model = multiLayerNetwork;
        }

        public Builder(@NonNull ComputationGraph computationGraph) {
            if (computationGraph == null) {
                throw new NullPointerException("graph");
            }
            this.model = computationGraph;
        }

        public Builder workers(int i) {
            if (i < 1) {
                throw new RuntimeException("Number of workers can't be lower then 1!");
            }
            this.workers = i;
            return this;
        }

        public Builder averagingFrequency(int i) {
            this.averagingFrequency = i;
            return this;
        }

        public Builder prefetchBuffer(int i) {
            if (i < 0) {
                i = 0;
            }
            this.prefetchSize = i;
            return this;
        }

        public ParallelWrapper build() {
            ParallelWrapper parallelWrapper = new ParallelWrapper(this.model, this.workers, this.prefetchSize);
            parallelWrapper.averagingFrequency = this.averagingFrequency;
            return parallelWrapper;
        }
    }

    /* loaded from: input_file:org/deeplearning4j/parallelism/ParallelWrapper$Trainer.class */
    private static class Trainer extends Thread implements Runnable {
        private Model originalModel;
        private Model replicatedModel;
        private LinkedBlockingQueue<DataSet> queue = new LinkedBlockingQueue<>();
        private AtomicInteger running = new AtomicInteger(0);
        private int threadId;

        public Trainer(int i, Model model) {
            this.threadId = i;
            setDaemon(true);
            this.originalModel = model;
            if (model instanceof MultiLayerNetwork) {
                this.replicatedModel = ((MultiLayerNetwork) model).m94clone();
                if (i != 0) {
                    ((MultiLayerNetwork) this.replicatedModel).setListeners(new ArrayList());
                    return;
                }
                return;
            }
            if (model instanceof ComputationGraph) {
                this.replicatedModel = ((ComputationGraph) model).m80clone();
                if (i != 0) {
                    ((ComputationGraph) this.replicatedModel).setListeners(new ArrayList());
                }
            }
        }

        public void feedDataSet(@NonNull DataSet dataSet) {
            if (dataSet == null) {
                throw new NullPointerException("dataSet");
            }
            this.running.incrementAndGet();
            this.queue.add(dataSet);
        }

        public Model getModel() {
            return this.replicatedModel;
        }

        public void updateModel(@NonNull Model model) {
            if (model == null) {
                throw new NullPointerException("model");
            }
            if (model instanceof MultiLayerNetwork) {
                this.replicatedModel = ((MultiLayerNetwork) model).m94clone();
            } else if (model instanceof ComputationGraph) {
                this.replicatedModel = ((ComputationGraph) model).m80clone();
            }
        }

        public boolean isRunning() {
            return this.running.get() == 0;
        }

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            while (true) {
                try {
                    DataSet poll = this.queue.poll(1L, TimeUnit.SECONDS);
                    if (poll != null) {
                        if (this.replicatedModel instanceof MultiLayerNetwork) {
                            ((MultiLayerNetwork) this.replicatedModel).fit(poll);
                        } else if (this.replicatedModel instanceof ComputationGraph) {
                            ((ComputationGraph) this.replicatedModel).fit(poll);
                        }
                        this.running.decrementAndGet();
                    }
                } catch (Exception e) {
                    return;
                }
            }
        }

        public void waitTillRunning() {
            while (this.running.get() != 0) {
                try {
                    Thread.sleep(10L);
                } catch (Exception e) {
                }
            }
        }
    }

    protected ParallelWrapper(Model model, int i, int i2) {
        this.workers = 2;
        this.prefetchSize = 2;
        this.model = model;
        this.workers = i;
        this.prefetchSize = i2;
        this.zoo = new Trainer[i];
        for (int i3 = 0; i3 < i; i3++) {
            this.zoo[i3] = new Trainer(i3, model);
            this.zoo[i3].start();
        }
    }

    public synchronized void fit(@NonNull DataSetIterator dataSetIterator) {
        if (dataSetIterator == null) {
            throw new NullPointerException("source");
        }
        DataSetIterator asyncDataSetIterator = (this.prefetchSize <= 0 || (dataSetIterator instanceof AsyncDataSetIterator) || (dataSetIterator instanceof ListDataSetIterator)) ? dataSetIterator : new AsyncDataSetIterator(dataSetIterator, this.prefetchSize);
        AtomicInteger atomicInteger = new AtomicInteger(0);
        asyncDataSetIterator.reset();
        while (asyncDataSetIterator.hasNext()) {
            DataSet dataSet = (DataSet) asyncDataSetIterator.next();
            int andIncrement = atomicInteger.getAndIncrement();
            this.zoo[andIncrement].feedDataSet(dataSet);
            if (andIncrement + 1 == this.workers || !asyncDataSetIterator.hasNext()) {
                this.iterationsCounter.incrementAndGet();
                for (int i = 0; i < this.workers && i < atomicInteger.get(); i++) {
                    try {
                        this.zoo[i].waitTillRunning();
                    } catch (Exception e) {
                        throw new RuntimeException(e);
                    }
                }
                if (this.iterationsCounter.get() % this.averagingFrequency == 0 || !asyncDataSetIterator.hasNext()) {
                    double d = 0.0d;
                    INDArray zeros = Nd4j.zeros(this.model.params().shape());
                    for (int i2 = 0; i2 < this.workers && i2 < atomicInteger.get(); i2++) {
                        zeros.addi(this.zoo[i2].getModel().params());
                        d += this.zoo[i2].getModel().score();
                    }
                    zeros.divi(Integer.valueOf(Math.min(this.workers, atomicInteger.get())));
                    this.model.setParams(zeros);
                    double min = d / Math.min(this.workers, atomicInteger.get());
                    logger.info("Averaged score: " + min);
                    if (this.model instanceof MultiLayerNetwork) {
                        UpdaterAggregator aggregator = ((MultiLayerNetwork) this.zoo[0].getModel()).getUpdater().getAggregator(false);
                        for (int i3 = 0; i3 < this.workers; i3++) {
                            aggregator.merge(((MultiLayerNetwork) this.zoo[i3].getModel()).getUpdater().getAggregator(true));
                        }
                        ((MultiLayerNetwork) this.model).setScore(min);
                        ((MultiLayerNetwork) this.model).setUpdater(aggregator.getUpdater());
                    } else if (this.model instanceof ComputationGraph) {
                        ComputationGraphUpdater.Aggregator aggregator2 = ((ComputationGraph) this.zoo[0].getModel()).getUpdater().getAggregator(false);
                        for (int i4 = 0; i4 < this.workers; i4++) {
                            aggregator2.merge(((ComputationGraph) this.zoo[i4].getModel()).getUpdater().getAggregator(true));
                        }
                        ((ComputationGraph) this.model).setScore(min);
                        ((ComputationGraph) this.model).setUpdater(aggregator2.getUpdater());
                    }
                    for (int i5 = 0; i5 < this.workers; i5++) {
                        this.zoo[i5].updateModel(this.model);
                    }
                }
                atomicInteger.set(0);
            }
        }
    }
}
