package org.deeplearning4j.scaleout.perform.models.word2vec;

import java.io.Serializable;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.InMemoryLookupCache;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/scaleout/perform/models/word2vec/Word2VecWork.class */
public class Word2VecWork implements Serializable {
    private List<List<VocabWord>> sentences;
    private Map<String, Pair<VocabWord, INDArray>> vectors = new ConcurrentHashMap();
    private Map<String, Pair<VocabWord, INDArray>> negativeVectors = new ConcurrentHashMap();
    private Map<Integer, VocabWord> indexes = new ConcurrentHashMap();
    private Map<String, INDArray> originalVectors = new ConcurrentHashMap();
    private Map<String, INDArray> originalSyn1Vectors = new ConcurrentHashMap();
    private Map<String, INDArray> originalNegative = new ConcurrentHashMap();
    private Map<String, INDArray> syn1Vectors = new ConcurrentHashMap();

    public Word2VecWork(InMemoryLookupTable inMemoryLookupTable, InMemoryLookupCache inMemoryLookupCache, List<List<VocabWord>> list) {
        this.sentences = list;
        Iterator<List<VocabWord>> it = list.iterator();
        while (it.hasNext()) {
            for (VocabWord vocabWord : it.next()) {
                addWord(vocabWord, inMemoryLookupTable);
                if (vocabWord.getPoints() != null) {
                    for (int i = 0; i < vocabWord.getCodeLength(); i++) {
                        addWord(inMemoryLookupCache.wordFor(inMemoryLookupCache.wordAtIndex(vocabWord.getPoints().get(i).intValue())), inMemoryLookupTable);
                    }
                }
            }
        }
    }

    private void addWord(VocabWord vocabWord, InMemoryLookupTable inMemoryLookupTable) {
        if (vocabWord == null) {
            throw new IllegalArgumentException("Word must not be null!");
        }
        this.indexes.put(Integer.valueOf(vocabWord.getIndex()), vocabWord);
        this.vectors.put(vocabWord.getWord(), new Pair<>(vocabWord, inMemoryLookupTable.getSyn0().getRow(vocabWord.getIndex()).dup()));
        this.originalVectors.put(vocabWord.getWord(), inMemoryLookupTable.getSyn0().getRow(vocabWord.getIndex()).dup());
        if (inMemoryLookupTable instanceof InMemoryLookupTable) {
            this.syn1Vectors.put(vocabWord.getWord(), inMemoryLookupTable.getSyn1().slice(vocabWord.getIndex()).dup());
            this.originalSyn1Vectors.put(vocabWord.getWord(), inMemoryLookupTable.getSyn1().slice(vocabWord.getIndex()).dup());
            if (inMemoryLookupTable.getSyn1Neg() != null) {
                this.originalNegative.put(vocabWord.getWord(), inMemoryLookupTable.getSyn1Neg().slice(vocabWord.getIndex()).dup());
                this.negativeVectors.put(vocabWord.getWord(), new Pair<>(vocabWord, inMemoryLookupTable.getSyn1Neg().slice(vocabWord.getIndex()).dup()));
            }
        }
    }

    public Word2VecResult addDeltas() {
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        HashMap hashMap3 = new HashMap();
        Iterator<List<VocabWord>> it = this.sentences.iterator();
        while (it.hasNext()) {
            for (VocabWord vocabWord : it.next()) {
                hashMap.put(vocabWord.getWord(), ((INDArray) this.vectors.get(vocabWord.getWord()).getSecond()).sub(this.originalVectors.get(vocabWord.getWord())));
                hashMap2.put(vocabWord.getWord(), this.syn1Vectors.get(vocabWord.getWord()).sub(this.originalSyn1Vectors.get(vocabWord.getWord())));
                if (!this.negativeVectors.isEmpty()) {
                    hashMap3.put(vocabWord.getWord(), ((INDArray) this.negativeVectors.get(vocabWord.getWord()).getSecond()).subi(this.originalNegative.get(vocabWord.getWord())));
                }
            }
        }
        return new Word2VecResult(hashMap, hashMap2, hashMap3);
    }

    public List<List<VocabWord>> getSentences() {
        return this.sentences;
    }

    public void setSentences(List<List<VocabWord>> list) {
        this.sentences = list;
    }

    public Map<String, Pair<VocabWord, INDArray>> getNegativeVectors() {
        return this.negativeVectors;
    }

    public void setNegativeVectors(Map<String, Pair<VocabWord, INDArray>> map) {
        this.negativeVectors = map;
    }

    public Map<String, Pair<VocabWord, INDArray>> getVectors() {
        return this.vectors;
    }

    public void setVectors(Map<String, Pair<VocabWord, INDArray>> map) {
        this.vectors = map;
    }

    public Map<Integer, VocabWord> getIndexes() {
        return this.indexes;
    }

    public void setIndexes(Map<Integer, VocabWord> map) {
        this.indexes = map;
    }

    public Map<String, INDArray> getOriginalVectors() {
        return this.originalVectors;
    }

    public void setOriginalVectors(Map<String, INDArray> map) {
        this.originalVectors = map;
    }

    public Map<String, INDArray> getOriginalSyn1Vectors() {
        return this.originalSyn1Vectors;
    }

    public void setOriginalSyn1Vectors(Map<String, INDArray> map) {
        this.originalSyn1Vectors = map;
    }

    public Map<String, INDArray> getOriginalNegative() {
        return this.originalNegative;
    }

    public void setOriginalNegative(Map<String, INDArray> map) {
        this.originalNegative = map;
    }

    public Map<String, INDArray> getSyn1Vectors() {
        return this.syn1Vectors;
    }

    public void setSyn1Vectors(Map<String, INDArray> map) {
        this.syn1Vectors = map;
    }
}
