package org.deeplearning4j.parallelism;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/parallelism/MagicQueue.class */
public class MagicQueue implements Queue<DataSet> {
    protected final int numberOfBuckets;
    protected final AtomicInteger nextBucket = new AtomicInteger(0);
    protected final List<LinkedBlockingQueue<DataSet>> backingQueues = new ArrayList();
    protected final List<QueueHandler> handlers = new ArrayList();

    /* loaded from: input_file:org/deeplearning4j/parallelism/MagicQueue$Builder.class */
    public static class Builder {
        private int numberOfBuckets = -1;

        public Builder setNumberOfBuckets(int i) {
            this.numberOfBuckets = i;
            return this;
        }

        public MagicQueue build() {
            if (this.numberOfBuckets < 1) {
                this.numberOfBuckets = Nd4j.getAffinityManager().getNumberOfDevices();
            }
            return new MagicQueue(this.numberOfBuckets);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/deeplearning4j/parallelism/MagicQueue$QueueHandler.class */
    public static class QueueHandler extends Thread implements Runnable {
        private final Queue<DataSet> targetQueue;
        private final LinkedBlockingQueue<DataSet> bufferQueue = new LinkedBlockingQueue<>();

        public QueueHandler(Queue<DataSet> queue) {
            this.targetQueue = queue;
            setDaemon(true);
        }

        public void put(DataSet dataSet) {
            this.bufferQueue.add(dataSet);
        }

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            while (true) {
                try {
                    DataSet poll = this.bufferQueue.poll(1L, TimeUnit.SECONDS);
                    if (poll != null) {
                        if (poll.getFeaturesMaskArray() != null) {
                            Nd4j.getAffinityManager().touch(poll.getFeaturesMaskArray());
                        }
                        if (poll.getLabelsMaskArray() != null) {
                            Nd4j.getAffinityManager().touch(poll.getLabelsMaskArray());
                        }
                        Nd4j.getAffinityManager().touch(poll.getFeatures());
                        Nd4j.getAffinityManager().touch(poll.getLabels());
                        this.targetQueue.add(poll);
                    }
                } catch (Exception e) {
                }
            }
        }
    }

    protected MagicQueue(int i) {
        if (i > 1) {
            for (int i2 = 0; i2 < i; i2++) {
                LinkedBlockingQueue<DataSet> linkedBlockingQueue = new LinkedBlockingQueue<>();
                this.backingQueues.add(linkedBlockingQueue);
                QueueHandler queueHandler = new QueueHandler(linkedBlockingQueue);
                Nd4j.getAffinityManager().attachThreadToDevice(queueHandler, Integer.valueOf(i2));
                queueHandler.start();
                this.handlers.add(queueHandler);
            }
        } else {
            this.backingQueues.add(new LinkedBlockingQueue<>());
        }
        this.numberOfBuckets = i;
    }

    @Override // java.util.Collection
    public int size() {
        if (this.numberOfBuckets <= 1) {
            return this.backingQueues.get(0).size();
        }
        long j = 0;
        for (int i = 0; i < this.numberOfBuckets; i++) {
            j += this.backingQueues.get(i).size();
        }
        return (int) Math.floor(j / this.numberOfBuckets);
    }

    protected int size(int i) {
        if (i >= this.backingQueues.size()) {
            throw new RuntimeException("DeviceID exceeds number of actual backing queues");
        }
        return this.backingQueues.get(i).size();
    }

    @Override // java.util.Collection
    public boolean isEmpty() {
        return size() < 1;
    }

    @Override // java.util.Collection
    public boolean contains(Object obj) {
        throw new UnsupportedOperationException();
    }

    @Override // java.util.Collection, java.lang.Iterable
    public Iterator<DataSet> iterator() {
        throw new UnsupportedOperationException();
    }

    @Override // java.util.Collection
    public Object[] toArray() {
        throw new UnsupportedOperationException();
    }

    @Override // java.util.Collection
    public <T> T[] toArray(T[] tArr) {
        throw new UnsupportedOperationException();
    }

    @Override // java.util.Queue, java.util.Collection
    public boolean add(DataSet dataSet) {
        if (this.numberOfBuckets <= 1) {
            this.backingQueues.get(0).add(dataSet);
            return true;
        }
        synchronized (this) {
            if (this.nextBucket.get() >= this.backingQueues.size()) {
                this.nextBucket.set(0);
            }
        }
        this.handlers.get(this.nextBucket.getAndIncrement()).put(dataSet);
        return true;
    }

    @Override // java.util.Collection
    public boolean remove(Object obj) {
        throw new UnsupportedOperationException();
    }

    @Override // java.util.Collection
    public boolean containsAll(Collection<?> collection) {
        throw new UnsupportedOperationException();
    }

    @Override // java.util.Collection
    public boolean addAll(Collection<? extends DataSet> collection) {
        return false;
    }

    @Override // java.util.Collection
    public boolean removeAll(Collection<?> collection) {
        throw new UnsupportedOperationException();
    }

    @Override // java.util.Collection
    public boolean retainAll(Collection<?> collection) {
        throw new UnsupportedOperationException();
    }

    @Override // java.util.Collection
    public void clear() {
        Iterator<LinkedBlockingQueue<DataSet>> it = this.backingQueues.iterator();
        while (it.hasNext()) {
            it.next().clear();
        }
    }

    @Override // java.util.Queue
    public boolean offer(DataSet dataSet) {
        return false;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // java.util.Queue
    public DataSet remove() {
        return null;
    }

    public DataSet poll(long j, TimeUnit timeUnit) throws InterruptedException {
        if (this.numberOfBuckets <= 1) {
            return this.backingQueues.get(0).poll(j, timeUnit);
        }
        return this.backingQueues.get(Nd4j.getAffinityManager().getDeviceForCurrentThread().intValue()).poll(j, timeUnit);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // java.util.Queue
    public DataSet poll() {
        if (this.numberOfBuckets <= 1) {
            return this.backingQueues.get(0).poll();
        }
        return this.backingQueues.get(Nd4j.getAffinityManager().getDeviceForCurrentThread().intValue()).poll();
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // java.util.Queue
    public DataSet element() {
        return null;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // java.util.Queue
    public DataSet peek() {
        return null;
    }
}
