/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.text.functions;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.spark.Accumulator;
import org.apache.spark.AccumulatorParam;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.InMemoryLookupCache;
import org.deeplearning4j.spark.text.accumulators.WordFreqAccumulator;
import org.deeplearning4j.spark.text.functions.GetSentenceCountFunction;
import org.deeplearning4j.spark.text.functions.ReduceSentenceCount;
import org.deeplearning4j.spark.text.functions.TokenizerFunction;
import org.deeplearning4j.spark.text.functions.UpdateWordFreqAccumulatorFunction;
import org.deeplearning4j.spark.text.functions.WordsListToVocabWordsFunction;
import org.deeplearning4j.text.stopwords.StopWords;

public class TextPipeline {
    private JavaRDD<String> corpusRDD;
    private int numWords;
    private int nGrams;
    private String tokenizer;
    private String tokenizerPreprocessor;
    private List<String> stopWords = new ArrayList<String>();
    private JavaSparkContext sc;
    private Accumulator<Counter<String>> wordFreqAcc;
    private Broadcast<List<String>> stopWordBroadCast;
    private JavaRDD<Pair<List<String>, AtomicLong>> sentenceWordsCountRDD;
    private VocabCache vocabCache = new InMemoryLookupCache();
    private Broadcast<VocabCache> vocabCacheBroadcast;
    private JavaRDD<List<VocabWord>> vocabWordListRDD;
    private JavaRDD<AtomicLong> sentenceCountRDD;
    private long totalWordCount;

    public TextPipeline() {
    }

    public TextPipeline(JavaRDD<String> corpusRDD, Broadcast<Map<String, Object>> broadcasTokenizerVarMap) throws Exception {
        this.setRDDVarMap(corpusRDD, broadcasTokenizerVarMap);
        this.setup();
    }

    public void setRDDVarMap(JavaRDD<String> corpusRDD, Broadcast<Map<String, Object>> broadcasTokenizerVarMap) {
        Map tokenizerVarMap = (Map)broadcasTokenizerVarMap.getValue();
        this.corpusRDD = corpusRDD;
        this.numWords = (Integer)tokenizerVarMap.get("numWords");
        this.nGrams = (Integer)tokenizerVarMap.get("nGrams");
        this.tokenizer = (String)tokenizerVarMap.get("tokenizer");
        this.tokenizerPreprocessor = (String)tokenizerVarMap.get("tokenPreprocessor");
        if (((Boolean)tokenizerVarMap.get("removeStop")).booleanValue()) {
            this.stopWords = StopWords.getStopWords();
        }
    }

    private void setup() {
        this.sc = new JavaSparkContext(this.corpusRDD.context());
        this.wordFreqAcc = this.sc.accumulator((Object)new Counter(), (AccumulatorParam)new WordFreqAccumulator());
        this.stopWordBroadCast = this.sc.broadcast(this.stopWords);
    }

    public JavaRDD<List<String>> tokenize() {
        if (this.corpusRDD == null) {
            throw new IllegalStateException("corpusRDD not assigned. Define TextPipeline with corpusRDD assigned.");
        }
        return this.corpusRDD.map((Function)new TokenizerFunction(this.tokenizer, this.tokenizerPreprocessor, this.nGrams));
    }

    public JavaRDD<Pair<List<String>, AtomicLong>> updateAndReturnAccumulatorVal(JavaRDD<List<String>> tokenizedRDD) {
        UpdateWordFreqAccumulatorFunction accumulatorClassFunction = new UpdateWordFreqAccumulatorFunction(this.stopWordBroadCast, this.wordFreqAcc);
        JavaRDD sentenceWordsCountRDD = tokenizedRDD.map((Function)accumulatorClassFunction);
        sentenceWordsCountRDD.count();
        return sentenceWordsCountRDD;
    }

    private String filterMinWord(String stringToken, double tokenCount) {
        return tokenCount < (double)this.numWords ? "UNK" : stringToken;
    }

    private void addTokenToVocabCache(String stringToken, Double tokenCount) {
        VocabWord actualToken;
        if (this.vocabCache.hasToken(stringToken)) {
            actualToken = this.vocabCache.tokenFor(stringToken);
            actualToken.increment(tokenCount.intValue());
        } else {
            actualToken = new VocabWord(tokenCount.doubleValue(), stringToken);
        }
        boolean vocabContainsWord = this.vocabCache.containsWord(stringToken);
        if (!vocabContainsWord) {
            this.vocabCache.addToken(actualToken);
            int idx = this.vocabCache.numWords();
            actualToken.setIndex(idx);
            this.vocabCache.putVocabWord(stringToken);
        }
    }

    public void filterMinWordAddVocab(Counter<String> wordFreq) {
        if (wordFreq.size() == 0) {
            throw new IllegalStateException("IllegalStateException: wordFreqCounter has nothing. Check accumulator updating");
        }
        for (Map.Entry entry : wordFreq.entrySet()) {
            String stringToken = (String)entry.getKey();
            Double tokenCount = (Double)entry.getValue();
            stringToken = this.filterMinWord(stringToken, tokenCount);
            this.addTokenToVocabCache(stringToken, tokenCount);
        }
    }

    public void buildVocabCache() {
        JavaRDD<List<String>> tokenizedRDD = this.tokenize();
        this.sentenceWordsCountRDD = this.updateAndReturnAccumulatorVal(tokenizedRDD).cache();
        Counter wordFreqCounter = (Counter)this.wordFreqAcc.value();
        this.filterMinWordAddVocab((Counter<String>)wordFreqCounter);
        this.vocabCacheBroadcast = this.sc.broadcast((Object)this.vocabCache);
    }

    public void buildVocabWordListRDD() {
        if (this.sentenceWordsCountRDD == null) {
            throw new IllegalStateException("SentenceWordCountRDD must be defined first. Run buildLookupCache first.");
        }
        this.vocabWordListRDD = this.sentenceWordsCountRDD.map((Function)new WordsListToVocabWordsFunction(this.vocabCacheBroadcast)).setName("vocabWordListRDD").cache();
        this.sentenceCountRDD = this.sentenceWordsCountRDD.map((Function)new GetSentenceCountFunction()).setName("sentenceCountRDD").cache();
        this.vocabWordListRDD.count();
        this.totalWordCount = ((AtomicLong)this.sentenceCountRDD.reduce((Function2)new ReduceSentenceCount())).get();
        this.sentenceWordsCountRDD.unpersist();
    }

    public Accumulator<Counter<String>> getWordFreqAcc() {
        if (this.wordFreqAcc != null) {
            return this.wordFreqAcc;
        }
        throw new IllegalStateException("IllegalStateException: wordFreqAcc not set at TextPipline.");
    }

    public Broadcast<VocabCache> getBroadCastVocabCache() throws IllegalStateException {
        if (this.vocabCache.numWords() > 0) {
            return this.vocabCacheBroadcast;
        }
        throw new IllegalStateException("IllegalStateException: VocabCache not set at TextPipline.");
    }

    public VocabCache getVocabCache() throws IllegalStateException {
        if (this.vocabCache.numWords() > 0) {
            return this.vocabCache;
        }
        throw new IllegalStateException("IllegalStateException: VocabCache not set at TextPipline.");
    }

    public JavaRDD<Pair<List<String>, AtomicLong>> getSentenceWordsCountRDD() {
        if (this.sentenceWordsCountRDD != null) {
            return this.sentenceWordsCountRDD;
        }
        throw new IllegalStateException("IllegalStateException: sentenceWordsCountRDD not set at TextPipline.");
    }

    public JavaRDD<List<VocabWord>> getVocabWordListRDD() throws IllegalStateException {
        if (this.vocabWordListRDD != null) {
            return this.vocabWordListRDD;
        }
        throw new IllegalStateException("IllegalStateException: vocabWordListRDD not set at TextPipline.");
    }

    public JavaRDD<AtomicLong> getSentenceCountRDD() throws IllegalStateException {
        if (this.sentenceCountRDD != null) {
            return this.sentenceCountRDD;
        }
        throw new IllegalStateException("IllegalStateException: sentenceCountRDD not set at TextPipline.");
    }

    public Long getTotalWordCount() {
        if (this.totalWordCount != 0L) {
            return this.totalWordCount;
        }
        throw new IllegalStateException("IllegalStateException: totalWordCount not set at TextPipline.");
    }
}

