/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.models.embeddings.word2vec;

import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import scala.Tuple2;

public class FirstIterationFunction
implements FlatMapFunction<Iterator<Tuple2<List<VocabWord>, Long>>, Map.Entry<Integer, INDArray>> {
    private int ithIteration = 1;
    private int vectorLength;
    private boolean useAdaGrad;
    private int negative;
    private int window;
    private double alpha;
    private double minAlpha;
    private long totalWordCount;
    private long seed;
    private int maxExp;
    private double[] expTable;
    private Map<Integer, INDArray> indexSyn0VecMap;
    private Map<Integer, INDArray> pointSyn1VecMap;
    private AtomicLong nextRandom = new AtomicLong(5L);

    public FirstIterationFunction(Broadcast<Map<String, Object>> word2vecVarMapBroadcast, Broadcast<double[]> expTableBroadcast) {
        Map word2vecVarMap = (Map)word2vecVarMapBroadcast.getValue();
        this.expTable = (double[])expTableBroadcast.getValue();
        this.vectorLength = (Integer)word2vecVarMap.get("vectorLength");
        this.useAdaGrad = (Boolean)word2vecVarMap.get("useAdaGrad");
        this.negative = (Integer)word2vecVarMap.get("negative");
        this.window = (Integer)word2vecVarMap.get("window");
        this.alpha = (Double)word2vecVarMap.get("alpha");
        this.minAlpha = (Double)word2vecVarMap.get("minAlpha");
        this.totalWordCount = (Long)word2vecVarMap.get("totalWordCount");
        this.seed = (Long)word2vecVarMap.get("seed");
        this.maxExp = (Integer)word2vecVarMap.get("maxExp");
        this.indexSyn0VecMap = new HashMap<Integer, INDArray>();
        this.pointSyn1VecMap = new HashMap<Integer, INDArray>();
    }

    public Iterable<Map.Entry<Integer, INDArray>> call(Iterator<Tuple2<List<VocabWord>, Long>> pairIter) {
        while (pairIter.hasNext()) {
            Tuple2<List<VocabWord>, Long> pair = pairIter.next();
            List vocabWordsList = (List)pair._1();
            Long sentenceCumSumCount = (Long)pair._2();
            double currentSentenceAlpha = Math.max(this.minAlpha, this.alpha - (this.alpha - this.minAlpha) * ((double)sentenceCumSumCount.longValue() / (double)this.totalWordCount));
            this.trainSentence(vocabWordsList, currentSentenceAlpha);
        }
        return this.indexSyn0VecMap.entrySet();
    }

    public void trainSentence(List<VocabWord> vocabWordsList, double currentSentenceAlpha) {
        if (vocabWordsList != null && !vocabWordsList.isEmpty()) {
            for (int ithWordInSentence = 0; ithWordInSentence < vocabWordsList.size(); ++ithWordInSentence) {
                this.nextRandom.set(this.nextRandom.get() * 25214903917L + 11L);
                int b = (int)this.nextRandom.get() % this.window;
                VocabWord currentWord = vocabWordsList.get(ithWordInSentence);
                if (currentWord == null) continue;
                this.skipGram(ithWordInSentence, vocabWordsList, b, currentSentenceAlpha);
            }
        }
    }

    public void skipGram(int ithWordInSentence, List<VocabWord> vocabWordsList, int b, double currentSentenceAlpha) {
        VocabWord currentWord = vocabWordsList.get(ithWordInSentence);
        if (currentWord != null && !vocabWordsList.isEmpty()) {
            int end = this.window * 2 + 1 - b;
            for (int a = b; a < end; ++a) {
                int c;
                if (a == this.window || (c = ithWordInSentence - this.window + a) < 0 || c >= vocabWordsList.size()) continue;
                VocabWord lastWord = vocabWordsList.get(c);
                this.iterateSample(currentWord, lastWord, currentSentenceAlpha);
            }
        }
    }

    public void iterateSample(VocabWord currentWord, VocabWord w2, double currentSentenceAlpha) {
        int currentWordIndex = currentWord.getIndex();
        if (w2 == null || w2.getIndex() < 0 || currentWordIndex == w2.getIndex()) {
            return;
        }
        INDArray neu1e = Nd4j.create((int)this.vectorLength);
        INDArray randomSyn0Vec = this.getRandomSyn0Vec(this.vectorLength);
        for (int i = 0; i < currentWord.getCodeLength(); ++i) {
            INDArray syn1VecCurrentIndex;
            int code = (Integer)currentWord.getCodes().get(i);
            int point = (Integer)currentWord.getPoints().get(i);
            if (this.pointSyn1VecMap.containsKey(point)) {
                syn1VecCurrentIndex = this.pointSyn1VecMap.get(point);
            } else {
                syn1VecCurrentIndex = Nd4j.zeros((int)1, (int)this.vectorLength);
                this.pointSyn1VecMap.put(point, syn1VecCurrentIndex);
            }
            double dot = Nd4j.getBlasWrapper().level1().dot(this.vectorLength, 1.0, randomSyn0Vec, syn1VecCurrentIndex);
            if (dot < (double)(-this.maxExp) || dot >= (double)this.maxExp) continue;
            int idx = (int)((dot + (double)this.maxExp) * ((double)this.expTable.length / (double)this.maxExp / 2.0));
            double f = this.expTable[idx];
            double g = ((double)(1 - code) - f) * (this.useAdaGrad ? currentWord.getGradient(i, currentSentenceAlpha) : currentSentenceAlpha);
            Nd4j.getBlasWrapper().level1().axpy(this.vectorLength, g, syn1VecCurrentIndex, neu1e);
            Nd4j.getBlasWrapper().level1().axpy(this.vectorLength, g, randomSyn0Vec, syn1VecCurrentIndex);
        }
        Nd4j.getBlasWrapper().level1().axpy(this.vectorLength, 1.0, neu1e, randomSyn0Vec);
        this.indexSyn0VecMap.put(currentWordIndex, randomSyn0Vec);
    }

    public INDArray getRandomSyn0Vec(int vectorLength) {
        return Nd4j.rand((long)this.seed, (int[])new int[]{1, vectorLength}).subi((Number)0.5).divi((Number)vectorLength);
    }
}

