/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.models.embeddings.glove.cooccurrences;

import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.berkeley.CounterMap;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.nd4j.linalg.factory.Nd4j;

public class CoOccurrenceCalculator
implements Function<Pair<List<String>, AtomicLong>, CounterMap<String, String>> {
    private boolean symmetric = false;
    private Broadcast<VocabCache> vocab;
    private int windowSize = 5;

    public CoOccurrenceCalculator(boolean symmetric, Broadcast<VocabCache> vocab, int windowSize) {
        this.symmetric = symmetric;
        this.vocab = vocab;
        this.windowSize = windowSize;
    }

    public CounterMap<String, String> call(Pair<List<String>, AtomicLong> pair) throws Exception {
        List sentence = (List)pair.getFirst();
        CounterMap coOCurreneCounts = new CounterMap();
        VocabCache vocab = (VocabCache)this.vocab.value();
        for (int i = 0; i < sentence.size(); ++i) {
            int wordIdx = vocab.indexOf((String)sentence.get(i));
            String w1 = vocab.wordFor((String)sentence.get(i)).getWord();
            if (wordIdx < 0 || w1.equals("UNK")) continue;
            int windowStop = Math.min(i + this.windowSize + 1, sentence.size());
            for (int j = i; j < windowStop; ++j) {
                int otherWord = vocab.indexOf((String)sentence.get(j));
                String w2 = vocab.wordFor((String)sentence.get(j)).getWord();
                if (vocab.indexOf((String)sentence.get(j)) < 0 || w2.equals("UNK") || otherWord == wordIdx) continue;
                if (wordIdx < otherWord) {
                    coOCurreneCounts.incrementCount(sentence.get(i), sentence.get(j), 1.0 / ((double)(j - i) + Nd4j.EPS_THRESHOLD));
                    if (!this.symmetric) continue;
                    coOCurreneCounts.incrementCount(sentence.get(j), sentence.get(i), 1.0 / ((double)(j - i) + Nd4j.EPS_THRESHOLD));
                    continue;
                }
                coOCurreneCounts.incrementCount(sentence.get(j), sentence.get(i), 1.0 / ((double)(j - i) + Nd4j.EPS_THRESHOLD));
                if (!this.symmetric) continue;
                coOCurreneCounts.incrementCount(sentence.get(i), sentence.get(j), 1.0 / ((double)(j - i) + Nd4j.EPS_THRESHOLD));
            }
        }
        return coOCurreneCounts;
    }
}

