/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.models.embeddings.word2vec;

import java.io.Serializable;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.commons.math3.util.FastMath;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectorsImpl;
import org.deeplearning4j.models.word2vec.Huffman;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.spark.models.embeddings.word2vec.FirstIterationFunction;
import org.deeplearning4j.spark.models.embeddings.word2vec.MapToPairFunction;
import org.deeplearning4j.spark.text.functions.CountCumSum;
import org.deeplearning4j.spark.text.functions.TextPipeline;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Word2Vec
extends WordVectorsImpl
implements Serializable {
    private INDArray trainedSyn1;
    private static Logger log = LoggerFactory.getLogger(Word2Vec.class);
    private int MAX_EXP = 6;
    private double[] expTable;
    private int vectorLength = 100;
    private boolean useAdaGrad = false;
    private int negative = 0;
    private int numWords = 1;
    private int window = 5;
    private double alpha = 0.025;
    private double minAlpha = 1.0E-4;
    private int iterations = 1;
    private int nGrams = 1;
    private String tokenizer = "org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory";
    private String tokenPreprocessor = "org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor";
    private boolean removeStop = false;
    private long seed = 42L;

    public Word2Vec(INDArray trainedSyn1) {
        this.trainedSyn1 = trainedSyn1;
        this.expTable = this.initExpTable();
    }

    public Word2Vec() {
        this.expTable = this.initExpTable();
    }

    public double[] initExpTable() {
        double[] expTable = new double[1000];
        for (int i = 0; i < expTable.length; ++i) {
            double tmp = FastMath.exp((double)(((double)i / (double)expTable.length * 2.0 - 1.0) * (double)this.MAX_EXP));
            expTable[i] = tmp / (tmp + 1.0);
        }
        return expTable;
    }

    public Map<String, Object> getTokenizerVarMap() {
        return new HashMap<String, Object>(){
            {
                this.put("numWords", Word2Vec.this.numWords);
                this.put("nGrams", Word2Vec.this.nGrams);
                this.put("tokenizer", Word2Vec.this.tokenizer);
                this.put("tokenPreprocessor", Word2Vec.this.tokenPreprocessor);
                this.put("removeStop", Word2Vec.this.removeStop);
            }
        };
    }

    public Map<String, Object> getWord2vecVarMap() {
        return new HashMap<String, Object>(){
            {
                this.put("vectorLength", Word2Vec.this.vectorLength);
                this.put("useAdaGrad", Word2Vec.this.useAdaGrad);
                this.put("negative", Word2Vec.this.negative);
                this.put("window", Word2Vec.this.window);
                this.put("alpha", Word2Vec.this.alpha);
                this.put("minAlpha", Word2Vec.this.minAlpha);
                this.put("iterations", Word2Vec.this.iterations);
                this.put("seed", Word2Vec.this.seed);
                this.put("maxExp", Word2Vec.this.MAX_EXP);
            }
        };
    }

    public void train(JavaRDD<String> corpusRDD) throws Exception {
        log.info("Start training ...");
        JavaSparkContext sc = new JavaSparkContext(corpusRDD.context());
        Map<String, Object> tokenizerVarMap = this.getTokenizerVarMap();
        Map<String, Object> word2vecVarMap = this.getWord2vecVarMap();
        log.info("Tokenization and building VocabCache ...");
        Broadcast broadcastTokenizerVarMap = sc.broadcast(tokenizerVarMap);
        TextPipeline pipeline = new TextPipeline(corpusRDD, (Broadcast<Map<String, Object>>)broadcastTokenizerVarMap);
        pipeline.buildVocabCache();
        pipeline.buildVocabWordListRDD();
        word2vecVarMap.put("totalWordCount", pipeline.getTotalWordCount());
        JavaRDD<AtomicLong> sentenceWordsCountRDD = pipeline.getSentenceCountRDD();
        JavaRDD<List<VocabWord>> vocabWordListRDD = pipeline.getVocabWordListRDD();
        Broadcast<VocabCache> vocabCacheBroadcast = pipeline.getBroadCastVocabCache();
        VocabCache vocabCache = (VocabCache)vocabCacheBroadcast.getValue();
        log.info("Building Huffman Tree ...");
        Huffman huffman = new Huffman(vocabCache.vocabWords());
        huffman.build();
        log.info("Calculating cumulative sum of sentence counts ...");
        JavaRDD<Long> sentenceCumSumCountRDD = new CountCumSum(sentenceWordsCountRDD).buildCumSum();
        log.info("Mapping to RDD(vocabWordList, cumulative sentence count) ...");
        JavaPairRDD vocabWordListSentenceCumSumRDD = vocabWordListRDD.zip(sentenceCumSumCountRDD).setName("vocabWordListSentenceCumSumRDD").cache();
        log.info("Broadcasting word2vec variables to workers ...");
        Broadcast word2vecVarMapBroadcast = sc.broadcast(word2vecVarMap);
        Broadcast expTableBroadcast = sc.broadcast((Object)this.expTable);
        log.info("Training word2vec sentences ...");
        FirstIterationFunction firstIterFunc = new FirstIterationFunction((Broadcast<Map<String, Object>>)word2vecVarMapBroadcast, (Broadcast<double[]>)expTableBroadcast);
        JavaRDD indexSyn0UpdateEntryRDD = vocabWordListSentenceCumSumRDD.mapPartitions((FlatMapFunction)firstIterFunc).map((Function)new MapToPairFunction());
        List syn0UpdateEntries = indexSyn0UpdateEntryRDD.collect();
        INDArray syn0 = Nd4j.create((int)vocabCache.numWords(), (int)this.vectorLength);
        for (Pair syn0UpdateEntry : syn0UpdateEntries) {
            syn0.getRow(((Integer)syn0UpdateEntry.getFirst()).intValue()).addi((INDArray)syn0UpdateEntry.getSecond());
        }
        this.vocab = vocabCache;
        InMemoryLookupTable inMemoryLookupTable = new InMemoryLookupTable();
        inMemoryLookupTable.setVocab(vocabCache);
        inMemoryLookupTable.setVectorLength(this.vectorLength);
        inMemoryLookupTable.setSyn0(syn0);
        this.lookupTable = inMemoryLookupTable;
    }

    public int getVectorLength() {
        return this.vectorLength;
    }

    public Word2Vec setVectorLength(int vectorLength) {
        this.vectorLength = vectorLength;
        return this;
    }

    public boolean isUseAdaGrad() {
        return this.useAdaGrad;
    }

    public Word2Vec setUseAdaGrad(boolean useAdaGrad) {
        this.useAdaGrad = useAdaGrad;
        return this;
    }

    public int getNegative() {
        return this.negative;
    }

    public Word2Vec setNegative(int negative) {
        this.negative = negative;
        return this;
    }

    public int getNumWords() {
        return this.numWords;
    }

    public Word2Vec setNumWords(int numWords) {
        this.numWords = numWords;
        return this;
    }

    public int getWindow() {
        return this.window;
    }

    public Word2Vec setWindow(int window) {
        this.window = window;
        return this;
    }

    public double getAlpha() {
        return this.alpha;
    }

    public Word2Vec setAlpha(double alpha) {
        this.alpha = alpha;
        return this;
    }

    public double getMinAlpha() {
        return this.minAlpha;
    }

    public Word2Vec setMinAlpha(double minAlpha) {
        this.minAlpha = minAlpha;
        return this;
    }

    public int getIterations() {
        return this.iterations;
    }

    public Word2Vec setIterations(int iterations) {
        this.iterations = iterations;
        return this;
    }

    public int getnGrams() {
        return this.nGrams;
    }

    public Word2Vec setnGrams(int nGrams) {
        this.nGrams = nGrams;
        return this;
    }

    public String getTokenizer() {
        return this.tokenizer;
    }

    public Word2Vec setTokenizer(String tokenizer) {
        this.tokenizer = tokenizer;
        return this;
    }

    public String getTokenPreprocessor() {
        return this.tokenPreprocessor;
    }

    public Word2Vec setTokenPreprocessor(String tokenPreprocessor) {
        this.tokenPreprocessor = tokenPreprocessor;
        return this;
    }

    public boolean isRemoveStop() {
        return this.removeStop;
    }

    public Word2Vec setRemoveStop(boolean removeStop) {
        this.removeStop = removeStop;
        return this;
    }

    public long getSeed() {
        return this.seed;
    }

    public Word2Vec setSeed(long seed) {
        this.seed = seed;
        return this;
    }

    public double[] getExpTable() {
        return this.expTable;
    }
}

