/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.word2vec;

import akka.actor.ActorSystem;
import com.google.common.util.concurrent.AtomicDouble;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.bagofwords.vectorizer.TextVectorizer;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.loader.Word2VecConfiguration;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectorsImpl;
import org.deeplearning4j.models.word2vec.Huffman;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.models.word2vec.wordstore.VocabConstructor;
import org.deeplearning4j.models.word2vec.wordstore.VocabularyHolder;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.InMemoryLookupCache;
import org.deeplearning4j.parallel.Parallelization;
import org.deeplearning4j.text.documentiterator.DocumentIterator;
import org.deeplearning4j.text.invertedindex.InvertedIndex;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.sentenceiterator.StreamLineIterator;
import org.deeplearning4j.text.stopwords.StopWords;
import org.deeplearning4j.text.tokenization.tokenizer.Tokenizer;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.UimaTokenizerFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Word2Vec
extends WordVectorsImpl {
    protected static final long serialVersionUID = -2367495638286018038L;
    protected transient Word2VecConfiguration configuration = new Word2VecConfiguration();
    protected transient TokenizerFactory tokenizerFactory = new DefaultTokenizerFactory();
    protected transient SentenceIterator sentenceIter;
    protected transient DocumentIterator docIter;
    protected transient TextVectorizer vectorizer;
    protected transient InvertedIndex invertedIndex;
    protected transient VocabularyHolder vocabularyHolder;
    protected transient RandomGenerator g;
    protected transient int workers = Runtime.getRuntime().availableProcessors();
    protected int batchSize = 1000;
    protected double sample = 0.0;
    protected long totalWords = 1L;
    protected AtomicDouble alpha = new AtomicDouble(0.025);
    protected int window = 5;
    protected static final Logger log = LoggerFactory.getLogger(Word2Vec.class);
    protected boolean shouldReset = true;
    protected int numIterations = 1;
    public static final String UNK = "UNK";
    protected long seed = 123L;
    protected boolean saveVocab = false;
    protected double minLearningRate = 0.01;
    protected int learningRateDecayWords = 10000;
    private double negative;
    private int epochs;
    protected boolean useAdaGrad = false;
    protected boolean resetModel = true;
    final AtomicLong totalLines = new AtomicLong(0L);

    public TextVectorizer getVectorizer() {
        return this.vectorizer;
    }

    public void setVectorizer(TextVectorizer vectorizer) {
        this.vectorizer = vectorizer;
    }

    protected int fillVocabulary(List<String> tokens) {
        AtomicInteger wordsAdded = new AtomicInteger(0);
        for (String token : tokens) {
            if (this.stopWords != null && this.stopWords.contains(token)) continue;
            if (!this.vocabularyHolder.containsWord(token)) {
                this.vocabularyHolder.addWord(token);
                wordsAdded.incrementAndGet();
                continue;
            }
            this.vocabularyHolder.incrementWordCounter(token);
            wordsAdded.incrementAndGet();
        }
        return wordsAdded.get();
    }

    public VocabCache fillSpecialVocabulary(SentenceIterator iterator, int minWord) {
        iterator.reset();
        while (iterator.hasNext()) {
        }
        return null;
    }

    protected List<VocabWord> digitizeSentence(List<String> tokens) {
        ArrayList<VocabWord> result = new ArrayList<VocabWord>(tokens.size());
        for (String token : tokens) {
            VocabWord word;
            if (this.stopWords != null && this.stopWords.contains(token) || token == null || token.isEmpty() || (word = this.vocab.wordFor(token)) == null) continue;
            result.add(word);
        }
        return result;
    }

    public void fit() throws IOException {
        if (this.sentenceIter == null && this.docIter == null) {
            throw new IllegalStateException("At least one iterator is needed for model fit()");
        }
        LinkedBlockingQueue<List<VocabWord>> sentences = new LinkedBlockingQueue<List<VocabWord>>();
        if (this.resetModel) {
            log.info("Building matrices & resetting weights...");
            this.buildVocab();
            this.lookupTable.resetWeights(true);
            this.resetModel = true;
        }
        long totalWordsCount = this.vocab.totalWordOccurrences() * (long)this.numIterations * (long)this.epochs;
        log.info("Total number of words in vocab: [" + this.vocab.numWords() + "], word occurencies: [" + this.vocab.totalWordOccurrences() + "], buffed words count: [" + totalWordsCount + "], number of Epochs: [" + this.epochs + "],  number of Iterations:[" + this.numIterations + "]");
        long maxLines = this.totalLines.get();
        for (int epoch = 1; epoch <= this.epochs; ++epoch) {
            int x;
            log.info("Starting async iterator...");
            this.totalLines.set(0L);
            AtomicLong wordsCounter = new AtomicLong(0L);
            AsyncIteratorDigitizer roller = new AsyncIteratorDigitizer(this.sentenceIter, sentences, this.totalLines);
            roller.start();
            log.info("Starting vectorization process...");
            VectorCalculationsThread[] threads = new VectorCalculationsThread[this.workers];
            for (x = 0; x < this.workers; ++x) {
                threads[x] = new VectorCalculationsThread(x, maxLines, epoch, wordsCounter, totalWordsCount, this.totalLines, sentences, roller);
                threads[x].start();
            }
            try {
                roller.join();
            }
            catch (Exception e) {
                e.printStackTrace();
            }
            for (x = 0; x < this.workers; ++x) {
                try {
                    threads[x].join();
                    continue;
                }
                catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }
            log.info("Epoch: " + epoch + "; Lines vectorized so far: " + this.totalLines.get());
        }
        log.info("Vectorization accomplished.");
    }

    private void doIteration(Collection<List<VocabWord>> batch2, final AtomicLong numWordsSoFar, final AtomicLong nextRandom, ActorSystem actorSystem) {
        final AtomicLong lastReported = new AtomicLong(System.currentTimeMillis());
        Parallelization.iterateInParallel(batch2, (Parallelization.RunnableWithParams)new Parallelization.RunnableWithParams<List<VocabWord>>(){

            public void run(List<VocabWord> sentence, Object[] args) {
                double alpha = Math.max(Word2Vec.this.minLearningRate, Word2Vec.this.alpha.get() * (1.0 - 1.0 * (double)numWordsSoFar.get() / (double)Word2Vec.this.totalWords));
                long now = System.currentTimeMillis();
                long diff = Math.abs(now - lastReported.get());
                if (numWordsSoFar.get() > 0L && diff > 1000L) {
                    lastReported.set(now);
                    log.info("Words so far " + numWordsSoFar.get() + " with alpha at " + alpha);
                }
                Word2Vec.this.trainSentence(sentence, nextRandom, alpha);
                numWordsSoFar.set(numWordsSoFar.get() + (long)sentence.size());
            }
        }, (ActorSystem)actorSystem);
    }

    protected void addWords(List<VocabWord> sentence, AtomicLong nextRandom, List<VocabWord> currMiniBatch) {
        for (VocabWord word : sentence) {
            if (word == null) continue;
            if (this.sample > 0.0) {
                double numDocs = this.vectorizer.index().numDocuments();
                double ran = (Math.sqrt(word.getWordFrequency() / (this.sample * numDocs)) + 1.0) * (this.sample * numDocs) / word.getWordFrequency();
                if (ran < (double)(nextRandom.get() & 0xFFFFL) / 65536.0) continue;
                currMiniBatch.add(word);
                continue;
            }
            currMiniBatch.add(word);
        }
    }

    public void setup() {
        log.info("Building binary tree");
        this.buildBinaryTree();
        log.info("Resetting weights");
        if (this.shouldReset) {
            this.resetWeights();
        }
    }

    public boolean buildVocab() {
        if (this.sentenceIter == null && this.docIter == null) {
            throw new IllegalStateException("At least one iterator is needed for model fit()");
        }
        VocabConstructor constructor = new VocabConstructor.Builder().addSource(this.sentenceIter, this.minWordFrequency).setTokenizerFactory(this.tokenizerFactory).setStopWords(this.stopWords).setTargetVocabCache(this.vocab).build();
        constructor.buildJointVocabulary(false, true);
        return false;
    }

    public void trainSentence(List<VocabWord> sentence, AtomicLong nextRandom, double alpha) {
        if (sentence == null || sentence.isEmpty()) {
            return;
        }
        for (int i = 0; i < sentence.size(); ++i) {
            nextRandom.set(nextRandom.get() * 25214903917L + 11L);
            this.skipGram(i, sentence, (int)nextRandom.get() % this.window, nextRandom, alpha);
        }
    }

    public void skipGram(int i, List<VocabWord> sentence, int b, AtomicLong nextRandom, double alpha) {
        VocabWord word = sentence.get(i);
        if (word == null || sentence.isEmpty()) {
            return;
        }
        int end = this.window * 2 + 1 - b;
        for (int a = b; a < end; ++a) {
            int c;
            if (a == this.window || (c = i - this.window + a) < 0 || c >= sentence.size()) continue;
            VocabWord lastWord = sentence.get(c);
            this.iterate(word, lastWord, nextRandom, alpha);
        }
    }

    public void iterate(VocabWord w1, VocabWord w2, AtomicLong nextRandom, double alpha) {
        this.lookupTable.iterateSample(w1, w2, nextRandom, alpha);
    }

    protected void buildBinaryTree() {
        log.info("Constructing priority queue");
        Huffman huffman = new Huffman(this.vocab().vocabWords());
        huffman.build();
        log.info("Built tree");
    }

    protected void resetWeights() {
        this.lookupTable.resetWeights();
    }

    protected void readStopWords() {
        if (this.stopWords != null) {
            return;
        }
        this.stopWords = StopWords.getStopWords();
    }

    public void setSentenceIter(SentenceIterator sentenceIter) {
        this.sentenceIter = sentenceIter;
        this.shouldReset = false;
    }

    public void resetWeightsOnSetup() {
        this.shouldReset = true;
    }

    public int getWindow() {
        return this.window;
    }

    public List<String> getStopWords() {
        return this.stopWords;
    }

    public synchronized SentenceIterator getSentenceIter() {
        return this.sentenceIter;
    }

    public TokenizerFactory getTokenizerFactory() {
        return this.tokenizerFactory;
    }

    public void setTokenizerFactory(TokenizerFactory tokenizerFactory) {
        this.tokenizerFactory = tokenizerFactory;
    }

    public Word2VecConfiguration getConfiguration() {
        return this.configuration;
    }

    private class VectorCalculationsThread
    extends Thread
    implements Runnable {
        private final int threadId;
        private final long linesLimit;
        private final int epochNumber;
        private final AtomicLong wordsCounter;
        private final long totalWordsCount;
        private final AtomicLong totalLines;
        private final LinkedBlockingQueue<List<VocabWord>> sentences;
        private final AsyncIteratorDigitizer digitizer;

        public VectorCalculationsThread(int threadId, long linesLimit, int epoch, AtomicLong wordsCounter, long totalWordsCount, AtomicLong linesCounter, LinkedBlockingQueue<List<VocabWord>> buffer, AsyncIteratorDigitizer digitizer) {
            this.threadId = threadId;
            this.linesLimit = linesLimit;
            this.epochNumber = epoch;
            this.wordsCounter = wordsCounter;
            this.totalWordsCount = totalWordsCount;
            this.totalLines = linesCounter;
            this.sentences = buffer;
            this.digitizer = digitizer;
            this.setName("VectorCalculationsThread " + this.threadId);
        }

        @Override
        public void run() {
            AtomicLong nextRandom = new AtomicLong(5L);
            while (this.digitizer.hasMoreLines() || this.sentences.size() > 0) {
                try {
                    List<VocabWord> sentence = this.sentences.poll(2L, TimeUnit.SECONDS);
                    double alpha = 0.025;
                    if (sentence == null || sentence.isEmpty()) {
                        log.warn("sentence is null");
                        continue;
                    }
                    for (int i = 0; i < Word2Vec.this.numIterations; ++i) {
                        alpha = Math.max(Word2Vec.this.minLearningRate, Word2Vec.this.alpha.get() * (1.0 - 1.0 * (double)this.wordsCounter.get() / (double)this.totalWordsCount));
                        Word2Vec.this.trainSentence(sentence, nextRandom, alpha);
                        this.wordsCounter.addAndGet(sentence.size());
                    }
                    this.totalLines.incrementAndGet();
                    if (this.totalLines.get() % 10000L != 0L) continue;
                    log.info("Epoch: " + this.epochNumber + "; Words vectorized so far: " + this.wordsCounter.get() + ";  Lines vectorized so far: " + this.totalLines.get() + "; learningRate: " + alpha);
                }
                catch (Exception e) {
                    e.printStackTrace();
                    throw new RuntimeException(e);
                }
            }
        }
    }

    private class AsyncIteratorDigitizer
    extends Thread
    implements Runnable {
        private final SentenceIterator iterator;
        private final LinkedBlockingQueue<List<VocabWord>> buffer;
        private final AtomicLong linesCounter;
        private final int limitUpper = 10000;
        private final int limitLower = 5000;
        private AtomicBoolean isRunning = new AtomicBoolean(false);

        public AsyncIteratorDigitizer(SentenceIterator iterator, LinkedBlockingQueue<List<VocabWord>> buffer, AtomicLong linesCounter) {
            this.iterator = iterator;
            this.buffer = buffer;
            this.linesCounter = linesCounter;
            this.setName("AsyncIteratorReader thread");
            this.iterator.reset();
        }

        @Override
        public void run() {
            this.isRunning.set(true);
            while (this.iterator.hasNext()) {
                if (this.buffer.size() < 5000) {
                    AtomicInteger linesLoaded = new AtomicInteger(0);
                    while (linesLoaded.getAndIncrement() < 10000 && this.iterator.hasNext()) {
                        String sentence = this.iterator.nextSentence();
                        Tokenizer tokenizer = Word2Vec.this.tokenizerFactory.create(sentence);
                        List<String> tokens = tokenizer.getTokens();
                        List<VocabWord> list = Word2Vec.this.digitizeSentence(tokens);
                        if (list != null && !list.isEmpty()) {
                            this.buffer.add(list);
                        }
                        linesLoaded.incrementAndGet();
                    }
                    continue;
                }
                try {
                    Thread.sleep(50L);
                }
                catch (Exception e) {
                    e.printStackTrace();
                }
            }
            this.isRunning.set(false);
        }

        public boolean hasMoreLines() {
            return this.buffer.size() > 0 || this.isRunning.get();
        }
    }

    public static class Builder {
        protected int minWordFrequency = 1;
        protected int layerSize = 50;
        protected SentenceIterator iter;
        protected List<String> stopWords = new ArrayList<String>();
        protected int window = 5;
        protected TokenizerFactory tokenizerFactory;
        protected VocabCache vocabCache;
        protected DocumentIterator docIter;
        protected double lr = 0.025;
        protected int iterations = 1;
        protected long seed = 123L;
        protected boolean saveVocab = false;
        protected int batchSize = 1000;
        protected int learningRateDecayWords = 10000;
        protected boolean useAdaGrad = false;
        protected TextVectorizer textVectorizer;
        protected double minLearningRate = 0.01;
        protected double negative = 0.0;
        protected double sampling = 1.0E-5;
        protected int workers = Runtime.getRuntime().availableProcessors();
        protected InvertedIndex index;
        protected WeightLookupTable lookupTable;
        protected boolean hugeModelExpected = false;
        private Word2VecConfiguration configuration = new Word2VecConfiguration();
        private boolean resetModel = true;
        private int numEpochs = 1;

        public Builder lookupTable(@NonNull WeightLookupTable lookupTable) {
            if (lookupTable == null) {
                throw new NullPointerException("lookupTable");
            }
            this.lookupTable = lookupTable;
            return this;
        }

        public Builder() {
        }

        public Builder(@NonNull Word2VecConfiguration conf) {
            if (conf == null) {
                throw new NullPointerException("conf");
            }
            this.iterations = conf.getIterations();
            this.hugeModelExpected = conf.isHugeModelExpected();
            this.useAdaGrad = conf.isUseAdaGrad();
            this.minWordFrequency = conf.getMinWordFrequency();
            this.lr = conf.getLearningRate();
            this.learningRateDecayWords = conf.getLearningRateDecayWords();
            this.negative = conf.getNegative();
            this.sampling = conf.getSampling();
            this.minLearningRate = conf.getMinLearningRate();
            this.window = conf.getWindow();
            this.seed = conf.getSeed();
            this.layerSize = conf.getLayersSize();
            this.numEpochs = conf.getEpochs();
            this.configuration = conf;
        }

        @Deprecated
        public Builder index(InvertedIndex index) {
            this.index = index;
            return this;
        }

        public Builder workers(int workers) {
            this.workers = workers;
            return this;
        }

        public Builder sampling(double sample) {
            this.sampling = sample;
            return this;
        }

        public Builder negativeSample(double negative) {
            this.negative = negative;
            return this;
        }

        public Builder minLearningRate(double minLearningRate) {
            this.minLearningRate = minLearningRate;
            return this;
        }

        public Builder useAdaGrad(boolean useAdaGrad) {
            this.useAdaGrad = useAdaGrad;
            return this;
        }

        @Deprecated
        public Builder vectorizer(TextVectorizer textVectorizer) {
            this.textVectorizer = textVectorizer;
            return this;
        }

        public Builder learningRateDecayWords(int learningRateDecayWords) {
            this.learningRateDecayWords = learningRateDecayWords;
            return this;
        }

        public Builder batchSize(int batchSize) {
            this.batchSize = batchSize;
            return this;
        }

        public Builder saveVocab(boolean saveVocab) {
            this.saveVocab = saveVocab;
            return this;
        }

        public Builder seed(long seed) {
            this.seed = seed;
            return this;
        }

        public Builder iterations(int iterations) {
            this.iterations = iterations;
            return this;
        }

        public Builder learningRate(double lr) {
            this.lr = lr;
            return this;
        }

        public Builder resetModel(boolean reset) {
            this.resetModel = reset;
            return this;
        }

        public Builder iterate(@NonNull DocumentIterator iter) {
            if (iter == null) {
                throw new NullPointerException("iter");
            }
            this.iter = new StreamLineIterator.Builder(iter).setFetchSize(100).build();
            return this;
        }

        public Builder vocabCache(@NonNull VocabCache cache) {
            if (cache == null) {
                throw new NullPointerException("cache");
            }
            this.vocabCache = cache;
            return this;
        }

        public Builder minWordFrequency(int minWordFrequency) {
            this.minWordFrequency = minWordFrequency;
            return this;
        }

        public Builder tokenizerFactory(@NonNull TokenizerFactory tokenizerFactory) {
            if (tokenizerFactory == null) {
                throw new NullPointerException("tokenizerFactory");
            }
            this.tokenizerFactory = tokenizerFactory;
            return this;
        }

        public Builder layerSize(int layerSize) {
            this.layerSize = layerSize;
            return this;
        }

        public Builder stopWords(List<String> stopWords) {
            this.stopWords = stopWords;
            return this;
        }

        public Builder windowSize(int window) {
            this.window = window;
            return this;
        }

        public Builder iterate(@NonNull SentenceIterator iter) {
            if (iter == null) {
                throw new NullPointerException("iter");
            }
            this.iter = iter;
            return this;
        }

        public Builder hugeModelExpected(boolean reallyExpected) {
            this.hugeModelExpected = reallyExpected;
            return this;
        }

        public Builder epochs(int numEpochs) {
            this.numEpochs = numEpochs;
            return this;
        }

        public Word2Vec build() {
            Word2Vec ret = new Word2Vec();
            ret.alpha.set(this.lr);
            ret.sentenceIter = this.iter;
            ret.window = this.window;
            ret.useAdaGrad = this.useAdaGrad;
            ret.minLearningRate = this.minLearningRate;
            ret.vectorizer = this.textVectorizer;
            ret.stopWords = this.stopWords;
            ret.minWordFrequency = this.minWordFrequency;
            ret.setVocab(this.vocabCache);
            ret.minWordFrequency = this.minWordFrequency;
            ret.numIterations = this.iterations;
            ret.seed = this.seed;
            ret.numIterations = this.iterations;
            ret.saveVocab = this.saveVocab;
            ret.batchSize = this.batchSize;
            ret.sample = this.sampling;
            ret.workers = this.workers;
            ret.invertedIndex = this.index;
            ret.lookupTable = this.lookupTable;
            ret.epochs = this.numEpochs;
            ret.resetModel = this.resetModel;
            try {
                if (this.tokenizerFactory == null) {
                    this.tokenizerFactory = new UimaTokenizerFactory();
                }
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
            if (this.vocabCache == null) {
                this.vocabCache = new InMemoryLookupCache();
                ret.setVocab(this.vocabCache);
            }
            if (this.lookupTable == null) {
                this.lookupTable = new InMemoryLookupTable.Builder().negative(this.negative).useAdaGrad(this.useAdaGrad).lr(this.lr).cache(this.vocabCache).vectorLength(this.layerSize).build();
            }
            ret.lookupTable = this.lookupTable;
            ret.tokenizerFactory = this.tokenizerFactory;
            ret.vocabularyHolder = this.vocabCache != null ? new VocabularyHolder.Builder().externalCache(this.vocabCache).hugeModelExpected(this.hugeModelExpected).minWordFrequency(this.minWordFrequency).scavengerActivationThreshold(this.configuration.getScavengerActivationThreshold()).scavengerRetentionDelay(this.configuration.getScavengerRetentionDelay()).build() : new VocabularyHolder.Builder().hugeModelExpected(this.hugeModelExpected).minWordFrequency(this.minWordFrequency).scavengerActivationThreshold(this.configuration.getScavengerActivationThreshold()).scavengerRetentionDelay(this.configuration.getScavengerRetentionDelay()).build();
            this.configuration.setLearningRate(this.lr);
            this.configuration.setLayersSize(this.layerSize);
            this.configuration.setHugeModelExpected(this.hugeModelExpected);
            this.configuration.setWindow(this.window);
            this.configuration.setMinWordFrequency(this.minWordFrequency);
            this.configuration.setIterations(this.iterations);
            this.configuration.setSeed(this.seed);
            this.configuration.setBatchSize(this.batchSize);
            this.configuration.setLearningRateDecayWords(this.learningRateDecayWords);
            this.configuration.setMinLearningRate(this.minLearningRate);
            this.configuration.setSampling(this.sampling);
            this.configuration.setUseAdaGrad(this.useAdaGrad);
            this.configuration.setNegative(this.negative);
            this.configuration.setEpochs(this.numEpochs);
            ret.configuration = this.configuration;
            return ret;
        }
    }
}

