/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.embeddings.wordvectors;

import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Set;
import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.text.stopwords.StopWords;
import org.deeplearning4j.util.MathUtils;
import org.deeplearning4j.util.SetUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

public class WordVectorsImpl
implements WordVectors {
    protected int minWordFrequency = 5;
    protected transient WeightLookupTable lookupTable;
    protected transient VocabCache vocab;
    protected int layerSize = 100;
    public static final String UNK = "UNK";
    protected List<String> stopWords = StopWords.getStopWords();

    @Override
    public boolean hasWord(String word) {
        return this.vocab().indexOf(word) >= 0;
    }

    @Override
    public Collection<String> wordsNearestSum(Collection<String> positive, Collection<String> negative, int top) {
        INDArray words = Nd4j.create((int)this.lookupTable().layerSize());
        Set union = SetUtils.union(new HashSet<String>(positive), new HashSet<String>(negative));
        for (String s : positive) {
            words.addi(this.lookupTable().vector(s));
        }
        for (String s : negative) {
            words.addi(this.lookupTable.vector(s).mul((Number)-1));
        }
        if (this.lookupTable() instanceof InMemoryLookupTable) {
            InMemoryLookupTable l = (InMemoryLookupTable)this.lookupTable();
            INDArray syn0 = l.getSyn0();
            INDArray weights = syn0.norm2(new int[]{0}).rdivi((Number)1).muli(words);
            INDArray distances = syn0.mulRowVector(weights).sum(new int[]{1});
            INDArray[] sorted = Nd4j.sortWithIndices((INDArray)distances, (int)0, (boolean)false);
            INDArray sort = sorted[0];
            ArrayList<String> ret = new ArrayList<String>();
            if (top > sort.length()) {
                top = sort.length();
            }
            int end = top;
            for (int i = 0; i < end; ++i) {
                String word = this.vocab.wordAtIndex(sort.getInt(new int[]{i}));
                if (union.contains(word)) {
                    if (++end < sort.length()) continue;
                    break;
                }
                String add = this.vocab().wordAtIndex(sort.getInt(new int[]{i}));
                if (add == null || add.equals(UNK) || add.equals("STOP")) {
                    if (++end < sort.length()) continue;
                    break;
                }
                ret.add(this.vocab().wordAtIndex(sort.getInt(new int[]{i})));
            }
            return ret;
        }
        Counter distances = new Counter();
        for (String s : this.vocab().words()) {
            INDArray otherVec = this.getWordVectorMatrix(s);
            double sim = Transforms.cosineSim((INDArray)words, (INDArray)otherVec);
            distances.incrementCount((Object)s, sim);
        }
        distances.keepTopNKeys(top);
        return distances.keySet();
    }

    @Override
    public Collection<String> wordsNearestSum(INDArray words, int top) {
        if (this.lookupTable() instanceof InMemoryLookupTable) {
            InMemoryLookupTable l = (InMemoryLookupTable)this.lookupTable();
            INDArray syn0 = l.getSyn0();
            INDArray weights = syn0.norm2(new int[]{0}).rdivi((Number)1).muli(words);
            INDArray distances = syn0.mulRowVector(weights).sum(new int[]{1});
            INDArray[] sorted = Nd4j.sortWithIndices((INDArray)distances, (int)0, (boolean)false);
            INDArray sort = sorted[0];
            ArrayList<String> ret = new ArrayList<String>();
            if (top > sort.length()) {
                top = sort.length();
            }
            int end = top;
            for (int i = 0; i < end; ++i) {
                String add = this.vocab().wordAtIndex(sort.getInt(new int[]{i}));
                if (add == null || add.equals(UNK) || add.equals("STOP")) {
                    if (++end < sort.length()) continue;
                    break;
                }
                ret.add(this.vocab().wordAtIndex(sort.getInt(new int[]{i})));
            }
            return ret;
        }
        Counter distances = new Counter();
        for (String s : this.vocab().words()) {
            INDArray otherVec = this.getWordVectorMatrix(s);
            double sim = Transforms.cosineSim((INDArray)words, (INDArray)otherVec);
            distances.incrementCount((Object)s, sim);
        }
        distances.keepTopNKeys(top);
        return distances.keySet();
    }

    @Override
    public Collection<String> wordsNearest(INDArray words, int top) {
        if (this.lookupTable() instanceof InMemoryLookupTable) {
            InMemoryLookupTable l = (InMemoryLookupTable)this.lookupTable();
            INDArray syn0 = l.getSyn0();
            INDArray weights = syn0.norm2(new int[]{0}).rdivi((Number)1).muli(words);
            INDArray distances = syn0.mulRowVector(weights).mean(new int[]{1});
            INDArray[] sorted = Nd4j.sortWithIndices((INDArray)distances, (int)0, (boolean)false);
            INDArray sort = sorted[0];
            ArrayList<String> ret = new ArrayList<String>();
            if (top > sort.length()) {
                top = sort.length();
            }
            int end = top;
            for (int i = 0; i < end; ++i) {
                int s;
                VocabCache vocabCache = this.vocab();
                String add = vocabCache.wordAtIndex(s = sort.getInt(new int[]{0, i}));
                if (add == null || add.equals(UNK) || add.equals("STOP")) {
                    if (++end < sort.length()) continue;
                    break;
                }
                ret.add(vocabCache.wordAtIndex(s));
            }
            return ret;
        }
        Counter distances = new Counter();
        for (String s : this.vocab().words()) {
            INDArray otherVec = this.getWordVectorMatrix(s);
            double sim = Transforms.cosineSim((INDArray)words, (INDArray)otherVec);
            distances.incrementCount((Object)s, sim);
        }
        distances.keepTopNKeys(top);
        return distances.keySet();
    }

    @Override
    public Collection<String> wordsNearestSum(String word, int n) {
        INDArray vec = Transforms.unitVec((INDArray)this.getWordVectorMatrix(word));
        if (this.lookupTable() instanceof InMemoryLookupTable) {
            InMemoryLookupTable l = (InMemoryLookupTable)this.lookupTable();
            INDArray syn0 = l.getSyn0();
            INDArray weights = syn0.norm2(new int[]{0}).rdivi((Number)1).muli(vec);
            INDArray distances = syn0.mulRowVector(weights).sum(new int[]{1});
            INDArray[] sorted = Nd4j.sortWithIndices((INDArray)distances, (int)0, (boolean)false);
            INDArray sort = sorted[0];
            ArrayList<String> ret = new ArrayList<String>();
            VocabWord word2 = this.vocab().wordFor(word);
            if (n > sort.length()) {
                n = sort.length();
            }
            for (int i = 0; i < n + 1; ++i) {
                String add;
                if (sort.getInt(new int[]{i}) == word2.getIndex() || (add = this.vocab().wordAtIndex(sort.getInt(new int[]{i}))) == null || add.equals(UNK) || add.equals("STOP")) continue;
                ret.add(this.vocab().wordAtIndex(sort.getInt(new int[]{i})));
            }
            return ret;
        }
        if (vec == null) {
            return new ArrayList<String>();
        }
        Counter distances = new Counter();
        for (String s : this.vocab().words()) {
            if (s.equals(word)) continue;
            INDArray otherVec = this.getWordVectorMatrix(s);
            double sim = Transforms.cosineSim((INDArray)vec, (INDArray)otherVec);
            distances.incrementCount((Object)s, sim);
        }
        distances.keepTopNKeys(n);
        return distances.keySet();
    }

    @Override
    public Map<String, Double> accuracy(List<String> questions) {
        HashMap<String, Double> accuracy = new HashMap<String, Double>();
        Counter right = new Counter();
        for (String s : questions) {
            if (s.startsWith(":")) {
                double correct = right.getCount((Object)"correct");
                double wrong = right.getCount((Object)"wrong");
                double accuracyRet = 100.0 * correct / (correct / wrong);
                accuracy.put(s, accuracyRet);
                right.clear();
                continue;
            }
            String[] split = s.split(" ");
            String word = split[0];
            List<String> positive = Arrays.asList(word);
            String predicted = split[3];
            List<String> negative = Arrays.asList(split[1], split[2]);
            String w = this.wordsNearest(positive, negative, 1).iterator().next();
            if (predicted.equals(w)) {
                right.incrementCount((Object)"right", 1.0);
                continue;
            }
            right.incrementCount((Object)"wrong", 1.0);
        }
        return accuracy;
    }

    @Override
    public int indexOf(String word) {
        return this.vocab().indexOf(word);
    }

    @Override
    public List<String> similarWordsInVocabTo(String word, double accuracy) {
        ArrayList<String> ret = new ArrayList<String>();
        for (String s : this.vocab.words()) {
            String[] stringArray = new String[]{word, s};
            if (!(MathUtils.stringSimilarity((String[])stringArray) >= accuracy)) continue;
            ret.add(s);
        }
        return ret;
    }

    @Override
    public double[] getWordVector(String word) {
        int i = this.vocab().indexOf(word);
        if (i < 0) {
            return this.lookupTable.vector(UNK).dup().data().asDouble();
        }
        return this.lookupTable.vector(word).dup().data().asDouble();
    }

    @Override
    public INDArray getWordVectorMatrixNormalized(String word) {
        int i = this.vocab().indexOf(word);
        if (i < 0) {
            return this.lookupTable().vector(UNK);
        }
        INDArray r = this.lookupTable().vector(word);
        return r.div((Number)Nd4j.getBlasWrapper().nrm2(r));
    }

    @Override
    public INDArray getWordVectorMatrix(String word) {
        return this.lookupTable().vector(word);
    }

    @Override
    public Collection<String> wordsNearest(Collection<String> positive, Collection<String> negative, int top) {
        INDArray mean;
        for (String p : SetUtils.union(new HashSet<String>(positive), new HashSet<String>(negative))) {
            if (this.vocab().containsWord(p)) continue;
            return new ArrayList<String>();
        }
        WeightLookupTable weightLookupTable = this.lookupTable();
        INDArray words = Nd4j.create((int)(positive.size() + negative.size()), (int)weightLookupTable.layerSize());
        int row = 0;
        Set union = SetUtils.union(new HashSet<String>(positive), new HashSet<String>(negative));
        for (String s : positive) {
            words.putRow(row++, weightLookupTable.vector(s));
        }
        for (String s : negative) {
            words.putRow(row++, weightLookupTable.vector(s).mul((Number)-1));
        }
        INDArray iNDArray = mean = words.isMatrix() ? words.mean(new int[]{0}) : words;
        if (weightLookupTable instanceof InMemoryLookupTable) {
            InMemoryLookupTable l = (InMemoryLookupTable)weightLookupTable;
            INDArray syn0 = l.getSyn0();
            syn0.diviRowVector(syn0.norm2(new int[]{0}));
            INDArray similarity = Transforms.unitVec((INDArray)mean).mmul(syn0.transpose());
            List<Double> highToLowSimList = WordVectorsImpl.getTopN(similarity, top + union.size());
            ArrayList<String> ret = new ArrayList<String>();
            for (int i = 0; i < highToLowSimList.size(); ++i) {
                String word = this.vocab().wordAtIndex(highToLowSimList.get(i).intValue());
                if (word == null || word.equals(UNK) || word.equals("STOP") || union.contains(word)) continue;
                ret.add(word);
                if (ret.size() >= top) break;
            }
            return ret;
        }
        Counter distances = new Counter();
        for (String s : this.vocab().words()) {
            INDArray otherVec = this.getWordVectorMatrix(s);
            double sim = Transforms.cosineSim((INDArray)mean, (INDArray)otherVec);
            distances.incrementCount((Object)s, sim);
        }
        distances.keepTopNKeys(top);
        return distances.keySet();
    }

    private static List<Double> getTopN(INDArray vec, int N) {
        ArrayComparator comparator = new ArrayComparator();
        PriorityQueue<Double[]> queue = new PriorityQueue<Double[]>(vec.rows(), comparator);
        for (int j = 0; j < vec.length(); ++j) {
            Double[] pair = new Double[]{vec.getDouble(j), j};
            if (queue.size() < N) {
                queue.add(pair);
                continue;
            }
            Double[] head = queue.peek();
            if (comparator.compare(pair, head) <= 0) continue;
            queue.poll();
            queue.add(pair);
        }
        ArrayList<Double> lowToHighSimLst = new ArrayList<Double>();
        while (!queue.isEmpty()) {
            double ind = queue.poll()[1];
            lowToHighSimLst.add(ind);
        }
        return Lists.reverse(lowToHighSimLst);
    }

    @Override
    public Collection<String> wordsNearest(String word, int n) {
        return this.wordsNearest(Arrays.asList(word), new ArrayList<String>(), n);
    }

    @Override
    public double similarity(String word, String word2) {
        if (word.equals(word2)) {
            return 1.0;
        }
        INDArray vector = Transforms.unitVec((INDArray)this.getWordVectorMatrix(word));
        INDArray vector2 = Transforms.unitVec((INDArray)this.getWordVectorMatrix(word2));
        if (vector == null || vector2 == null) {
            return -1.0;
        }
        return Nd4j.getBlasWrapper().dot(vector, vector2);
    }

    @Override
    public VocabCache vocab() {
        return this.vocab;
    }

    @Override
    public WeightLookupTable lookupTable() {
        return this.lookupTable;
    }

    public void setLookupTable(WeightLookupTable lookupTable) {
        this.lookupTable = lookupTable;
    }

    public void setVocab(VocabCache vocab) {
        this.vocab = vocab;
    }

    public int getMinWordFrequency() {
        return this.minWordFrequency;
    }

    public WeightLookupTable getLookupTable() {
        return this.lookupTable;
    }

    public VocabCache getVocab() {
        return this.vocab;
    }

    public int getLayerSize() {
        return this.layerSize;
    }

    private static class ArrayComparator
    implements Comparator<Double[]> {
        private ArrayComparator() {
        }

        @Override
        public int compare(Double[] o1, Double[] o2) {
            return Double.compare(o1[0], o2[0]);
        }
    }
}

