/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.scaleout.perform.models.glove;

import java.io.Serializable;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.models.glove.GloveWeightLookupTable;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.scaleout.perform.models.glove.GloveResult;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.learning.AdaGrad;

public class GloveWork
implements Serializable {
    private Map<String, Pair<VocabWord, INDArray>> vectors = new ConcurrentHashMap<String, Pair<VocabWord, INDArray>>();
    private List<Pair<VocabWord, VocabWord>> coOccurrences;
    private Map<Integer, VocabWord> indexes = new ConcurrentHashMap<Integer, VocabWord>();
    private Map<String, INDArray> originalVectors = new ConcurrentHashMap<String, INDArray>();
    private Map<String, Double> biases = new ConcurrentHashMap<String, Double>();
    private Map<String, AdaGrad> adaGrads = new ConcurrentHashMap<String, AdaGrad>();
    private Map<String, AdaGrad> biasAdaGrads = new ConcurrentHashMap<String, AdaGrad>();

    public GloveWork(GloveWeightLookupTable table, List<Pair<VocabWord, VocabWord>> coOccurrences) {
        this.coOccurrences = coOccurrences;
        for (Pair<VocabWord, VocabWord> coOccurrence : coOccurrences) {
            this.indexes.put(((VocabWord)coOccurrence.getFirst()).getIndex(), (VocabWord)coOccurrence.getFirst());
            this.indexes.put(((VocabWord)coOccurrence.getSecond()).getIndex(), (VocabWord)coOccurrence.getSecond());
            this.addWord((VocabWord)coOccurrence.getFirst(), table);
            this.addWord((VocabWord)coOccurrence.getSecond(), table);
        }
    }

    private void addWord(VocabWord word, GloveWeightLookupTable table) {
        if (word == null) {
            throw new IllegalArgumentException("Word must not be null!");
        }
        this.indexes.put(word.getIndex(), word);
        this.vectors.put(word.getWord(), (Pair<VocabWord, INDArray>)new Pair((Object)word, (Object)table.getSyn0().getRow(word.getIndex()).dup()));
        this.originalVectors.put(word.getWord(), table.getSyn0().getRow(word.getIndex()).dup());
        this.biases.put(word.getWord(), table.getBias().getDouble(word.getIndex()));
        this.adaGrads.put(word.getWord(), table.getWeightAdaGrad().createSubset(word.getIndex()));
        this.biasAdaGrads.put(word.getWord(), table.getBiasAdaGrad().createSubset(word.getIndex()));
    }

    public AdaGrad getBiasAdaGrad(String word) {
        return this.biasAdaGrads.get(word);
    }

    public AdaGrad getAdaGrad(String word) {
        return this.adaGrads.get(word);
    }

    public void updateBias(String word, double bias) {
        this.biases.put(word, bias);
    }

    public GloveResult addDeltas() {
        HashMap<String, INDArray> syn0Change = new HashMap<String, INDArray>();
        for (Pair<VocabWord, VocabWord> sentence : this.coOccurrences) {
            VocabWord w1 = (VocabWord)sentence.getFirst();
            VocabWord w2 = (VocabWord)sentence.getSecond();
            syn0Change.put(w1.getWord(), ((INDArray)this.vectors.get(w1.getWord()).getSecond()).sub(this.originalVectors.get(w1.getWord())));
            syn0Change.put(w2.getWord(), ((INDArray)this.vectors.get(w2.getWord()).getSecond()).sub(this.originalVectors.get(w2.getWord())));
        }
        return new GloveResult(syn0Change);
    }

    public double getBias(String word) {
        return this.biases.get(word);
    }

    public List<Pair<VocabWord, VocabWord>> getCoOccurrences() {
        return this.coOccurrences;
    }

    public void setCoOccurrences(List<Pair<VocabWord, VocabWord>> coOccurrences) {
        this.coOccurrences = coOccurrences;
    }

    public Map<String, Pair<VocabWord, INDArray>> getVectors() {
        return this.vectors;
    }

    public void setVectors(Map<String, Pair<VocabWord, INDArray>> vectors) {
        this.vectors = vectors;
    }

    public Map<Integer, VocabWord> getIndexes() {
        return this.indexes;
    }

    public void setIndexes(Map<Integer, VocabWord> indexes) {
        this.indexes = indexes;
    }

    public Map<String, INDArray> getOriginalVectors() {
        return this.originalVectors;
    }

    public void setOriginalVectors(Map<String, INDArray> originalVectors) {
        this.originalVectors = originalVectors;
    }

    public Map<String, Double> getBiases() {
        return this.biases;
    }

    public void setBiases(Map<String, Double> biases) {
        this.biases = biases;
    }
}

