/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.glove;

import akka.actor.ActorSystem;
import com.google.common.collect.Lists;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Random;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.LineIterator;
import org.deeplearning4j.bagofwords.vectorizer.TextVectorizer;
import org.deeplearning4j.bagofwords.vectorizer.TfidfVectorizer;
import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectorsImpl;
import org.deeplearning4j.models.glove.CoOccurrences;
import org.deeplearning4j.models.glove.GloveWeightLookupTable;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.InMemoryLookupCache;
import org.deeplearning4j.parallel.Parallelization;
import org.deeplearning4j.text.invertedindex.LuceneInvertedIndex;
import org.deeplearning4j.text.movingwindow.Util;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.stopwords.StopWords;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Glove
extends WordVectorsImpl {
    private transient SentenceIterator sentenceIterator;
    private transient TextVectorizer textVectorizer;
    private transient TokenizerFactory tokenizerFactory;
    private double learningRate = 0.05;
    private double xMax = 0.75;
    private int windowSize = 15;
    private CoOccurrences coOccurrences;
    private boolean stem = false;
    protected Queue<Pair<Integer, List<Pair<VocabWord, VocabWord>>>> jobQueue = new LinkedBlockingDeque<Pair<Integer, List<Pair<VocabWord, VocabWord>>>>();
    private int batchSize = 1000;
    private int minWordFrequency = 5;
    private double maxCount = 100.0;
    public static final String UNK = "UNK";
    private int iterations = 5;
    private static final Logger log = LoggerFactory.getLogger(Glove.class);
    private boolean symmetric = true;
    private transient org.nd4j.linalg.api.rng.Random gen;
    private boolean shuffle = true;
    private transient org.nd4j.linalg.api.rng.Random shuffleRandom;
    private int numWorkers = Runtime.getRuntime().availableProcessors();

    private Glove() {
    }

    public Glove(VocabCache cache, SentenceIterator sentenceIterator, TextVectorizer textVectorizer, TokenizerFactory tokenizerFactory, GloveWeightLookupTable lookupTable, int layerSize, double learningRate, double xMax, int windowSize, CoOccurrences coOccurrences, List<String> stopWords, boolean stem, int batchSize, int minWordFrequency, double maxCount, int iterations, boolean symmetric, org.nd4j.linalg.api.rng.Random gen, boolean shuffle, long seed, int numWorkers) {
        this.numWorkers = numWorkers;
        this.gen = gen;
        this.vocab = cache;
        this.layerSize = layerSize;
        this.shuffle = shuffle;
        this.sentenceIterator = sentenceIterator;
        this.textVectorizer = textVectorizer;
        this.tokenizerFactory = tokenizerFactory;
        this.lookupTable = lookupTable;
        this.learningRate = learningRate;
        this.xMax = xMax;
        this.windowSize = windowSize;
        this.coOccurrences = coOccurrences;
        this.stopWords = stopWords;
        this.stem = stem;
        this.batchSize = batchSize;
        this.minWordFrequency = minWordFrequency;
        this.maxCount = maxCount;
        this.iterations = iterations;
        this.symmetric = symmetric;
        this.shuffleRandom = Nd4j.getRandom();
    }

    public void fit() {
        boolean cacheFresh = false;
        if (this.vocab() == null) {
            cacheFresh = true;
            this.setVocab(new InMemoryLookupCache());
        }
        if (this.textVectorizer == null && cacheFresh) {
            LuceneInvertedIndex index = new LuceneInvertedIndex(this.vocab(), false, "glove-index");
            this.textVectorizer = new TfidfVectorizer.Builder().tokenize(this.tokenizerFactory).index(index).cache(this.vocab()).iterate(this.sentenceIterator).minWords(this.minWordFrequency).stopWords(this.stopWords).stem(this.stem).build();
            this.textVectorizer.fit();
        }
        if (this.sentenceIterator != null) {
            this.sentenceIterator.reset();
        }
        if (this.coOccurrences == null) {
            this.coOccurrences = new CoOccurrences.Builder().cache(this.vocab()).iterate(this.sentenceIterator).symmetric(this.symmetric).tokenizer(this.tokenizerFactory).windowSize(this.windowSize).build();
            this.coOccurrences.fit();
        }
        if (this.lookupTable == null) {
            this.lookupTable = new GloveWeightLookupTable.Builder().cache(this.textVectorizer.vocab()).lr(this.learningRate).vectorLength(this.layerSize).maxCount(this.maxCount).build();
        }
        if (this.lookupTable().getSyn0() == null) {
            this.lookupTable().resetWeights();
        }
        List<Pair<String, String>> pairList = this.coOccurrences.coOccurrenceList();
        if (this.shuffle) {
            Collections.shuffle(pairList, new Random());
        }
        AtomicInteger countUp = new AtomicInteger(0);
        Counter errorPerIteration = Util.parallelCounter();
        log.info("Processing # of co occurrences " + this.coOccurrences.numCoOccurrences());
        for (int i = 0; i < this.iterations; ++i) {
            AtomicInteger processed = new AtomicInteger(this.coOccurrences.numCoOccurrences());
            this.doIteration(i, pairList, errorPerIteration, processed, countUp);
            log.info("Processed " + countUp.doubleValue() + " out of " + pairList.size() * this.iterations + " error was " + errorPerIteration.getCount((Object)i));
        }
    }

    public void doIteration(final int i, List<Pair<String, String>> pairList, final Counter<Integer> errorPerIteration, final AtomicInteger processed, final AtomicInteger countUp) {
        log.info("Iteration " + i);
        if (this.shuffle) {
            Collections.shuffle(pairList, new Random());
        }
        List miniBatches = Lists.partition(pairList, (int)this.batchSize);
        ActorSystem actor = ActorSystem.create();
        Parallelization.iterateInParallel((Collection)miniBatches, (Parallelization.RunnableWithParams)new Parallelization.RunnableWithParams<List<Pair<String, String>>>(){

            public void run(List<Pair<String, String>> currentItem, Object[] args) {
                ArrayList<Pair> send = new ArrayList<Pair>();
                for (Pair<String, String> next : currentItem) {
                    String w1 = (String)next.getFirst();
                    String w2 = (String)next.getSecond();
                    VocabWord vocabWord = Glove.this.vocab().wordFor(w1);
                    VocabWord vocabWord1 = Glove.this.vocab().wordFor(w2);
                    send.add(new Pair((Object)vocabWord, (Object)vocabWord1));
                }
                Glove.this.jobQueue.add((Pair<Integer, List<Pair<VocabWord, VocabWord>>>)new Pair((Object)i, send));
            }
        }, (ActorSystem)actor);
        actor.shutdown();
        Parallelization.runInParallel((int)this.numWorkers, (Runnable)new Runnable(){

            @Override
            public void run() {
                while (processed.get() > 0 || !Glove.this.jobQueue.isEmpty()) {
                    Pair<Integer, List<Pair<VocabWord, VocabWord>>> work = Glove.this.jobQueue.poll();
                    if (work == null) continue;
                    List batch = (List)work.getSecond();
                    for (Pair pair : batch) {
                        VocabWord w1 = (VocabWord)pair.getFirst();
                        VocabWord w2 = (VocabWord)pair.getSecond();
                        double weight = Glove.this.getCount(w1.getWord(), w2.getWord());
                        if (weight <= 0.0) {
                            countUp.incrementAndGet();
                            processed.decrementAndGet();
                            continue;
                        }
                        errorPerIteration.incrementCount(work.getFirst(), Glove.this.lookupTable().iterateSample(w1, w2, weight));
                        countUp.incrementAndGet();
                        if (countUp.get() % 10000 == 0) {
                            log.info("Processed " + countUp.get() + " co occurrences");
                        }
                        processed.decrementAndGet();
                    }
                }
            }
        }, (boolean)true);
    }

    public static Glove load(InputStream is, InputStream biases) throws IOException {
        LineIterator iter = IOUtils.lineIterator((InputStream)is, (String)"UTF-8");
        Glove glove = new Glove();
        HashMap<String, float[]> wordVectors = new HashMap<String, float[]>();
        int count = 0;
        while (iter.hasNext()) {
            float[] read;
            String line = iter.nextLine().trim();
            if (line.isEmpty()) continue;
            String[] split = line.split(" ");
            String word = split[0];
            if (glove.vocab() == null) {
                glove.setVocab(new InMemoryLookupCache());
            }
            if (glove.lookupTable() == null) {
                glove.lookupTable = new GloveWeightLookupTable.Builder().cache(glove.vocab()).vectorLength(split.length - 1).build();
            }
            if (word.isEmpty() || (read = Glove.read(split, glove.lookupTable().getVectorLength())).length < 1) continue;
            VocabWord w1 = new VocabWord(1.0, word);
            w1.setIndex(count);
            glove.vocab().addToken(w1);
            glove.vocab().addWordToIndex(count, word);
            glove.vocab().putVocabWord(word);
            wordVectors.put(word, read);
            ++count;
        }
        glove.lookupTable().setSyn0(Glove.weights(glove, wordVectors));
        iter.close();
        glove.lookupTable().setBias(Nd4j.read((InputStream)biases));
        return glove;
    }

    private static INDArray weights(Glove glove, Map<String, float[]> data) {
        INDArray ret = Nd4j.create((int)data.size(), (int)glove.lookupTable().getVectorLength());
        for (String key : data.keySet()) {
            INDArray row = Nd4j.create((DataBuffer)Nd4j.createBuffer((float[])data.get(key)));
            if (row.length() != glove.lookupTable().getVectorLength() || glove.vocab().indexOf(key) >= data.size()) continue;
            ret.putRow(glove.vocab().indexOf(key), row);
        }
        return ret;
    }

    private static float[] read(String[] split, int length) {
        float[] ret = new float[length];
        for (int i = 1; i < split.length; ++i) {
            ret[i - 1] = Float.parseFloat(split[i]);
        }
        return ret;
    }

    public double getCount(String w1, String w2) {
        return this.coOccurrences.getCoOCurreneCounts().getCount((Object)w1, (Object)w2);
    }

    public CoOccurrences getCoOccurrences() {
        return this.coOccurrences;
    }

    public void setCoOccurrences(CoOccurrences coOccurrences) {
        this.coOccurrences = coOccurrences;
    }

    @Override
    public GloveWeightLookupTable lookupTable() {
        return (GloveWeightLookupTable)this.lookupTable;
    }

    public void setLookupTable(GloveWeightLookupTable lookupTable) {
        this.lookupTable = lookupTable;
    }

    public static class Builder {
        private VocabCache vocabCache;
        private SentenceIterator sentenceIterator;
        private TextVectorizer textVectorizer;
        private TokenizerFactory tokenizerFactory = new DefaultTokenizerFactory();
        private GloveWeightLookupTable weightLookupTable;
        private int layerSize = 300;
        private double learningRate = 0.05;
        private double xMax = 0.75;
        private int windowSize = 5;
        private CoOccurrences coOccurrences;
        private List<String> stopWords = StopWords.getStopWords();
        private boolean stem = false;
        private int batchSize = 100;
        private int minWordFrequency = 5;
        private double maxCount = 100.0;
        private int iterations = 5;
        private boolean symmetric = true;
        private boolean shuffle = true;
        private long seed = 123L;
        private int numWorkers = Runtime.getRuntime().availableProcessors();
        private org.nd4j.linalg.api.rng.Random gen = Nd4j.getRandom();

        public Builder numWorkers(int numWorkers) {
            this.numWorkers = numWorkers;
            return this;
        }

        public Builder seed(long seed) {
            this.seed = seed;
            return this;
        }

        public Builder shuffle(boolean shuffle) {
            this.shuffle = shuffle;
            return this;
        }

        public Builder rng(org.nd4j.linalg.api.rng.Random gen) {
            this.gen = gen;
            return this;
        }

        public Builder symmetric(boolean symmetric) {
            this.symmetric = symmetric;
            return this;
        }

        public Builder iterations(int iterations) {
            this.iterations = iterations;
            return this;
        }

        public Builder maxCount(double maxCount) {
            this.maxCount = maxCount;
            return this;
        }

        public Builder minWordFrequency(int minWordFrequency) {
            this.minWordFrequency = minWordFrequency;
            return this;
        }

        public Builder cache(VocabCache vocabCache) {
            this.vocabCache = vocabCache;
            return this;
        }

        public Builder iterate(SentenceIterator sentenceIterator) {
            this.sentenceIterator = sentenceIterator;
            return this;
        }

        public Builder vectorizer(TextVectorizer textVectorizer) {
            this.textVectorizer = textVectorizer;
            return this;
        }

        public Builder tokenizer(TokenizerFactory tokenizerFactory) {
            this.tokenizerFactory = tokenizerFactory;
            return this;
        }

        public Builder weights(GloveWeightLookupTable weightLookupTable) {
            this.weightLookupTable = weightLookupTable;
            return this;
        }

        public Builder layerSize(int layerSize) {
            this.layerSize = layerSize;
            return this;
        }

        public Builder learningRate(double learningRate) {
            this.learningRate = learningRate;
            return this;
        }

        public Builder xMax(double xMax) {
            this.xMax = xMax;
            return this;
        }

        public Builder windowSize(int windowSize) {
            this.windowSize = windowSize;
            return this;
        }

        public Builder coOccurrences(CoOccurrences coOccurrences) {
            this.coOccurrences = coOccurrences;
            return this;
        }

        public Builder stopWords(List<String> stopWords) {
            this.stopWords = stopWords;
            return this;
        }

        public Builder stem(boolean stem) {
            this.stem = stem;
            return this;
        }

        public Builder batchSize(int batchSize) {
            this.batchSize = batchSize;
            return this;
        }

        public Glove build() {
            return new Glove(this.vocabCache, this.sentenceIterator, this.textVectorizer, this.tokenizerFactory, this.weightLookupTable, this.layerSize, this.learningRate, this.xMax, this.windowSize, this.coOccurrences, this.stopWords, this.stem, this.batchSize, this.minWordFrequency, this.maxCount, this.iterations, this.symmetric, this.gen, this.shuffle, this.seed, this.numWorkers);
        }
    }
}

