/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.embeddings.learning.impl.elements;

import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm;
import org.deeplearning4j.models.embeddings.learning.impl.elements.RandomUtils;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CBOW<T extends SequenceElement>
implements ElementsLearningAlgorithm<T> {
    private VocabCache<T> vocabCache;
    private WeightLookupTable<T> lookupTable;
    private VectorsConfiguration configuration;
    private static final Logger logger = LoggerFactory.getLogger(CBOW.class);
    protected static double MAX_EXP = 6.0;
    protected int window;
    protected boolean useAdaGrad;
    protected double negative;
    protected double sampling;
    protected int[] variableWindows;
    protected double[] expTable;
    protected INDArray syn0;
    protected INDArray syn1;
    protected INDArray syn1Neg;
    protected INDArray table;

    @Override
    public String getCodeName() {
        return "CBOW";
    }

    @Override
    public void configure(@NonNull VocabCache<T> vocabCache, @NonNull WeightLookupTable<T> lookupTable, @NonNull VectorsConfiguration configuration) {
        if (vocabCache == null) {
            throw new NullPointerException("vocabCache");
        }
        if (lookupTable == null) {
            throw new NullPointerException("lookupTable");
        }
        if (configuration == null) {
            throw new NullPointerException("configuration");
        }
        this.vocabCache = vocabCache;
        this.lookupTable = lookupTable;
        this.configuration = configuration;
        this.window = configuration.getWindow();
        this.useAdaGrad = configuration.isUseAdaGrad();
        this.negative = configuration.getNegative();
        this.sampling = configuration.getSampling();
        this.syn0 = ((InMemoryLookupTable)lookupTable).getSyn0();
        this.syn1 = ((InMemoryLookupTable)lookupTable).getSyn1();
        this.syn1Neg = ((InMemoryLookupTable)lookupTable).getSyn1Neg();
        this.expTable = ((InMemoryLookupTable)lookupTable).getExpTable();
        this.table = ((InMemoryLookupTable)lookupTable).getTable();
        this.variableWindows = configuration.getVariableWindows();
    }

    @Override
    public void pretrain(SequenceIterator<T> iterator) {
    }

    @Override
    public double learnSequence(Sequence<T> sequence, AtomicLong nextRandom, double learningRate) {
        Sequence<T> tempSequence = sequence;
        if (this.sampling > 0.0) {
            tempSequence = this.applySubsampling(sequence, nextRandom);
        }
        int currentWindow = this.window;
        if (this.variableWindows != null && this.variableWindows.length != 0) {
            currentWindow = this.variableWindows[RandomUtils.nextInt(this.variableWindows.length)];
        }
        for (int i = 0; i < tempSequence.getElements().size(); ++i) {
            nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11L));
            this.cbow(i, tempSequence.getElements(), (int)nextRandom.get() % currentWindow, nextRandom, learningRate, currentWindow);
        }
        return 0.0;
    }

    @Override
    public boolean isEarlyTerminationHit() {
        return false;
    }

    public INDArray iterateSample(T currentWord, INDArray neu1, AtomicLong nextRandom, double alpha, boolean isInference) {
        INDArray neu1e;
        block7: {
            neu1e = Nd4j.zeros((int)this.lookupTable.layerSize());
            for (int p = 0; p < ((SequenceElement)currentWord).getCodeLength(); ++p) {
                int idx;
                double f = 0.0;
                int code = ((SequenceElement)currentWord).getCodes().get(p);
                int point = ((SequenceElement)currentWord).getPoints().get(p);
                INDArray syn1row = this.syn1.getRow(point);
                double dot = Nd4j.getBlasWrapper().dot(neu1, this.syn1.getRow(point));
                if (dot < -MAX_EXP || dot >= MAX_EXP || (idx = (int)((dot + MAX_EXP) * ((double)this.expTable.length / MAX_EXP / 2.0))) >= this.expTable.length) continue;
                f = this.expTable[idx];
                double g = this.useAdaGrad ? ((SequenceElement)currentWord).getGradient(p, (double)(1 - code) - f, alpha) : ((double)(1 - code) - f) * alpha;
                Nd4j.getBlasWrapper().level1().axpy(syn1row.length(), g, syn1row, neu1e);
                if (!isInference) {
                    Nd4j.getBlasWrapper().level1().axpy(syn1row.length(), g, neu1, syn1row);
                    continue;
                }
                Nd4j.getBlasWrapper().level1().axpy(syn1row.length(), g, neu1, syn1row.dup());
            }
            if (!(this.negative > 0.0) || isInference) break block7;
            int target = ((SequenceElement)currentWord).getIndex();
            int d = 0;
            while ((double)d < this.negative + 1.0) {
                block10: {
                    double g;
                    block12: {
                        double f;
                        int label;
                        block13: {
                            block11: {
                                block9: {
                                    block8: {
                                        if (d != 0) break block8;
                                        label = 1;
                                        break block9;
                                    }
                                    nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11L));
                                    int idx = Math.abs((int)(nextRandom.get() >> 16) % this.table.length());
                                    target = this.table.getInt(new int[]{idx});
                                    if (target <= 0) {
                                        target = (int)nextRandom.get() % (this.vocabCache.numWords() - 1) + 1;
                                    }
                                    if (target == ((SequenceElement)currentWord).getIndex()) break block10;
                                    label = 0;
                                }
                                if (target >= this.syn1Neg.rows() || target < 0) break block10;
                                f = Nd4j.getBlasWrapper().dot(neu1, this.syn1Neg.slice(target));
                                if (!(f > MAX_EXP)) break block11;
                                g = this.useAdaGrad ? this.lookupTable.getGradient(target, label - 1) : (double)(label - 1) * alpha;
                                break block12;
                            }
                            if (!(f < -MAX_EXP)) break block13;
                            g = (double)label * (this.useAdaGrad ? this.lookupTable.getGradient(target, alpha) : alpha);
                            break block12;
                        }
                        int idx = (int)((f + MAX_EXP) * ((double)this.expTable.length / MAX_EXP / 2.0));
                        if (idx >= this.expTable.length) break block10;
                        g = this.useAdaGrad ? this.lookupTable.getGradient(target, (double)label - this.expTable[idx]) : ((double)label - this.expTable[idx]) * alpha;
                    }
                    Nd4j.getBlasWrapper().level1().axpy(this.lookupTable.layerSize(), g, this.syn1Neg.slice(target), neu1e);
                    Nd4j.getBlasWrapper().level1().axpy(this.lookupTable.layerSize(), g, neu1, this.syn1Neg.slice(target));
                }
                ++d;
            }
        }
        return neu1e;
    }

    public void cbow(int i, List<T> sentence, int b, AtomicLong nextRandom, double alpha, int currentWindow) {
        int end = this.window * 2 + 1 - b;
        int cw = 0;
        INDArray neu1 = Nd4j.zeros((int)this.lookupTable.layerSize());
        SequenceElement currentWord = (SequenceElement)sentence.get(i);
        for (int a = b; a < end; ++a) {
            int c;
            if (a == currentWindow || (c = i - currentWindow + a) < 0 || c >= sentence.size()) continue;
            SequenceElement lastWord = (SequenceElement)sentence.get(c);
            neu1.addiRowVector(this.syn0.getRow(lastWord.getIndex()));
            ++cw;
        }
        if (cw == 0) {
            return;
        }
        neu1.divi((Number)cw);
        INDArray neu1e = this.iterateSample(currentWord, neu1, nextRandom, alpha, false);
        for (int a = b; a < end; ++a) {
            int c;
            if (a == this.window || (c = i - this.window + a) < 0 || c >= sentence.size()) continue;
            SequenceElement lastWord = (SequenceElement)sentence.get(c);
            INDArray syn0row = this.syn0.getRow(lastWord.getIndex());
            Nd4j.getBlasWrapper().level1().axpy(this.lookupTable.layerSize(), 1.0, neu1e, syn0row);
        }
    }

    public Sequence<T> applySubsampling(@NonNull Sequence<T> sequence, @NonNull AtomicLong nextRandom) {
        if (sequence == null) {
            throw new NullPointerException("sequence");
        }
        if (nextRandom == null) {
            throw new NullPointerException("nextRandom");
        }
        Sequence<T> result = new Sequence<T>();
        if (this.sampling > 0.0) {
            result.setSequenceId(sequence.getSequenceId());
            if (sequence.getSequenceLabels() != null) {
                result.setSequenceLabels(sequence.getSequenceLabels());
            }
            if (sequence.getSequenceLabel() != null) {
                result.setSequenceLabel(sequence.getSequenceLabel());
            }
            for (SequenceElement element : sequence.getElements()) {
                double numWords = this.vocabCache.totalWordOccurrences();
                double ran = (Math.sqrt(element.getElementFrequency() / (this.sampling * numWords)) + 1.0) * (this.sampling * numWords) / element.getElementFrequency();
                nextRandom.set(nextRandom.get() * 25214903917L + 11L);
                if (ran < (double)(nextRandom.get() & 0xFFFFL) / 65536.0) continue;
                result.addElement(element);
            }
            return result;
        }
        return sequence;
    }
}

