package org.deeplearning4j.optimize.solvers.accumulation;

import java.util.Collection;
import java.util.Iterator;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.LockSupport;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import lombok.NonNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/optimize/solvers/accumulation/FancyBlockingQueue.class */
public class FancyBlockingQueue<E> implements BlockingQueue<E>, Registerable {
    private static final Logger log = LoggerFactory.getLogger(FancyBlockingQueue.class);
    protected BlockingQueue<E> backingQueue;
    protected volatile int consumers;
    protected ThreadLocal<AtomicLong> currentStep;
    protected final AtomicLong step;
    protected final AtomicInteger state;
    protected final AtomicInteger currentConsumers;
    protected AtomicBoolean isFirst;
    protected AtomicBoolean isDone;
    protected AtomicInteger barrier;
    protected AtomicInteger secondary;
    protected AtomicInteger numElementsReady;
    protected AtomicInteger numElementsDrained;
    protected AtomicBoolean bypassMode;
    protected boolean isDebug;
    protected ReentrantReadWriteLock lock;

    public FancyBlockingQueue(@NonNull BlockingQueue<E> blockingQueue) {
        this(blockingQueue, -1);
        if (blockingQueue == null) {
            throw new NullPointerException("queue");
        }
    }

    public FancyBlockingQueue(@NonNull BlockingQueue<E> blockingQueue, int i) {
        this.currentStep = new ThreadLocal<>();
        this.step = new AtomicLong(0L);
        this.state = new AtomicInteger(0);
        this.currentConsumers = new AtomicInteger(0);
        this.isFirst = new AtomicBoolean(false);
        this.isDone = new AtomicBoolean(true);
        this.barrier = new AtomicInteger(0);
        this.secondary = new AtomicInteger(0);
        this.numElementsReady = new AtomicInteger(0);
        this.numElementsDrained = new AtomicInteger(0);
        this.bypassMode = new AtomicBoolean(false);
        this.isDebug = false;
        this.lock = new ReentrantReadWriteLock();
        if (blockingQueue == null) {
            throw new NullPointerException("queue");
        }
        this.backingQueue = blockingQueue;
        this.consumers = i;
        this.currentConsumers.set(i);
    }

    @Override // java.util.concurrent.BlockingQueue, java.util.Queue, java.util.Collection
    public boolean add(E e) {
        return this.backingQueue.add(e);
    }

    @Override // java.util.concurrent.BlockingQueue, java.util.Queue
    public boolean offer(E e) {
        return this.backingQueue.offer(e);
    }

    @Override // java.util.Queue
    public E remove() {
        return this.backingQueue.remove();
    }

    @Override // org.deeplearning4j.optimize.solvers.accumulation.Registerable
    public void fallbackToSingleConsumerMode(boolean z) {
        this.bypassMode.set(z);
    }

    @Override // org.deeplearning4j.optimize.solvers.accumulation.Registerable
    public void registerConsumers(int i) {
        this.lock.writeLock().lock();
        this.numElementsReady.set(this.backingQueue.size());
        this.numElementsDrained.set(0);
        this.consumers = i;
        this.currentConsumers.set(i);
        this.lock.writeLock().unlock();
    }

    @Override // java.util.concurrent.BlockingQueue
    public void put(E e) throws InterruptedException {
        this.lock.readLock().lock();
        this.backingQueue.put(e);
        this.lock.readLock().unlock();
    }

    @Override // java.util.Collection
    public boolean isEmpty() {
        if (this.bypassMode.get()) {
            return this.backingQueue.isEmpty();
        }
        boolean z = this.numElementsDrained.get() >= this.numElementsReady.get();
        if (this.isDebug) {
            log.info("thread {} queries isEmpty: {}", Long.valueOf(Thread.currentThread().getId()), Boolean.valueOf(z));
        }
        return z;
    }

    protected void synchronize(int i) {
        if (i == 1 || this.bypassMode.get()) {
            return;
        }
        if (this.isDebug) {
            log.info("thread {} locking at FBQ", Long.valueOf(Thread.currentThread().getId()));
        }
        this.isDone.compareAndSet(true, false);
        if (this.barrier.incrementAndGet() == i) {
            this.secondary.set(0);
            this.barrier.set(0);
            this.isFirst.set(false);
            this.isDone.set(true);
        } else {
            while (!this.isDone.get()) {
                LockSupport.parkNanos(1000L);
            }
        }
        if (this.secondary.incrementAndGet() == i) {
            this.isFirst.set(true);
        } else {
            while (!this.isFirst.get()) {
                LockSupport.parkNanos(1000L);
            }
        }
        if (this.isDebug) {
            log.info("thread {} unlocking at FBQ", Long.valueOf(Thread.currentThread().getId()));
        }
    }

    @Override // java.util.Queue
    public E poll() {
        if (this.bypassMode.get()) {
            return this.backingQueue.poll();
        }
        if (this.currentStep.get() == null) {
            this.currentStep.set(new AtomicLong(-1L));
        }
        while (this.step.get() == this.currentStep.get().get()) {
            LockSupport.parkNanos(1000L);
        }
        E peek = peek();
        synchronize(this.currentConsumers.get());
        this.currentStep.get().incrementAndGet();
        if (this.state.incrementAndGet() == this.currentConsumers.get()) {
            remove();
            this.numElementsDrained.incrementAndGet();
            this.state.set(0);
            this.step.incrementAndGet();
        }
        synchronize(this.currentConsumers.get());
        return peek;
    }

    @Override // java.util.Queue
    public E element() {
        return this.backingQueue.element();
    }

    @Override // java.util.Collection
    public void clear() {
        this.backingQueue.clear();
        this.step.set(0L);
    }

    @Override // java.util.Collection
    public int size() {
        return this.backingQueue.size();
    }

    @Override // java.util.Queue
    public E peek() {
        return this.backingQueue.peek();
    }

    @Override // java.util.concurrent.BlockingQueue
    public boolean offer(E e, long j, TimeUnit timeUnit) throws InterruptedException {
        return this.backingQueue.offer(e, j, timeUnit);
    }

    @Override // java.util.concurrent.BlockingQueue
    public E take() throws InterruptedException {
        return null;
    }

    @Override // java.util.concurrent.BlockingQueue
    public E poll(long j, TimeUnit timeUnit) throws InterruptedException {
        return this.backingQueue.poll(j, timeUnit);
    }

    @Override // java.util.concurrent.BlockingQueue
    public int remainingCapacity() {
        return this.backingQueue.remainingCapacity();
    }

    @Override // java.util.concurrent.BlockingQueue, java.util.Collection
    public boolean remove(Object obj) {
        return this.backingQueue.remove(obj);
    }

    @Override // java.util.Collection
    public boolean containsAll(Collection<?> collection) {
        return this.backingQueue.containsAll(collection);
    }

    @Override // java.util.Collection
    public boolean addAll(Collection<? extends E> collection) {
        return this.backingQueue.addAll(collection);
    }

    @Override // java.util.Collection
    public boolean removeAll(Collection<?> collection) {
        return this.backingQueue.removeAll(collection);
    }

    @Override // java.util.Collection
    public boolean retainAll(Collection<?> collection) {
        return this.backingQueue.retainAll(collection);
    }

    @Override // java.util.concurrent.BlockingQueue, java.util.Collection
    public boolean contains(Object obj) {
        return this.backingQueue.contains(obj);
    }

    @Override // java.util.Collection, java.lang.Iterable
    public Iterator<E> iterator() {
        throw new UnsupportedOperationException();
    }

    @Override // java.util.Collection
    public Object[] toArray() {
        throw new UnsupportedOperationException();
    }

    @Override // java.util.Collection
    public <T> T[] toArray(T[] tArr) {
        throw new UnsupportedOperationException();
    }

    @Override // java.util.concurrent.BlockingQueue
    public int drainTo(Collection<? super E> collection) {
        throw new UnsupportedOperationException();
    }

    @Override // java.util.concurrent.BlockingQueue
    public int drainTo(Collection<? super E> collection, int i) {
        throw new UnsupportedOperationException();
    }
}
