package org.deeplearning4j.models.embeddings.reader.impl;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import lombok.NonNull;
import org.deeplearning4j.clustering.sptree.DataPoint;
import org.deeplearning4j.clustering.vptree.VPTree;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.util.SetUtils;

/* loaded from: input_file:org/deeplearning4j/models/embeddings/reader/impl/TreeModelUtils.class */
public class TreeModelUtils<T extends SequenceElement> extends BasicModelUtils<T> {
    protected VPTree vpTree;

    @Override // org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils, org.deeplearning4j.models.embeddings.reader.ModelUtils
    public void init(@NonNull WeightLookupTable<T> weightLookupTable) {
        if (weightLookupTable == null) {
            throw new NullPointerException("lookupTable is marked @NonNull but is null");
        }
        super.init(weightLookupTable);
        this.vpTree = null;
    }

    protected synchronized void checkTree() {
        if (this.vpTree == null) {
            ArrayList arrayList = new ArrayList();
            for (String str : this.vocabCache.words()) {
                arrayList.add(new DataPoint(this.vocabCache.indexOf(str), this.lookupTable.vector(str)));
            }
            this.vpTree = new VPTree(arrayList);
        }
    }

    @Override // org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils, org.deeplearning4j.models.embeddings.reader.ModelUtils
    public Collection<String> wordsNearest(String str, int i) {
        if (!this.vocabCache.hasToken(str)) {
            return new ArrayList();
        }
        Collection<String> wordsNearest = wordsNearest(Arrays.asList(str), new ArrayList(), i + 1);
        if (wordsNearest.contains(str)) {
            wordsNearest.remove(str);
        }
        return wordsNearest;
    }

    @Override // org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils, org.deeplearning4j.models.embeddings.reader.ModelUtils
    public Collection<String> wordsNearest(Collection<String> collection, Collection<String> collection2, int i) {
        Iterator it = SetUtils.union(new HashSet(collection), new HashSet(collection2)).iterator();
        while (it.hasNext()) {
            if (!this.vocabCache.containsWord((String) it.next())) {
                return new ArrayList();
            }
        }
        INDArray create = Nd4j.create(collection.size() + collection2.size(), this.lookupTable.layerSize());
        int i2 = 0;
        Iterator<String> it2 = collection.iterator();
        while (it2.hasNext()) {
            int i3 = i2;
            i2++;
            create.putRow(i3, this.lookupTable.vector(it2.next()));
        }
        Iterator<String> it3 = collection2.iterator();
        while (it3.hasNext()) {
            int i4 = i2;
            i2++;
            create.putRow(i4, this.lookupTable.vector(it3.next()).mul(-1));
        }
        return wordsNearest(create.isMatrix() ? create.mean(new int[]{0}) : create, i);
    }

    @Override // org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils, org.deeplearning4j.models.embeddings.reader.ModelUtils
    public Collection<String> wordsNearest(INDArray iNDArray, int i) {
        checkTree();
        INDArray adjustRank = adjustRank(iNDArray);
        ArrayList arrayList = new ArrayList();
        this.vpTree.search(adjustRank, i, arrayList, new ArrayList());
        ArrayList arrayList2 = new ArrayList();
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            arrayList2.add(this.vocabCache.wordAtIndex(((DataPoint) it.next()).getIndex()));
        }
        return super.wordsNearest(adjustRank, i);
    }
}
