package org.deeplearning4j.models.glove;

import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.LineIterator;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.AdaGrad;

/* loaded from: input_file:org/deeplearning4j/models/glove/GloveWeightLookupTable.class */
public class GloveWeightLookupTable extends InMemoryLookupTable {
    private AdaGrad weightAdaGrad;
    private AdaGrad biasAdaGrad;
    private INDArray bias;
    private double xMax;
    private double maxCount;

    /* loaded from: input_file:org/deeplearning4j/models/glove/GloveWeightLookupTable$Builder.class */
    public static class Builder extends InMemoryLookupTable.Builder {
        private double xMax = 0.75d;
        private double maxCount = 100.0d;

        public Builder maxCount(double d) {
            this.maxCount = d;
            return this;
        }

        public Builder xMax(double d) {
            this.xMax = d;
            return this;
        }

        @Override // org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable.Builder
        public Builder cache(VocabCache vocabCache) {
            super.cache(vocabCache);
            return this;
        }

        @Override // org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable.Builder
        public Builder negative(double d) {
            super.negative(d);
            return this;
        }

        @Override // org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable.Builder
        public Builder vectorLength(int i) {
            super.vectorLength(i);
            return this;
        }

        @Override // org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable.Builder
        public Builder useAdaGrad(boolean z) {
            super.useAdaGrad(z);
            return this;
        }

        @Override // org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable.Builder
        public Builder lr(double d) {
            super.lr(d);
            return this;
        }

        @Override // org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable.Builder
        public Builder gen(Random random) {
            super.gen(random);
            return this;
        }

        @Override // org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable.Builder
        public Builder seed(long j) {
            super.seed(j);
            return this;
        }

        @Override // org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable.Builder
        public GloveWeightLookupTable build() {
            return new GloveWeightLookupTable(this.vocabCache, this.vectorLength, this.useAdaGrad, this.lr, this.gen, this.negative, this.xMax, this.maxCount);
        }
    }

    public GloveWeightLookupTable(VocabCache vocabCache, int i, boolean z, double d, Random random, double d2, double d3, double d4) {
        super(vocabCache, i, z, d, random, d2);
        this.xMax = 0.75d;
        this.maxCount = 100.0d;
        this.xMax = d3;
        this.maxCount = d4;
    }

    @Override // org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable, org.deeplearning4j.models.embeddings.WeightLookupTable
    public void resetWeights(boolean z) {
        if (this.rng == null) {
            this.rng = Nd4j.getRandom();
        }
        if (this.syn0 == null || (this.syn0 != null && z)) {
            this.syn0 = Nd4j.rand(new int[]{this.vocab.numWords() + 1, this.vectorLength}, this.rng).subi(Double.valueOf(0.5d)).divi(Double.valueOf(this.vectorLength));
            putVector("UNK", Nd4j.rand(1, this.vectorLength, this.rng).subi(Double.valueOf(0.5d)).divi(Integer.valueOf(this.vectorLength)));
        }
        if (this.weightAdaGrad == null || (this.weightAdaGrad != null && z)) {
            this.weightAdaGrad = new AdaGrad(new int[]{this.vocab.numWords() + 1, this.vectorLength});
            this.weightAdaGrad.setMasterStepSize(this.lr.get());
        }
        if (this.bias == null || (this.bias != null && z)) {
            this.bias = Nd4j.create(this.syn0.rows());
        }
        if (this.biasAdaGrad == null || (this.biasAdaGrad != null && z)) {
            this.biasAdaGrad = new AdaGrad(this.bias.shape());
            this.biasAdaGrad.setMasterStepSize(this.lr.get());
        }
    }

    @Override // org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable, org.deeplearning4j.models.embeddings.WeightLookupTable
    public void resetWeights() {
        resetWeights(true);
    }

    public double iterateSample(VocabWord vocabWord, VocabWord vocabWord2, double d) {
        INDArray slice = this.syn0.slice(vocabWord.getIndex());
        INDArray slice2 = this.syn0.slice(vocabWord2.getIndex());
        if (vocabWord.getIndex() < 0 || vocabWord.getIndex() >= this.syn0.rows()) {
            throw new IllegalArgumentException("Illegal index for word " + vocabWord.getWord());
        }
        if (vocabWord2.getIndex() < 0 || vocabWord2.getIndex() >= this.syn0.rows()) {
            throw new IllegalArgumentException("Illegal index for word " + vocabWord2.getWord());
        }
        double dot = Nd4j.getBlasWrapper().dot(slice, slice2) + this.bias.getDouble(vocabWord.getIndex()) + this.bias.getDouble(vocabWord2.getIndex());
        double pow = d > this.xMax ? dot : Math.pow(Math.min(1.0d, d / this.maxCount), this.xMax) * (dot - Math.log(d));
        if (Double.isNaN(pow)) {
            pow = Nd4j.EPS_THRESHOLD;
        }
        double d2 = pow;
        update(vocabWord, slice, slice2, d2);
        update(vocabWord2, slice2, slice, d2);
        return pow;
    }

    private void update(VocabWord vocabWord, INDArray iNDArray, INDArray iNDArray2, double d) {
        iNDArray.subi(this.weightAdaGrad.getGradient(iNDArray2.mul(Double.valueOf(d)), vocabWord.getIndex(), this.syn0.shape()));
        this.bias.putScalar(vocabWord.getIndex(), this.bias.getDouble(vocabWord.getIndex()) - this.biasAdaGrad.getGradient(d, vocabWord.getIndex(), this.bias.shape()));
    }

    public AdaGrad getWeightAdaGrad() {
        return this.weightAdaGrad;
    }

    public AdaGrad getBiasAdaGrad() {
        return this.biasAdaGrad;
    }

    public static GloveWeightLookupTable load(InputStream inputStream, VocabCache vocabCache) throws IOException {
        LineIterator lineIterator = IOUtils.lineIterator(inputStream, "UTF-8");
        GloveWeightLookupTable gloveWeightLookupTable = null;
        HashMap hashMap = new HashMap();
        while (lineIterator.hasNext()) {
            String trim = lineIterator.nextLine().trim();
            if (!trim.isEmpty()) {
                String[] split = trim.split(" ");
                String str = split[0];
                if (gloveWeightLookupTable == null) {
                    gloveWeightLookupTable = new Builder().cache(vocabCache).vectorLength(split.length - 1).build();
                }
                if (!str.isEmpty()) {
                    float[] read = read(split, gloveWeightLookupTable.getVectorLength());
                    if (read.length >= 1) {
                        hashMap.put(str, read);
                    }
                }
            }
        }
        gloveWeightLookupTable.setSyn0(weights(gloveWeightLookupTable, hashMap, vocabCache));
        gloveWeightLookupTable.resetWeights(false);
        lineIterator.close();
        return gloveWeightLookupTable;
    }

    private static INDArray weights(GloveWeightLookupTable gloveWeightLookupTable, Map<String, float[]> map, VocabCache vocabCache) {
        INDArray create = Nd4j.create(map.size(), gloveWeightLookupTable.getVectorLength());
        for (String str : map.keySet()) {
            INDArray create2 = Nd4j.create(Nd4j.createBuffer(map.get(str)));
            if (create2.length() == gloveWeightLookupTable.getVectorLength() && vocabCache.indexOf(str) < map.size() && vocabCache.indexOf(str) >= 0) {
                create.putRow(vocabCache.indexOf(str), create2);
            }
        }
        return create;
    }

    private static float[] read(String[] strArr, int i) {
        float[] fArr = new float[i];
        for (int i2 = 1; i2 < strArr.length; i2++) {
            fArr[i2 - 1] = Float.parseFloat(strArr[i2]);
        }
        return fArr;
    }

    @Override // org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable, org.deeplearning4j.models.embeddings.WeightLookupTable
    public void iterateSample(VocabWord vocabWord, VocabWord vocabWord2, AtomicLong atomicLong, double d) {
        throw new UnsupportedOperationException();
    }

    public double getxMax() {
        return this.xMax;
    }

    public void setxMax(double d) {
        this.xMax = d;
    }

    public double getMaxCount() {
        return this.maxCount;
    }

    public void setMaxCount(double d) {
        this.maxCount = d;
    }

    public INDArray getBias() {
        return this.bias;
    }

    public void setBias(INDArray iNDArray) {
        this.bias = iNDArray;
    }
}
