/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.models.embeddings.glove;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.commons.math3.util.FastMath;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.berkeley.CounterMap;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.berkeley.Triple;
import org.deeplearning4j.models.glove.GloveWeightLookupTable;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.spark.models.embeddings.glove.GloveChange;
import org.deeplearning4j.spark.models.embeddings.glove.cooccurrences.CoOccurrenceCalculator;
import org.deeplearning4j.spark.models.embeddings.glove.cooccurrences.CoOccurrenceCounts;
import org.deeplearning4j.spark.models.embeddings.word2vec.Word2VecVariables;
import org.deeplearning4j.spark.text.functions.TextPipeline;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.AdaGrad;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;

public class Glove
implements Serializable {
    private Broadcast<VocabCache> vocabCacheBroadcast;
    private String tokenizerFactoryClazz = DefaultTokenizerFactory.class.getName();
    private boolean symmetric = true;
    private int windowSize = 15;
    private int iterations = 300;
    private static Logger log = LoggerFactory.getLogger(Glove.class);

    public Glove(String tokenizerFactoryClazz, boolean symmetric, int windowSize, int iterations) {
        this.tokenizerFactoryClazz = tokenizerFactoryClazz;
        this.symmetric = symmetric;
        this.windowSize = windowSize;
        this.iterations = iterations;
    }

    public Glove(boolean symmetric, int windowSize, int iterations) {
        this.symmetric = symmetric;
        this.windowSize = windowSize;
        this.iterations = iterations;
    }

    private Pair<INDArray, Double> update(AdaGrad weightAdaGrad, AdaGrad biasAdaGrad, INDArray syn0, INDArray bias, VocabWord w1, INDArray wordVector, INDArray contextVector, double gradient) {
        INDArray grad1 = contextVector.mul((Number)gradient);
        INDArray update = weightAdaGrad.getGradient(grad1, w1.getIndex(), syn0.shape());
        wordVector.subi(update);
        double w1Bias = bias.getDouble(w1.getIndex());
        double biasGradient = biasAdaGrad.getGradient(gradient, w1.getIndex(), bias.shape());
        double update2 = w1Bias - biasGradient;
        bias.putScalar(w1.getIndex(), bias.getDouble(w1.getIndex()) - update2);
        return new Pair((Object)update, (Object)update2);
    }

    public Pair<VocabCache, GloveWeightLookupTable> train(JavaRDD<String> rdd) throws Exception {
        JavaSparkContext sc = new JavaSparkContext(rdd.context());
        SparkConf conf = sc.getConf();
        int vectorLength = (Integer)Word2VecVariables.assignVar("org.deeplearning4j.scaleout.perform.models.word2vec.length", conf, Integer.class);
        boolean useAdaGrad = (Boolean)Word2VecVariables.assignVar("org.deeplearning4j.scaleout.perform.models.word2vec.adagrad", conf, Boolean.class);
        double negative = (Double)Word2VecVariables.assignVar("org.deeplearning4j.scaleout.perform.models.word2vec.negative", conf, Double.class);
        final int numWords = (Integer)Word2VecVariables.assignVar("org.deeplearning4j.scaleout.perform.models.word2vec.numwords", conf, Integer.class);
        int window = (Integer)Word2VecVariables.assignVar("org.deeplearning4j.scaleout.perform.models.word2vec.window", conf, Integer.class);
        double alpha = (Double)Word2VecVariables.assignVar("org.deeplearning4j.scaleout.perform.models.word2vec.alpha", conf, Double.class);
        double minAlpha = (Double)Word2VecVariables.assignVar("org.deeplearning4j.scaleout.perform.models.word2vec.minalpha", conf, Double.class);
        int iterations = (Integer)Word2VecVariables.assignVar("org.deeplearning4j.scaleout.perform.models.word2vec.iterations", conf, Integer.class);
        final int nGrams = (Integer)Word2VecVariables.assignVar("org.deeplearning4j.scaleout.perform.models.word2vec.ngrams", conf, Integer.class);
        final String tokenizer = (String)Word2VecVariables.assignVar("org.deeplearning4j.scaleout.perform.models.word2vec.tokenizer", conf, String.class);
        final String tokenPreprocessor = (String)Word2VecVariables.assignVar("org.deeplearning4j.scaleout.perform.models.word2vec.preprocessor", conf, String.class);
        final boolean removeStop = (Boolean)Word2VecVariables.assignVar("org.deeplearning4j.scaleout.perform.models.word2vec.removestopwords", conf, Boolean.class);
        HashMap<String, Object> tokenizerVarMap = new HashMap<String, Object>(){
            {
                this.put("numWords", numWords);
                this.put("nGrams", nGrams);
                this.put("tokenizer", tokenizer);
                this.put("tokenPreprocessor", tokenPreprocessor);
                this.put("removeStop", removeStop);
            }
        };
        Broadcast broadcastTokenizerVarMap = sc.broadcast((Object)tokenizerVarMap);
        TextPipeline pipeline = new TextPipeline(rdd, (Broadcast<Map<String, Object>>)broadcastTokenizerVarMap);
        pipeline.buildVocabCache();
        pipeline.buildVocabWordListRDD();
        Long totalWordCount = pipeline.getTotalWordCount();
        VocabCache vocabCache = pipeline.getVocabCache();
        JavaRDD<Pair<List<String>, AtomicLong>> sentenceWordsCountRDD = pipeline.getSentenceWordsCountRDD();
        Pair vocabAndNumWords = new Pair((Object)vocabCache, (Object)totalWordCount);
        this.vocabCacheBroadcast = sc.broadcast(vocabAndNumWords.getFirst());
        final GloveWeightLookupTable gloveWeightLookupTable = new GloveWeightLookupTable.Builder().cache((VocabCache)vocabAndNumWords.getFirst()).lr(conf.getDouble("org.deeplearning4j.scaleout.perform.models.glove.alpha", 0.01)).maxCount(conf.getDouble("org.deeplearning4j.scaleout.perform.models.glove.maxcount", 100.0)).vectorLength(conf.getInt("org.deeplearning4j.scaleout.perform.models.glove.length", 300)).xMax(conf.getDouble("org.deeplearning4j.scaleout.perform.models.glove.xmax", 0.75)).build();
        gloveWeightLookupTable.resetWeights();
        gloveWeightLookupTable.getBiasAdaGrad().historicalGradient = Nd4j.ones((int)gloveWeightLookupTable.getSyn0().rows());
        gloveWeightLookupTable.getWeightAdaGrad().historicalGradient = Nd4j.ones((int[])gloveWeightLookupTable.getSyn0().shape());
        log.info("Created lookup table of size " + Arrays.toString(gloveWeightLookupTable.getSyn0().shape()));
        CounterMap coOccurrenceCounts = (CounterMap)sentenceWordsCountRDD.map((Function)new CoOccurrenceCalculator(this.symmetric, this.vocabCacheBroadcast, this.windowSize)).fold((Object)new CounterMap(), (Function2)new CoOccurrenceCounts());
        Iterator pair2 = coOccurrenceCounts.getPairIterator();
        ArrayList<Triple> counts = new ArrayList<Triple>();
        while (pair2.hasNext()) {
            Pair next = (Pair)pair2.next();
            if (coOccurrenceCounts.getCount(next.getFirst(), next.getSecond()) > gloveWeightLookupTable.getMaxCount()) {
                coOccurrenceCounts.setCount(next.getFirst(), next.getSecond(), gloveWeightLookupTable.getMaxCount());
            }
            counts.add(new Triple(next.getFirst(), next.getSecond(), (Object)coOccurrenceCounts.getCount(next.getFirst(), next.getSecond())));
        }
        log.info("Calculated co occurrences");
        JavaRDD parallel = sc.parallelize(counts);
        JavaPairRDD pairs = parallel.mapToPair((PairFunction)new PairFunction<Triple<String, String, Double>, String, Tuple2<String, Double>>(){

            public Tuple2<String, Tuple2<String, Double>> call(Triple<String, String, Double> stringStringDoubleTriple) throws Exception {
                return new Tuple2(stringStringDoubleTriple.getFirst(), (Object)new Tuple2(stringStringDoubleTriple.getSecond(), stringStringDoubleTriple.getThird()));
            }
        });
        JavaPairRDD pairsVocab = pairs.mapToPair((PairFunction)new PairFunction<Tuple2<String, Tuple2<String, Double>>, VocabWord, Tuple2<VocabWord, Double>>(){

            public Tuple2<VocabWord, Tuple2<VocabWord, Double>> call(Tuple2<String, Tuple2<String, Double>> stringTuple2Tuple2) throws Exception {
                VocabWord w1 = ((VocabCache)Glove.this.vocabCacheBroadcast.getValue()).wordFor((String)stringTuple2Tuple2._1());
                VocabWord w2 = ((VocabCache)Glove.this.vocabCacheBroadcast.getValue()).wordFor((String)((Tuple2)stringTuple2Tuple2._2())._1());
                return new Tuple2((Object)w1, (Object)new Tuple2((Object)w2, ((Tuple2)stringTuple2Tuple2._2())._2()));
            }
        });
        for (int i = 0; i < iterations; ++i) {
            JavaRDD change = pairsVocab.map((Function)new Function<Tuple2<VocabWord, Tuple2<VocabWord, Double>>, GloveChange>(){

                public GloveChange call(Tuple2<VocabWord, Tuple2<VocabWord, Double>> vocabWordTuple2Tuple2) throws Exception {
                    double fDiff;
                    VocabWord w1 = (VocabWord)vocabWordTuple2Tuple2._1();
                    VocabWord w2 = (VocabWord)((Tuple2)vocabWordTuple2Tuple2._2())._1();
                    INDArray w1Vector = gloveWeightLookupTable.getSyn0().slice(w1.getIndex());
                    INDArray w2Vector = gloveWeightLookupTable.getSyn0().slice(w2.getIndex());
                    INDArray bias = gloveWeightLookupTable.getBias();
                    double score = (Double)((Tuple2)vocabWordTuple2Tuple2._2())._2();
                    double xMax = gloveWeightLookupTable.getxMax();
                    double maxCount = gloveWeightLookupTable.getMaxCount();
                    double prediction = Nd4j.getBlasWrapper().dot(w1Vector, w2Vector);
                    double weight = FastMath.pow((double)Math.min(1.0, score / maxCount), (double)xMax);
                    double d = fDiff = score > xMax ? prediction : weight * ((prediction += bias.getDouble(w1.getIndex()) + bias.getDouble(w2.getIndex())) - Math.log(score));
                    if (Double.isNaN(fDiff)) {
                        fDiff = Nd4j.EPS_THRESHOLD;
                    }
                    double gradient = fDiff;
                    Pair w1Update = Glove.this.update(gloveWeightLookupTable.getWeightAdaGrad(), gloveWeightLookupTable.getBiasAdaGrad(), gloveWeightLookupTable.getSyn0(), gloveWeightLookupTable.getBias(), w1, w1Vector, w2Vector, gradient);
                    Pair w2Update = Glove.this.update(gloveWeightLookupTable.getWeightAdaGrad(), gloveWeightLookupTable.getBiasAdaGrad(), gloveWeightLookupTable.getSyn0(), gloveWeightLookupTable.getBias(), w2, w2Vector, w1Vector, gradient);
                    return new GloveChange(w1, w2, (INDArray)w1Update.getFirst(), (INDArray)w2Update.getFirst(), (Double)w1Update.getSecond(), (Double)w2Update.getSecond(), fDiff, gloveWeightLookupTable.getWeightAdaGrad().getHistoricalGradient().slice(w1.getIndex()), gloveWeightLookupTable.getWeightAdaGrad().getHistoricalGradient().slice(w2.getIndex()), gloveWeightLookupTable.getBiasAdaGrad().getHistoricalGradient().getDouble(w2.getIndex()), gloveWeightLookupTable.getBiasAdaGrad().getHistoricalGradient().getDouble(w1.getIndex()));
                }
            });
            List gloveChanges = change.collect();
            double error = 0.0;
            for (GloveChange change2 : gloveChanges) {
                change2.apply(gloveWeightLookupTable);
                error += change2.getError();
            }
            List l = pairsVocab.collect();
            Collections.shuffle(l);
            pairsVocab = sc.parallelizePairs(l);
            log.info("Error at iteration " + i + " was " + error);
        }
        return new Pair(vocabAndNumWords.getFirst(), (Object)gloveWeightLookupTable);
    }
}

