package org.deeplearning4j.models.glove;

import java.io.File;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import lombok.NonNull;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectorsImpl;
import org.deeplearning4j.models.glove.count.ASCIICoOccurrenceWriter;
import org.deeplearning4j.models.glove.count.BinaryCoOccurrenceReader;
import org.deeplearning4j.models.glove.count.BinaryCoOccurrenceWriter;
import org.deeplearning4j.models.glove.count.CoOccurrenceWeight;
import org.deeplearning4j.models.glove.count.CoOccurrenceWriter;
import org.deeplearning4j.models.glove.count.CountMap;
import org.deeplearning4j.models.glove.count.RoundCount;
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator;
import org.deeplearning4j.models.sequencevectors.iterators.FilteredSequenceIterator;
import org.deeplearning4j.models.sequencevectors.iterators.SynchronizedSequenceIterator;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
import org.deeplearning4j.text.sentenceiterator.PrefetchingSentenceIterator;
import org.deeplearning4j.text.sentenceiterator.SynchronizedSentenceIterator;
import org.deeplearning4j.util.DL4JFileUtils;
import org.deeplearning4j.util.ThreadUtils;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/models/glove/AbstractCoOccurrences.class */
public class AbstractCoOccurrences<T extends SequenceElement> implements Serializable {
    protected boolean symmetric;
    protected int windowSize;
    protected VocabCache<T> vocabCache;
    protected SequenceIterator<T> sequenceIterator;
    protected int workers;
    protected File targetFile;
    protected ReentrantReadWriteLock lock;
    protected long memory_threshold;
    private AbstractCoOccurrences<T>.ShadowCopyThread shadowThread;
    private volatile CountMap<T> coOccurrenceCounts;
    private AtomicLong processedSequences;
    protected static final Logger logger = LoggerFactory.getLogger(AbstractCoOccurrences.class);

    /* loaded from: input_file:org/deeplearning4j/models/glove/AbstractCoOccurrences$Builder.class */
    public static class Builder<T extends SequenceElement> {
        protected boolean symmetric;
        protected VocabCache<T> vocabCache;
        protected SequenceIterator<T> sequenceIterator;
        protected File target;
        protected int windowSize = 5;
        protected int workers = Runtime.getRuntime().availableProcessors();
        protected long maxmemory = Runtime.getRuntime().maxMemory();

        public Builder<T> symmetric(boolean z) {
            this.symmetric = z;
            return this;
        }

        public Builder<T> windowSize(int i) {
            this.windowSize = i;
            return this;
        }

        public Builder<T> vocabCache(@NonNull VocabCache<T> vocabCache) {
            if (vocabCache == null) {
                throw new NullPointerException("cache is marked @NonNull but is null");
            }
            this.vocabCache = vocabCache;
            return this;
        }

        public Builder<T> iterate(@NonNull SequenceIterator<T> sequenceIterator) {
            if (sequenceIterator == null) {
                throw new NullPointerException("iterator is marked @NonNull but is null");
            }
            this.sequenceIterator = new SynchronizedSequenceIterator(sequenceIterator);
            return this;
        }

        public Builder<T> workers(int i) {
            this.workers = i;
            return this;
        }

        public Builder<T> maxMemory(int i) {
            if (i > 0) {
                this.maxmemory = Math.max(i - 1, 1) * 1024 * 1024 * 1024;
            }
            return this;
        }

        public Builder<T> targetFile(@NonNull String str) {
            if (str == null) {
                throw new NullPointerException("path is marked @NonNull but is null");
            }
            targetFile(new File(str));
            return this;
        }

        public Builder<T> targetFile(@NonNull File file) {
            if (file == null) {
                throw new NullPointerException("file is marked @NonNull but is null");
            }
            this.target = file;
            return this;
        }

        public AbstractCoOccurrences<T> build() {
            AbstractCoOccurrences<T> abstractCoOccurrences = new AbstractCoOccurrences<>();
            abstractCoOccurrences.sequenceIterator = this.sequenceIterator;
            abstractCoOccurrences.windowSize = this.windowSize;
            abstractCoOccurrences.vocabCache = this.vocabCache;
            abstractCoOccurrences.symmetric = this.symmetric;
            abstractCoOccurrences.workers = this.workers;
            if (this.maxmemory < 1) {
                this.maxmemory = Runtime.getRuntime().maxMemory();
            }
            abstractCoOccurrences.memory_threshold = this.maxmemory;
            AbstractCoOccurrences.logger.info("Actual memory limit: [" + this.maxmemory + "]");
            try {
                if (this.target == null) {
                    this.target = DL4JFileUtils.createTempFile("cooccurrence", "map");
                }
                this.target.deleteOnExit();
                abstractCoOccurrences.targetFile = this.target;
                return abstractCoOccurrences;
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/deeplearning4j/models/glove/AbstractCoOccurrences$CoOccurrencesCalculatorThread.class */
    public class CoOccurrencesCalculatorThread extends Thread implements Runnable {
        private final SequenceIterator<T> iterator;
        private final AtomicLong sequenceCounter;
        private int threadId;

        public CoOccurrencesCalculatorThread(int i, @NonNull SequenceIterator<T> sequenceIterator, @NonNull AtomicLong atomicLong) {
            if (sequenceIterator == null) {
                throw new NullPointerException("iterator is marked @NonNull but is null");
            }
            if (atomicLong == null) {
                throw new NullPointerException("sequenceCounter is marked @NonNull but is null");
            }
            this.iterator = sequenceIterator;
            this.sequenceCounter = atomicLong;
            this.threadId = i;
            setName("CoOccurrencesCalculatorThread " + i);
        }

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            while (this.iterator.hasMoreSequences()) {
                Sequence<T> nextSequence = this.iterator.nextSequence();
                ArrayList arrayList = new ArrayList(nextSequence.asLabels());
                for (int i = 0; i < nextSequence.getElements().size(); i++) {
                    int indexOf = AbstractCoOccurrences.this.vocabCache.indexOf((String) arrayList.get(i));
                    if (indexOf >= 0) {
                        AbstractCoOccurrences.this.vocabCache.wordFor((String) arrayList.get(i)).getLabel();
                        int min = Math.min(i + AbstractCoOccurrences.this.windowSize + 1, arrayList.size());
                        for (int i2 = i; i2 < min; i2++) {
                            int indexOf2 = AbstractCoOccurrences.this.vocabCache.indexOf((String) arrayList.get(i2));
                            if (indexOf2 >= 0 && !AbstractCoOccurrences.this.vocabCache.wordFor((String) arrayList.get(i2)).getLabel().equals(WordVectorsImpl.DEFAULT_UNK) && indexOf2 != indexOf) {
                                T wordFor = AbstractCoOccurrences.this.vocabCache.wordFor((String) arrayList.get(i));
                                T wordFor2 = AbstractCoOccurrences.this.vocabCache.wordFor((String) arrayList.get(i2));
                                double d = 1.0d / ((i2 - i) + Nd4j.EPS_THRESHOLD);
                                while (AbstractCoOccurrences.this.getMemoryFootprint() >= AbstractCoOccurrences.this.getMemoryThreshold()) {
                                    AbstractCoOccurrences.this.shadowThread.invoke();
                                    if (this.threadId == 0) {
                                        AbstractCoOccurrences.logger.debug("Memory consuimption > threshold: {footrpint: [" + AbstractCoOccurrences.this.getMemoryFootprint() + "], threshold: [" + AbstractCoOccurrences.this.getMemoryThreshold() + "] }");
                                    }
                                    ThreadUtils.uncheckedSleep(10000L);
                                }
                                try {
                                    AbstractCoOccurrences.this.lock.readLock().lock();
                                    if (indexOf < indexOf2) {
                                        AbstractCoOccurrences.this.coOccurrenceCounts.incrementCount(wordFor, wordFor2, d);
                                        if (AbstractCoOccurrences.this.symmetric) {
                                            AbstractCoOccurrences.this.coOccurrenceCounts.incrementCount(wordFor2, wordFor, d);
                                        }
                                    } else {
                                        AbstractCoOccurrences.this.coOccurrenceCounts.incrementCount(wordFor2, wordFor, d);
                                        if (AbstractCoOccurrences.this.symmetric) {
                                            AbstractCoOccurrences.this.coOccurrenceCounts.incrementCount(wordFor, wordFor2, d);
                                        }
                                    }
                                } finally {
                                    AbstractCoOccurrences.this.lock.readLock().unlock();
                                }
                            }
                        }
                    }
                }
                this.sequenceCounter.incrementAndGet();
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/deeplearning4j/models/glove/AbstractCoOccurrences$ShadowCopyThread.class */
    public class ShadowCopyThread extends Thread implements Runnable {
        private AtomicBoolean isFinished = new AtomicBoolean(false);
        private AtomicBoolean isTerminate = new AtomicBoolean(false);
        private AtomicBoolean isInvoked = new AtomicBoolean(false);
        private AtomicBoolean shouldInvoke = new AtomicBoolean(false);
        private File[] tempFiles;
        private RoundCount counter;

        public ShadowCopyThread() {
            try {
                this.counter = new RoundCount(1);
                this.tempFiles = new File[2];
                this.tempFiles[0] = DL4JFileUtils.createTempFile("aco", "tmp");
                this.tempFiles[1] = DL4JFileUtils.createTempFile("aco", "tmp");
                this.tempFiles[0].deleteOnExit();
                this.tempFiles[1].deleteOnExit();
                setName("ACO ShadowCopy thread");
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            while (!this.isFinished.get() && !this.isTerminate.get()) {
                if (AbstractCoOccurrences.this.getMemoryFootprint() > AbstractCoOccurrences.this.getMemoryThreshold() || (this.shouldInvoke.get() && !this.isInvoked.get())) {
                    this.shouldInvoke.compareAndSet(true, false);
                    invokeBlocking();
                } else {
                    ThreadUtils.uncheckedSleep(1000L);
                }
            }
        }

        public void invoke() {
            this.shouldInvoke.compareAndSet(false, true);
        }

        /* JADX WARN: Multi-variable type inference failed */
        public synchronized void invokeBlocking() {
            if (AbstractCoOccurrences.this.getMemoryFootprint() >= AbstractCoOccurrences.this.getMemoryThreshold() || this.isFinished.get()) {
                int i = 0;
                this.isInvoked.set(true);
                AbstractCoOccurrences.logger.debug("Memory purge started.");
                this.counter.tick();
                try {
                    AbstractCoOccurrences.this.lock.writeLock().lock();
                    CountMap countMap = AbstractCoOccurrences.this.coOccurrenceCounts;
                    AbstractCoOccurrences.this.coOccurrenceCounts = new CountMap();
                    try {
                        if (this.isFinished.get()) {
                            File file = AbstractCoOccurrences.this.targetFile;
                        } else {
                            File file2 = this.tempFiles[this.counter.previous()];
                        }
                        int i2 = 0;
                        AbstractCoOccurrences.logger.debug("Saving to: [" + this.counter.get() + "], Reading from: [" + this.counter.previous() + "]");
                        BinaryCoOccurrenceReader binaryCoOccurrenceReader = new BinaryCoOccurrenceReader(this.tempFiles[this.counter.previous()], AbstractCoOccurrences.this.vocabCache, countMap);
                        CoOccurrenceWriter aSCIICoOccurrenceWriter = this.isFinished.get() ? new ASCIICoOccurrenceWriter(AbstractCoOccurrences.this.targetFile) : new BinaryCoOccurrenceWriter(this.tempFiles[this.counter.get()]);
                        while (binaryCoOccurrenceReader.hasMoreObjects()) {
                            Object nextObject = binaryCoOccurrenceReader.nextObject();
                            if (nextObject != null) {
                                aSCIICoOccurrenceWriter.writeObject(nextObject);
                                i++;
                                i2++;
                            }
                        }
                        binaryCoOccurrenceReader.finish();
                        AbstractCoOccurrences.logger.debug("Lines read: [" + i2 + "]");
                        Iterator<Pair<T, T>> pairIterator = countMap.getPairIterator();
                        while (pairIterator.hasNext()) {
                            Pair<T, T> next = pairIterator.next();
                            double count = countMap.getCount(next);
                            CoOccurrenceWeight coOccurrenceWeight = new CoOccurrenceWeight();
                            coOccurrenceWeight.setElement1((SequenceElement) next.getFirst());
                            coOccurrenceWeight.setElement2((SequenceElement) next.getSecond());
                            coOccurrenceWeight.setWeight(count);
                            aSCIICoOccurrenceWriter.writeObject(coOccurrenceWeight);
                            i++;
                        }
                        aSCIICoOccurrenceWriter.finish();
                        AbstractCoOccurrences.logger.info("Number of word pairs saved so far: [" + i + "]");
                        this.isInvoked.set(false);
                    } catch (Exception e) {
                        throw new RuntimeException(e);
                    }
                } finally {
                    AbstractCoOccurrences.this.lock.writeLock().unlock();
                }
            }
        }

        public void finish() {
            if (this.isFinished.get()) {
                return;
            }
            this.isFinished.set(true);
            invokeBlocking();
        }

        public void terminate() {
            this.isTerminate.set(true);
        }
    }

    private AbstractCoOccurrences() {
        this.workers = Math.max(Runtime.getRuntime().availableProcessors() - 1, 1);
        this.lock = new ReentrantReadWriteLock();
        this.memory_threshold = 0L;
        this.coOccurrenceCounts = new CountMap<>();
        this.processedSequences = new AtomicLong(0L);
    }

    public double getCoOccurrenceCount(@NonNull T t, @NonNull T t2) {
        if (t == null) {
            throw new NullPointerException("element1 is marked @NonNull but is null");
        }
        if (t2 == null) {
            throw new NullPointerException("element2 is marked @NonNull but is null");
        }
        return this.coOccurrenceCounts.getCount(t, t2);
    }

    protected long getMemoryFootprint() {
        try {
            this.lock.readLock().lock();
            return this.coOccurrenceCounts.size() * 24 * 5;
        } finally {
            this.lock.readLock().unlock();
        }
    }

    protected long getMemoryThreshold() {
        return this.memory_threshold / 2;
    }

    public void fit() {
        this.shadowThread = new ShadowCopyThread();
        this.shadowThread.start();
        this.sequenceIterator.reset();
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.workers; i++) {
            arrayList.add(i, new CoOccurrencesCalculatorThread(i, new FilteredSequenceIterator(new SynchronizedSequenceIterator(this.sequenceIterator), this.vocabCache), this.processedSequences));
            ((CoOccurrencesCalculatorThread) arrayList.get(i)).start();
        }
        for (int i2 = 0; i2 < this.workers; i2++) {
            try {
                ((CoOccurrencesCalculatorThread) arrayList.get(i2)).join();
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        this.shadowThread.finish();
        logger.info("CoOccurrences map was built.");
    }

    public Iterator<Pair<Pair<T, T>, Double>> iterator() {
        try {
            final SynchronizedSentenceIterator synchronizedSentenceIterator = new SynchronizedSentenceIterator(new PrefetchingSentenceIterator.Builder(new BasicLineIterator(this.targetFile)).setFetchSize(500000).build());
            return (Iterator<Pair<Pair<T, T>, Double>>) new Iterator<Pair<Pair<T, T>, Double>>() { // from class: org.deeplearning4j.models.glove.AbstractCoOccurrences.1
                @Override // java.util.Iterator
                public boolean hasNext() {
                    return synchronizedSentenceIterator.hasNext();
                }

                @Override // java.util.Iterator
                public Pair<Pair<T, T>, Double> next() {
                    String[] split = synchronizedSentenceIterator.nextSentence().split(" ");
                    T elementAtIndex = AbstractCoOccurrences.this.vocabCache.elementAtIndex(Integer.valueOf(split[0]).intValue());
                    T elementAtIndex2 = AbstractCoOccurrences.this.vocabCache.elementAtIndex(Integer.valueOf(split[1]).intValue());
                    return new Pair<>(new Pair(elementAtIndex, elementAtIndex2), Double.valueOf(split[2]));
                }

                @Override // java.util.Iterator
                public void remove() {
                    throw new UnsupportedOperationException("remove() method can't be supported on read-only interface");
                }
            };
        } catch (Exception e) {
            logger.error("Target file was not found on last stage!");
            throw new RuntimeException(e);
        }
    }
}
