package org.deeplearning4j.parallelism;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/parallelism/MagicQueue.class */
public class MagicQueue implements BlockingQueue<DataSet> {
    private static final Logger log = LoggerFactory.getLogger(MagicQueue.class);
    protected final int numberOfBuckets;
    protected int capacity;
    protected final AtomicInteger nextBucket = new AtomicInteger(0);
    protected Mode mode = Mode.THREADED;
    protected AtomicInteger interleavedCounter = new AtomicInteger(0);
    protected AtomicInteger interleavedPutter = new AtomicInteger(0);
    protected AtomicLong cntPut = new AtomicLong(0);
    protected AtomicLong cntGet = new AtomicLong(0);
    protected final List<LinkedBlockingQueue<DataSet>> backingQueues = new ArrayList();
    protected final List<QueueHandler> handlers = new ArrayList();

    /* loaded from: input_file:org/deeplearning4j/parallelism/MagicQueue$Builder.class */
    public static class Builder {
        private int numberOfBuckets = Nd4j.getAffinityManager().getNumberOfDevices();
        private int capacity = 16;
        private Mode mode = Mode.THREADED;

        public Builder setNumberOfBuckets(int i) {
            this.numberOfBuckets = i;
            return this;
        }

        public Builder setMode(@NonNull Mode mode) {
            if (mode == null) {
                throw new NullPointerException("mode");
            }
            this.mode = mode;
            return this;
        }

        public Builder setCapacityPerFlow(int i) {
            if (i <= 0) {
                throw new ND4JIllegalStateException("Capacity per flow value should be positive value");
            }
            this.capacity = i;
            return this;
        }

        public MagicQueue build() {
            if (this.numberOfBuckets < 1) {
                this.numberOfBuckets = Nd4j.getAffinityManager().getNumberOfDevices();
            }
            MagicQueue magicQueue = new MagicQueue(this.numberOfBuckets, this.capacity);
            magicQueue.mode = this.mode;
            return magicQueue;
        }
    }

    /* loaded from: input_file:org/deeplearning4j/parallelism/MagicQueue$Mode.class */
    public enum Mode {
        THREADED,
        SEQUENTIAL
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/deeplearning4j/parallelism/MagicQueue$QueueHandler.class */
    public static class QueueHandler extends Thread implements Runnable {
        private final BlockingQueue<DataSet> targetQueue;
        private final LinkedBlockingQueue<DataSet> bufferQueue;

        public QueueHandler(BlockingQueue<DataSet> blockingQueue, int i) {
            this.targetQueue = blockingQueue;
            this.bufferQueue = new LinkedBlockingQueue<>(i);
            setDaemon(true);
        }

        public void put(DataSet dataSet) {
            try {
                this.bufferQueue.put(dataSet);
            } catch (InterruptedException e) {
            }
        }

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            while (true) {
                try {
                    DataSet poll = this.bufferQueue.poll(1L, TimeUnit.SECONDS);
                    if (poll != null) {
                        if (poll.getFeaturesMaskArray() != null) {
                            Nd4j.getAffinityManager().touch(poll.getFeaturesMaskArray());
                        }
                        if (poll.getLabelsMaskArray() != null) {
                            Nd4j.getAffinityManager().touch(poll.getLabelsMaskArray());
                        }
                        Nd4j.getAffinityManager().touch(poll.getFeatures());
                        Nd4j.getAffinityManager().touch(poll.getLabels());
                        this.targetQueue.put(poll);
                    }
                } catch (InterruptedException e) {
                    MagicQueue.log.warn("Got InterruptedException...");
                    return;
                }
            }
        }
    }

    protected MagicQueue(int i, int i2) {
        this.capacity = 10;
        this.capacity = i2;
        if (i > 1) {
            for (int i3 = 0; i3 < i; i3++) {
                LinkedBlockingQueue<DataSet> linkedBlockingQueue = new LinkedBlockingQueue<>(i2);
                this.backingQueues.add(linkedBlockingQueue);
                QueueHandler queueHandler = new QueueHandler(linkedBlockingQueue, i2);
                Nd4j.getAffinityManager().attachThreadToDevice(queueHandler, Integer.valueOf(i3));
                queueHandler.start();
                this.handlers.add(queueHandler);
            }
        } else {
            this.backingQueues.add(new LinkedBlockingQueue<>());
        }
        this.numberOfBuckets = i;
    }

    @Override // java.util.Collection
    public int size() {
        if (this.mode != Mode.THREADED) {
            return (int) (this.cntPut.get() - this.cntGet.get());
        }
        if (this.numberOfBuckets <= 1) {
            return this.backingQueues.get(0).size();
        }
        long j = 0;
        for (int i = 0; i < this.numberOfBuckets; i++) {
            j += this.backingQueues.get(i).size();
        }
        return (int) Math.floor(j / this.numberOfBuckets);
    }

    protected int size(int i) {
        if (i >= this.backingQueues.size()) {
            throw new RuntimeException("DeviceID exceeds number of actual backing queues");
        }
        return this.backingQueues.get(i).size();
    }

    @Override // java.util.Collection
    public boolean isEmpty() {
        return size() < 1;
    }

    @Override // java.util.concurrent.BlockingQueue, java.util.Collection
    public boolean contains(Object obj) {
        throw new UnsupportedOperationException();
    }

    @Override // java.util.concurrent.BlockingQueue
    public int drainTo(Collection<? super DataSet> collection) {
        throw new UnsupportedOperationException();
    }

    @Override // java.util.concurrent.BlockingQueue
    public int drainTo(Collection<? super DataSet> collection, int i) {
        throw new UnsupportedOperationException();
    }

    @Override // java.util.Collection, java.lang.Iterable
    public Iterator<DataSet> iterator() {
        throw new UnsupportedOperationException();
    }

    @Override // java.util.Collection
    public Object[] toArray() {
        throw new UnsupportedOperationException();
    }

    @Override // java.util.Collection
    public <T> T[] toArray(T[] tArr) {
        throw new UnsupportedOperationException();
    }

    @Override // java.util.concurrent.BlockingQueue, java.util.Queue, java.util.Collection
    public boolean add(DataSet dataSet) {
        this.cntPut.incrementAndGet();
        if (this.numberOfBuckets <= 1) {
            this.backingQueues.get(0).add(dataSet);
            return true;
        }
        synchronized (this) {
            if (this.nextBucket.get() >= this.backingQueues.size()) {
                this.nextBucket.set(0);
            }
        }
        this.handlers.get(this.nextBucket.getAndIncrement()).put(dataSet);
        return true;
    }

    @Override // java.util.concurrent.BlockingQueue, java.util.Collection
    public boolean remove(Object obj) {
        throw new UnsupportedOperationException();
    }

    @Override // java.util.Collection
    public boolean containsAll(Collection<?> collection) {
        throw new UnsupportedOperationException();
    }

    @Override // java.util.Collection
    public boolean addAll(Collection<? extends DataSet> collection) {
        Iterator<? extends DataSet> it = collection.iterator();
        while (it.hasNext()) {
            boolean add = add(it.next());
            if (!add) {
                return add;
            }
        }
        return true;
    }

    @Override // java.util.Collection
    public boolean removeAll(Collection<?> collection) {
        throw new UnsupportedOperationException();
    }

    @Override // java.util.Collection
    public boolean retainAll(Collection<?> collection) {
        throw new UnsupportedOperationException();
    }

    @Override // java.util.Collection
    public void clear() {
        Iterator<LinkedBlockingQueue<DataSet>> it = this.backingQueues.iterator();
        while (it.hasNext()) {
            it.next().clear();
        }
        this.cntPut.set(0L);
        this.cntGet.set(0L);
    }

    @Override // java.util.concurrent.BlockingQueue, java.util.Queue
    public boolean offer(DataSet dataSet) {
        if (this.numberOfBuckets <= 1) {
            boolean offer = this.backingQueues.get(0).offer(dataSet);
            if (offer) {
                this.cntPut.incrementAndGet();
            }
            return offer;
        }
        boolean offer2 = this.backingQueues.get(Nd4j.getAffinityManager().getDeviceForCurrentThread().intValue()).offer(dataSet);
        if (offer2) {
            this.cntPut.incrementAndGet();
        }
        return offer2;
    }

    @Override // java.util.concurrent.BlockingQueue
    public void put(DataSet dataSet) throws InterruptedException {
        if (this.numberOfBuckets > 1) {
            synchronized (this) {
                if (this.nextBucket.get() >= this.backingQueues.size()) {
                    this.nextBucket.set(0);
                }
            }
            this.handlers.get(this.nextBucket.getAndIncrement()).put(dataSet);
        } else {
            this.backingQueues.get(0).add(dataSet);
        }
        this.cntPut.incrementAndGet();
    }

    @Override // java.util.concurrent.BlockingQueue
    public boolean offer(DataSet dataSet, long j, TimeUnit timeUnit) throws InterruptedException {
        if (this.numberOfBuckets <= 1) {
            boolean offer = this.backingQueues.get(0).offer(dataSet, j, timeUnit);
            if (offer) {
                this.cntPut.incrementAndGet();
            }
            return offer;
        }
        boolean offer2 = this.backingQueues.get(Nd4j.getAffinityManager().getDeviceForCurrentThread().intValue()).offer(dataSet, j, timeUnit);
        if (offer2) {
            this.cntPut.incrementAndGet();
        }
        return offer2;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // java.util.concurrent.BlockingQueue
    public DataSet take() throws InterruptedException {
        try {
            try {
                if (this.mode != Mode.THREADED) {
                    DataSet take = this.backingQueues.get(this.interleavedCounter.getAndIncrement()).take();
                    if (this.interleavedCounter.get() >= this.backingQueues.size()) {
                        this.interleavedCounter.set(0);
                    }
                    return take;
                }
                if (this.numberOfBuckets <= 1) {
                    DataSet take2 = this.backingQueues.get(0).take();
                    this.cntGet.incrementAndGet();
                    return take2;
                }
                DataSet take3 = this.backingQueues.get(Nd4j.getAffinityManager().getDeviceForCurrentThread().intValue()).take();
                this.cntGet.incrementAndGet();
                return take3;
            } catch (InterruptedException e) {
                throw e;
            }
        } finally {
            this.cntGet.incrementAndGet();
        }
    }

    @Override // java.util.Queue
    public DataSet remove() {
        return null;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // java.util.concurrent.BlockingQueue
    public DataSet poll(long j, TimeUnit timeUnit) throws InterruptedException {
        if (this.mode != Mode.THREADED) {
            DataSet poll = this.backingQueues.get(this.interleavedCounter.getAndIncrement()).poll(j, timeUnit);
            if (this.interleavedCounter.get() >= this.backingQueues.size()) {
                this.interleavedCounter.set(0);
            }
            if (poll != null) {
                this.cntGet.incrementAndGet();
            }
            return poll;
        }
        if (this.numberOfBuckets <= 1) {
            DataSet poll2 = this.backingQueues.get(0).poll(j, timeUnit);
            if (poll2 != null) {
                this.cntGet.incrementAndGet();
            }
            return poll2;
        }
        DataSet poll3 = this.backingQueues.get(Nd4j.getAffinityManager().getDeviceForCurrentThread().intValue()).poll(j, timeUnit);
        if (poll3 != null) {
            this.cntGet.incrementAndGet();
        }
        return poll3;
    }

    @Override // java.util.concurrent.BlockingQueue
    public int remainingCapacity() {
        return 0;
    }

    @Override // java.util.Queue
    public DataSet poll() {
        if (this.mode != Mode.THREADED) {
            DataSet poll = this.backingQueues.get(this.interleavedCounter.getAndIncrement()).poll();
            if (this.interleavedCounter.get() >= this.backingQueues.size()) {
                this.interleavedCounter.set(0);
            }
            if (poll != null) {
                this.cntGet.incrementAndGet();
            }
            return poll;
        }
        if (this.numberOfBuckets <= 1) {
            DataSet poll2 = this.backingQueues.get(0).poll();
            if (poll2 != null) {
                this.cntGet.incrementAndGet();
            }
            return poll2;
        }
        DataSet poll3 = this.backingQueues.get(Nd4j.getAffinityManager().getDeviceForCurrentThread().intValue()).poll();
        if (poll3 != null) {
            this.cntGet.incrementAndGet();
        }
        return poll3;
    }

    @Override // java.util.Queue
    public DataSet element() {
        return null;
    }

    @Override // java.util.Queue
    public DataSet peek() {
        return null;
    }
}
