package org.deeplearning4j.scaleout.perform.models.word2vec;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.commons.math3.util.FastMath;
import org.canova.api.conf.Configuration;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.scaleout.api.statetracker.StateTracker;
import org.deeplearning4j.scaleout.job.Job;
import org.deeplearning4j.scaleout.perform.WorkerPerformer;
import org.deeplearning4j.scaleout.statetracker.hazelcast.HazelCastStateTracker;
import org.deeplearning4j.text.invertedindex.InvertedIndex;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/scaleout/perform/models/word2vec/Word2VecPerformer.class */
public class Word2VecPerformer implements WorkerPerformer {
    private int vectorLength;
    public static final String NAME_SPACE = "org.deeplearning4j.scaleout.perform.models.word2vec";
    public static final String VECTOR_LENGTH = "org.deeplearning4j.scaleout.perform.models.word2vec.length";
    public static final String ADAGRAD = "org.deeplearning4j.scaleout.perform.models.word2vec.adagrad";
    public static final String NEGATIVE = "org.deeplearning4j.scaleout.perform.models.word2vec.negative";
    public static final String NUM_WORDS = "org.deeplearning4j.scaleout.perform.models.word2vec.numwords";
    public static final String TABLE = "org.deeplearning4j.scaleout.perform.models.word2vec.table";
    public static final String WINDOW = "org.deeplearning4j.scaleout.perform.models.word2vec.window";
    public static final String ALPHA = "org.deeplearning4j.scaleout.perform.models.word2vec.alpha";
    public static final String MIN_ALPHA = "org.deeplearning4j.scaleout.perform.models.word2vec.minalpha";
    public static final String TOTAL_WORDS = "org.deeplearning4j.scaleout.perform.models.word2vec.totalwords";
    public static final String NUM_WORDS_SO_FAR = "org.deeplearning4j.scaleout.perform.models.word2vec.wordssofar";
    public static final String ITERATIONS = "org.deeplearning4j.scaleout.perform.models.word2vec.iterations";
    double[] expTable;
    private boolean useAdaGrad;
    private double negative;
    private int numWords;
    private INDArray table;
    private int window;
    private AtomicLong nextRandom;
    private double alpha;
    private double minAlpha;
    private int totalWords;
    private int iterations;
    private StateTracker stateTracker;
    private int lastChecked;
    static double MAX_EXP = 6.0d;
    private static final Logger log = LoggerFactory.getLogger(Word2VecPerformer.class);

    public Word2VecPerformer(StateTracker stateTracker) {
        this.vectorLength = 50;
        this.expTable = new double[1000];
        this.useAdaGrad = false;
        this.negative = 5.0d;
        this.numWords = 1;
        this.window = 5;
        this.nextRandom = new AtomicLong(5L);
        this.alpha = 0.025d;
        this.minAlpha = 0.01d;
        this.totalWords = 1;
        this.iterations = 5;
        this.lastChecked = 0;
        this.stateTracker = stateTracker;
    }

    public Word2VecPerformer() {
        this.vectorLength = 50;
        this.expTable = new double[1000];
        this.useAdaGrad = false;
        this.negative = 5.0d;
        this.numWords = 1;
        this.window = 5;
        this.nextRandom = new AtomicLong(5L);
        this.alpha = 0.025d;
        this.minAlpha = 0.01d;
        this.totalWords = 1;
        this.iterations = 5;
        this.lastChecked = 0;
    }

    public void perform(Job job) {
        if (job.getWork() instanceof Word2VecWork) {
            double count = this.stateTracker.count(NUM_WORDS_SO_FAR);
            Word2VecWork word2VecWork = (Word2VecWork) job.getWork();
            if (word2VecWork == null) {
                return;
            }
            List<List<VocabWord>> sentences = word2VecWork.getSentences();
            double max = Math.max(this.minAlpha, this.alpha * (1.0d - ((1.0d * count) / this.totalWords)));
            int i = 0;
            for (List<VocabWord> list : sentences) {
                for (int i2 = 0; i2 < this.iterations; i2++) {
                    trainSentence(list, word2VecWork, max);
                }
                i += list.size();
            }
            double d = i + count;
            if (Math.abs(d - this.lastChecked) >= 10000.0d) {
                this.lastChecked = (int) d;
                log.info("Words so far " + d + " out of " + this.totalWords);
            }
            job.setResult((Serializable) Arrays.asList(word2VecWork.addDeltas()));
            this.stateTracker.increment(NUM_WORDS_SO_FAR, i);
            return;
        }
        if (job.getWork() instanceof Collection) {
            double count2 = this.stateTracker.count(NUM_WORDS_SO_FAR);
            Collection<Word2VecWork> collection = (Collection) job.getWork();
            double max2 = Math.max(this.minAlpha, this.alpha * (1.0d - ((1.0d * count2) / this.totalWords)));
            int i3 = 0;
            ArrayList arrayList = new ArrayList();
            for (Word2VecWork word2VecWork2 : collection) {
                for (List<VocabWord> list2 : word2VecWork2.getSentences()) {
                    trainSentence(list2, word2VecWork2, max2);
                    i3 += list2.size();
                    arrayList.add(word2VecWork2.addDeltas());
                }
            }
            double d2 = i3 + count2;
            if (Math.abs(d2 - this.lastChecked) >= 10000.0d) {
                this.lastChecked = (int) d2;
                log.info("Words so far " + d2 + " out of " + this.totalWords);
            }
            job.setResult(arrayList);
            this.stateTracker.increment(NUM_WORDS_SO_FAR, i3);
        }
    }

    public void update(Object... objArr) {
    }

    public void setup(Configuration configuration) {
        this.vectorLength = configuration.getInt(VECTOR_LENGTH, 50);
        this.useAdaGrad = configuration.getBoolean(ADAGRAD, false);
        this.negative = configuration.getFloat(NEGATIVE, 5.0f);
        this.numWords = configuration.getInt(NUM_WORDS, 1);
        this.window = configuration.getInt(WINDOW, 5);
        this.alpha = configuration.getFloat(ALPHA, 0.025f);
        this.minAlpha = configuration.getFloat(MIN_ALPHA, 0.01f);
        this.totalWords = configuration.getInt(NUM_WORDS, 1);
        this.iterations = configuration.getInt(ITERATIONS, 5);
        initExpTable();
        String str = configuration.get("org.deeplearning4j.scaleout.statetracker.connectionstring");
        log.info("Creating state tracker with connection string " + str);
        if (this.stateTracker == null) {
            try {
                this.stateTracker = new HazelCastStateTracker(str);
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        if (this.negative > 0.0d) {
            try {
                this.table = Nd4j.read(new DataInputStream(new ByteArrayInputStream(configuration.get(TABLE).getBytes())));
            } catch (IOException e2) {
                e2.printStackTrace();
            }
        }
    }

    public static void configure(InMemoryLookupTable inMemoryLookupTable, InvertedIndex invertedIndex, Configuration configuration) {
        configuration.setInt(VECTOR_LENGTH, inMemoryLookupTable.getVectorLength());
        configuration.setBoolean(ADAGRAD, inMemoryLookupTable.isUseAdaGrad());
        configuration.setFloat(NEGATIVE, (float) inMemoryLookupTable.getNegative());
        configuration.setFloat(ALPHA, (float) inMemoryLookupTable.getLr().get());
        configuration.setLong(NUM_WORDS, invertedIndex.totalWords());
        configuration.set("org.deeplearning4j.scaleout.aggregator", Word2VecJobAggregator.class.getName());
        configuration.set("org.deeplearning4j.scaleout.perform.workerperformer", Word2VecPerformerFactory.class.getName());
        inMemoryLookupTable.resetWeights();
        if (inMemoryLookupTable.getNegative() > 0.0d) {
            ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
            try {
                Nd4j.write(inMemoryLookupTable.getTable(), new DataOutputStream(byteArrayOutputStream));
            } catch (IOException e) {
                e.printStackTrace();
            }
            configuration.set(TABLE, new String(byteArrayOutputStream.toByteArray()));
        }
    }

    public void trainSentence(List<VocabWord> list, Word2VecWork word2VecWork, double d) {
        if (list == null || list.isEmpty()) {
            return;
        }
        for (int i = 0; i < list.size(); i++) {
            if (!list.get(i).getWord().endsWith("STOP")) {
                this.nextRandom.set((this.nextRandom.get() * 25214903917L) + 11);
                skipGram(i, list, ((int) this.nextRandom.get()) % this.window, word2VecWork, d);
            }
        }
    }

    public void skipGram(int i, List<VocabWord> list, int i2, Word2VecWork word2VecWork, double d) {
        int i3;
        VocabWord vocabWord = list.get(i);
        if (vocabWord == null || list.isEmpty()) {
            return;
        }
        int i4 = ((this.window * 2) + 1) - i2;
        for (int i5 = i2; i5 < i4; i5++) {
            if (i5 != this.window && (i3 = (i - this.window) + i5) >= 0 && i3 < list.size()) {
                iterateSample(word2VecWork, vocabWord, list.get(i3), d);
            }
        }
    }

    public void iterateSample(Word2VecWork word2VecWork, VocabWord vocabWord, VocabWord vocabWord2, double d) {
        int i;
        double gradient;
        int length;
        if (vocabWord2 == null || vocabWord2.getIndex() < 0) {
            return;
        }
        if (word2VecWork.getVectors().get(vocabWord2.getWord()) == null) {
            log.warn("No vector found for word " + vocabWord2.getWord());
            return;
        }
        if (word2VecWork.getVectors().get(vocabWord.getWord()) == null) {
            log.warn("No vector found for word " + vocabWord.getWord());
            return;
        }
        INDArray iNDArray = (INDArray) word2VecWork.getVectors().get(vocabWord2.getWord()).getSecond();
        INDArray create = Nd4j.create(this.vectorLength);
        for (int i2 = 0; i2 < vocabWord.getCodeLength(); i2++) {
            int intValue = vocabWord.getCodes().get(i2).intValue();
            int intValue2 = vocabWord.getPoints().get(i2).intValue();
            if (word2VecWork.getIndexes().get(Integer.valueOf(intValue2)) != null) {
                if (word2VecWork.getSyn1Vectors().get(word2VecWork.getIndexes().get(Integer.valueOf(intValue2)).getWord()) == null) {
                    log.warn("Syn1 vectors for " + word2VecWork.getIndexes().get(Integer.valueOf(intValue2)).getWord() + " was null");
                } else {
                    INDArray iNDArray2 = word2VecWork.getSyn1Vectors().get(word2VecWork.getIndexes().get(Integer.valueOf(intValue2)).getWord());
                    double dot = Nd4j.getBlasWrapper().dot(iNDArray, iNDArray2);
                    if (dot >= (-MAX_EXP) && dot < MAX_EXP && (length = (int) ((dot + MAX_EXP) * ((this.expTable.length / MAX_EXP) / 2.0d))) < this.expTable.length) {
                        double gradient2 = ((1 - intValue) - this.expTable[length]) * (this.useAdaGrad ? vocabWord.getGradient(i2, d) : d);
                        if (create.data().dataType() == DataBuffer.Type.DOUBLE) {
                            Nd4j.getBlasWrapper().axpy(gradient2, iNDArray2, create);
                            Nd4j.getBlasWrapper().axpy(gradient2, iNDArray, iNDArray2);
                        } else {
                            Nd4j.getBlasWrapper().axpy((float) gradient2, iNDArray2, create);
                            Nd4j.getBlasWrapper().axpy((float) gradient2, iNDArray, iNDArray2);
                        }
                    }
                }
            }
        }
        if (this.negative > 0.0d) {
            int index = vocabWord.getIndex();
            INDArray iNDArray3 = (INDArray) word2VecWork.getNegativeVectors().get(word2VecWork.getIndexes().get(Integer.valueOf(index)).getWord()).getSecond();
            for (int i3 = 0; i3 < this.negative + 1.0d; i3++) {
                if (i3 == 0) {
                    i = 1;
                } else {
                    this.nextRandom.set((this.nextRandom.get() * 25214903917L) + 11);
                    index = this.table.getInt(new int[]{((int) (this.nextRandom.get() >> 16)) % this.table.length()});
                    if (index == 0) {
                        index = (((int) this.nextRandom.get()) % (this.numWords - 1)) + 1;
                    }
                    if (index != vocabWord.getIndex()) {
                        i = 0;
                    }
                }
                double dot2 = Nd4j.getBlasWrapper().dot(iNDArray, iNDArray3);
                if (dot2 > MAX_EXP) {
                    gradient = this.useAdaGrad ? vocabWord.getGradient(index, i - 1) : (i - 1) * d;
                } else if (dot2 < (-MAX_EXP)) {
                    gradient = (i - 0) * (this.useAdaGrad ? vocabWord.getGradient(index, d) : d);
                } else {
                    gradient = this.useAdaGrad ? vocabWord.getGradient(index, i - this.expTable[(int) ((dot2 + MAX_EXP) * ((this.expTable.length / MAX_EXP) / 2.0d))]) : (i - this.expTable[(int) ((dot2 + MAX_EXP) * ((this.expTable.length / MAX_EXP) / 2.0d))]) * d;
                }
                if (iNDArray3.data().dataType() == DataBuffer.Type.DOUBLE) {
                    Nd4j.getBlasWrapper().axpy(gradient, create, iNDArray);
                } else {
                    Nd4j.getBlasWrapper().axpy((float) gradient, create, iNDArray);
                }
                if (iNDArray3.data().dataType() == DataBuffer.Type.DOUBLE) {
                    Nd4j.getBlasWrapper().axpy(gradient, iNDArray3, iNDArray);
                } else {
                    Nd4j.getBlasWrapper().axpy((float) gradient, iNDArray3, iNDArray);
                }
            }
        }
        if (create.data().dataType() == DataBuffer.Type.DOUBLE) {
            Nd4j.getBlasWrapper().axpy(1.0d, create, iNDArray);
        } else {
            Nd4j.getBlasWrapper().axpy(1.0f, create, iNDArray);
        }
    }

    private void initExpTable() {
        for (int i = 0; i < this.expTable.length; i++) {
            double exp = FastMath.exp((((i / this.expTable.length) * 2.0d) - 1.0d) * MAX_EXP);
            this.expTable[i] = exp / (exp + 1.0d);
        }
    }
}
