/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.embeddings.loader;

import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.MapperFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import java.io.BufferedInputStream;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.DataInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.io.Reader;
import java.util.ArrayList;
import java.util.zip.GZIPInputStream;
import lombok.NonNull;
import org.apache.commons.compress.compressors.gzip.GzipUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.LineIterator;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.loader.Word2VecConfiguration;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectorsImpl;
import org.deeplearning4j.models.glove.Glove;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.models.word2vec.wordstore.VocabularyHolder;
import org.deeplearning4j.models.word2vec.wordstore.VocabularyWord;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.InMemoryLookupCache;
import org.deeplearning4j.text.sentenceiterator.LineSentenceIterator;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class WordVectorSerializer {
    private static final boolean DEFAULT_LINEBREAKS = false;
    private static final boolean HAS_HEADER = true;
    private static final int MAX_SIZE = 50;
    private static final Logger log = LoggerFactory.getLogger(WordVectorSerializer.class);

    public static WordVectors loadGoogleModel(File modelFile, boolean binary) throws IOException {
        return WordVectorSerializer.loadGoogleModel(modelFile, binary, false);
    }

    public static WordVectors loadGoogleModel(File modelFile, boolean binary, boolean lineBreaks) throws IOException {
        return binary ? WordVectorSerializer.readBinaryModel(modelFile, lineBreaks) : WordVectorSerializer.fromPair(WordVectorSerializer.loadTxt(modelFile));
    }

    private static Word2Vec readTextModel(File modelFile) throws IOException, NumberFormatException {
        Word2Vec ret = new Word2Vec();
        try (BufferedReader reader = new BufferedReader(new InputStreamReader(GzipUtils.isCompressedFilename((String)modelFile.getName()) ? new GZIPInputStream(new FileInputStream(modelFile)) : new FileInputStream(modelFile)));){
            String line = reader.readLine();
            String[] initial = line.split(" ");
            int words = Integer.parseInt(initial[0]);
            int layerSize = Integer.parseInt(initial[1]);
            INDArray syn0 = Nd4j.create((int)words, (int)layerSize);
            InMemoryLookupCache cache = new InMemoryLookupCache(false);
            int currLine = 0;
            while ((line = reader.readLine()) != null) {
                String[] split = line.split(" ");
                assert (split.length == layerSize + 1);
                String word = split[0];
                float[] vector = new float[split.length - 1];
                for (int i = 1; i < split.length; ++i) {
                    vector[i - 1] = Float.parseFloat(split[i]);
                }
                syn0.putRow(currLine, Transforms.unitVec((INDArray)Nd4j.create((float[])vector)));
                cache.addWordToIndex(cache.numWords(), word);
                cache.addToken(new VocabWord(1.0, word));
                cache.putVocabWord(word);
                ++currLine;
            }
            InMemoryLookupTable lookupTable = (InMemoryLookupTable)new InMemoryLookupTable.Builder().cache(cache).vectorLength(layerSize).build();
            lookupTable.setSyn0(syn0);
            ret.setVocab(cache);
            ret.setLookupTable(lookupTable);
        }
        return ret;
    }

    private static Word2Vec readBinaryModel(File modelFile, boolean linebreaks) throws NumberFormatException, IOException {
        InMemoryLookupTable lookupTable;
        InMemoryLookupCache cache;
        INDArray syn0;
        try (BufferedInputStream bis = new BufferedInputStream(GzipUtils.isCompressedFilename((String)modelFile.getName()) ? new GZIPInputStream(new FileInputStream(modelFile)) : new FileInputStream(modelFile));
             DataInputStream dis = new DataInputStream(bis);){
            int words = Integer.parseInt(WordVectorSerializer.readString(dis));
            int size = Integer.parseInt(WordVectorSerializer.readString(dis));
            syn0 = Nd4j.create((int)words, (int)size);
            cache = new InMemoryLookupCache(false);
            lookupTable = (InMemoryLookupTable)new InMemoryLookupTable.Builder().cache(cache).vectorLength(size).build();
            for (int i = 0; i < words; ++i) {
                String word = WordVectorSerializer.readString(dis);
                log.trace("Loading " + word + " with word " + i);
                float[] vector = new float[size];
                for (int j = 0; j < size; ++j) {
                    vector[j] = WordVectorSerializer.readFloat(dis);
                }
                syn0.putRow(i, Transforms.unitVec((INDArray)Nd4j.create((float[])vector)));
                cache.addWordToIndex(cache.numWords(), word);
                cache.addToken(new VocabWord(1.0, word));
                cache.putVocabWord(word);
                if (!linebreaks) continue;
                dis.readByte();
            }
        }
        Word2Vec ret = new Word2Vec();
        lookupTable.setSyn0(syn0);
        ret.setVocab(cache);
        ret.setLookupTable(lookupTable);
        return ret;
    }

    public static float readFloat(InputStream is) throws IOException {
        byte[] bytes = new byte[4];
        is.read(bytes);
        return WordVectorSerializer.getFloat(bytes);
    }

    public static float getFloat(byte[] b) {
        int accum = 0;
        accum |= (b[0] & 0xFF) << 0;
        accum |= (b[1] & 0xFF) << 8;
        accum |= (b[2] & 0xFF) << 16;
        return Float.intBitsToFloat(accum |= (b[3] & 0xFF) << 24);
    }

    public static String readString(DataInputStream dis) throws IOException {
        byte[] bytes = new byte[50];
        byte b = dis.readByte();
        int i = -1;
        StringBuilder sb = new StringBuilder();
        while (b != 32 && b != 10) {
            bytes[++i] = b;
            b = dis.readByte();
            if (i != 49) continue;
            sb.append(new String(bytes));
            i = -1;
            bytes = new byte[50];
        }
        sb.append(new String(bytes, 0, i + 1));
        return sb.toString();
    }

    public static void writeWordVectors(InMemoryLookupTable lookupTable, InMemoryLookupCache cache, String path) throws IOException {
        BufferedWriter write = new BufferedWriter(new FileWriter(new File(path), false));
        for (int i = 0; i < lookupTable.getSyn0().rows(); ++i) {
            String word = cache.wordAtIndex(i);
            if (word == null) continue;
            StringBuilder sb = new StringBuilder();
            sb.append(word.replaceAll(" ", "_"));
            sb.append(" ");
            INDArray wordVector = lookupTable.vector(word);
            for (int j = 0; j < wordVector.length(); ++j) {
                sb.append(wordVector.getDouble(j));
                if (j >= wordVector.length() - 1) continue;
                sb.append(" ");
            }
            sb.append("\n");
            write.write(sb.toString());
        }
        write.flush();
        write.close();
    }

    private static ObjectMapper getModelMapper() {
        ObjectMapper ret = new ObjectMapper();
        ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
        ret.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
        ret.configure(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY, true);
        ret.enable(SerializationFeature.INDENT_OUTPUT);
        return ret;
    }

    public static void writeFullModel(@NonNull Word2Vec vec, @NonNull String path) {
        int x;
        if (vec == null) {
            throw new NullPointerException("vec");
        }
        if (path == null) {
            throw new NullPointerException("path");
        }
        PrintWriter printWriter = null;
        try {
            printWriter = new PrintWriter(path);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        WeightLookupTable lookupTable = vec.getLookupTable();
        VocabCache vocabCache = vec.getVocab();
        if (!(lookupTable instanceof InMemoryLookupTable)) {
            throw new IllegalStateException("At this moment only InMemoryLookupTable is supported.");
        }
        if (!(vocabCache instanceof InMemoryLookupCache)) {
            throw new IllegalStateException("At this moment only InMemoryLookupCache is supported.");
        }
        Word2VecConfiguration conf = vec.getConfiguration();
        conf.setVocabSize(vocabCache.numWords());
        VocabularyHolder holder = new VocabularyHolder.Builder().externalCache(vocabCache).build();
        printWriter.println(conf.toJson());
        log.info("Word2Vec conf. JSON: " + conf.toJson());
        StringBuilder builder = new StringBuilder();
        for (x = 0; x < ((InMemoryLookupTable)lookupTable).getExpTable().length; ++x) {
            builder.append(((InMemoryLookupTable)lookupTable).getExpTable()[x]).append(" ");
        }
        printWriter.println(builder.toString().trim());
        if (conf.getNegative() > 0.0) {
            builder = new StringBuilder();
            for (x = 0; x < ((InMemoryLookupTable)lookupTable).getTable().columns(); ++x) {
                builder.append(((InMemoryLookupTable)lookupTable).getTable().getDouble(x)).append(" ");
            }
            printWriter.println(builder.toString().trim());
        } else {
            printWriter.println("");
        }
        ArrayList<VocabWord> words = new ArrayList<VocabWord>(((InMemoryLookupCache)vocabCache).getVocabs().values());
        for (VocabWord word : words) {
            int x2;
            VocabularyWord vw = new VocabularyWord(word.getWord());
            vw.setCount(vocabCache.wordFrequency(word.getWord()));
            vw.setHuffmanNode(VocabularyHolder.buildNode(word.getCodes(), word.getPoints(), word.getCodeLength(), word.getIndex()));
            INDArray syn0 = ((InMemoryLookupTable)lookupTable).getSyn0().getRow(vocabCache.indexOf(word.getWord()));
            double[] dsyn0 = new double[syn0.columns()];
            for (int x3 = 0; x3 < conf.getLayersSize(); ++x3) {
                dsyn0[x3] = syn0.getDouble(x3);
            }
            vw.setSyn0(dsyn0);
            INDArray syn1 = ((InMemoryLookupTable)lookupTable).getSyn1().getRow(vocabCache.indexOf(word.getWord()));
            double[] dsyn1 = new double[syn1.columns()];
            for (int x4 = 0; x4 < syn1.columns(); ++x4) {
                dsyn1[x4] = syn1.getDouble(x4);
            }
            vw.setSyn1(dsyn1);
            if (conf.getNegative() > 0.0) {
                INDArray syn1Neg = ((InMemoryLookupTable)lookupTable).getSyn1Neg().getRow(vocabCache.indexOf(word.getWord()));
                double[] dsyn1Neg = new double[syn1Neg.columns()];
                for (x2 = 0; x2 < syn1Neg.columns(); ++x2) {
                    dsyn1Neg[x2] = syn1Neg.getDouble(x2);
                }
                vw.setSyn1Neg(dsyn1Neg);
            }
            if (conf.isUseAdaGrad() && ((InMemoryLookupTable)lookupTable).isUseAdaGrad()) {
                INDArray gradient = word.getHistoricalGradient();
                if (gradient == null) {
                    gradient = Nd4j.zeros((int)word.getCodes().size());
                }
                double[] ada = new double[gradient.columns()];
                for (x2 = 0; x2 < gradient.columns(); ++x2) {
                    ada[x2] = gradient.getDouble(x2);
                }
                vw.setHistoricalGradient(ada);
            }
            printWriter.println(vw.toJson());
        }
        printWriter.flush();
        printWriter.close();
    }

    public static Word2Vec loadFullModel(@NonNull String path) {
        if (path == null) {
            throw new NullPointerException("path");
        }
        LineSentenceIterator iterator = new LineSentenceIterator(new File(path));
        String confJson = iterator.nextSentence();
        log.info("Word2Vec conf. JSON: " + confJson);
        Word2VecConfiguration configuration = Word2VecConfiguration.fromJson(confJson);
        String eTable = iterator.nextSentence();
        String nTable = iterator.nextSentence();
        if (configuration.getNegative() > 0.0) {
            // empty if block
        }
        VocabularyHolder holder = new VocabularyHolder.Builder().minWordFrequency(configuration.getMinWordFrequency()).hugeModelExpected(configuration.isHugeModelExpected()).scavengerActivationThreshold(configuration.getScavengerActivationThreshold()).scavengerRetentionDelay(configuration.getScavengerRetentionDelay()).build();
        while (iterator.hasNext()) {
            String wordJson = iterator.nextSentence();
            VocabularyWord word = VocabularyWord.fromJson(wordJson);
            word.setSpecial(true);
            holder.addWord(word);
        }
        InMemoryLookupCache vocabCache = new InMemoryLookupCache(false);
        holder.transferBackToVocabCache(vocabCache, false);
        InMemoryLookupTable lookupTable = (InMemoryLookupTable)new InMemoryLookupTable.Builder().negative(configuration.getNegative()).useAdaGrad(configuration.isUseAdaGrad()).lr(configuration.getLearningRate()).cache(vocabCache).vectorLength(configuration.getLayersSize()).build();
        lookupTable.resetWeights(true);
        for (VocabularyWord word : holder.getVocabulary()) {
            INDArray syn0 = lookupTable.getSyn0().getRow(vocabCache.indexOf(word.getWord()));
            syn0.assign(Nd4j.create((double[])word.getSyn0()));
            INDArray syn1 = lookupTable.getSyn1().getRow(vocabCache.indexOf(word.getWord()));
            syn1.assign(Nd4j.create((double[])word.getSyn1()));
            if (!(configuration.getNegative() > 0.0)) continue;
            INDArray syn1Neg = lookupTable.getSyn1Neg().getRow(vocabCache.indexOf(word.getWord()));
            syn1Neg.assign(Nd4j.create((double[])word.getSyn1Neg()));
        }
        return new Word2Vec.Builder(configuration).vocabCache(vocabCache).lookupTable(lookupTable).resetModel(false).build();
    }

    public static void writeWordVectors(WordVectors vec, String path) throws IOException {
        BufferedWriter write = new BufferedWriter(new FileWriter(new File(path), false));
        int words = 0;
        for (String word : vec.vocab().words()) {
            if (word == null) continue;
            StringBuilder sb = new StringBuilder();
            sb.append(word.replaceAll(" ", "_"));
            sb.append(" ");
            INDArray wordVector = vec.getWordVectorMatrix(word);
            for (int j = 0; j < wordVector.length(); ++j) {
                sb.append(wordVector.getDouble(j));
                if (j >= wordVector.length() - 1) continue;
                sb.append(" ");
            }
            sb.append("\n");
            write.write(sb.toString());
            ++words;
        }
        log.info("Wrote " + words + " with size of " + vec.lookupTable().layerSize());
        write.flush();
        write.close();
    }

    public static WordVectors fromTableAndVocab(WeightLookupTable table, VocabCache vocab) {
        WordVectorsImpl vectors = new WordVectorsImpl();
        vectors.setLookupTable(table);
        vectors.setVocab(vocab);
        return vectors;
    }

    public static WordVectors fromPair(Pair<InMemoryLookupTable, VocabCache> pair) {
        WordVectorsImpl vectors = new WordVectorsImpl();
        vectors.setLookupTable((WeightLookupTable)pair.getFirst());
        vectors.setVocab((VocabCache)pair.getSecond());
        return vectors;
    }

    public static WordVectors loadTxtVectors(File vectorsFile) throws FileNotFoundException {
        Pair<InMemoryLookupTable, VocabCache> pair = WordVectorSerializer.loadTxt(vectorsFile);
        return WordVectorSerializer.fromPair(pair);
    }

    public static Pair<InMemoryLookupTable, VocabCache> loadTxt(File vectorsFile) throws FileNotFoundException {
        BufferedReader reader = new BufferedReader(new FileReader(vectorsFile));
        InMemoryLookupCache cache = new InMemoryLookupCache();
        LineIterator iter = IOUtils.lineIterator((Reader)reader);
        String line = null;
        boolean hasHeader = false;
        if (iter.hasNext() && !(line = iter.nextLine()).contains(" ")) {
            hasHeader = true;
        }
        if (hasHeader) {
            iter.close();
            iter = IOUtils.lineIterator((Reader)reader);
            iter.nextLine();
        }
        ArrayList<INDArray> arrays = new ArrayList<INDArray>();
        while (iter.hasNext()) {
            line = iter.nextLine();
            String[] split = line.split(" ");
            String word = split[0];
            VocabWord word1 = new VocabWord(1.0, word);
            cache.addToken(word1);
            cache.addWordToIndex(cache.numWords(), word);
            word1.setIndex(cache.numWords());
            cache.putVocabWord(word);
            INDArray row = Nd4j.create((DataBuffer)Nd4j.createBuffer((int)(split.length - 1)));
            for (int i = 1; i < split.length; ++i) {
                row.putScalar(i - 1, Float.parseFloat(split[i]));
            }
            arrays.add(row);
        }
        INDArray syn = Nd4j.create((int[])new int[]{arrays.size(), ((INDArray)arrays.get(0)).columns()});
        for (int i = 0; i < syn.rows(); ++i) {
            syn.putRow(i, (INDArray)arrays.get(i));
        }
        InMemoryLookupTable lookupTable = (InMemoryLookupTable)new InMemoryLookupTable.Builder().vectorLength(((INDArray)arrays.get(0)).columns()).useAdaGrad(false).cache(cache).build();
        Nd4j.clearNans((INDArray)syn);
        lookupTable.setSyn0(syn);
        iter.close();
        return new Pair((Object)lookupTable, (Object)cache);
    }

    public static void writeTsneFormat(Glove vec, INDArray tsne, File csv) throws Exception {
        BufferedWriter write = new BufferedWriter(new FileWriter(csv));
        int words = 0;
        InMemoryLookupCache l = (InMemoryLookupCache)vec.vocab();
        for (String word : vec.vocab().words()) {
            if (word == null) continue;
            StringBuilder sb = new StringBuilder();
            INDArray wordVector = tsne.getRow(l.wordFor(word).getIndex());
            for (int j = 0; j < wordVector.length(); ++j) {
                sb.append(wordVector.getDouble(j));
                if (j >= wordVector.length() - 1) continue;
                sb.append(",");
            }
            sb.append(",");
            sb.append(word);
            sb.append(" ");
            sb.append("\n");
            write.write(sb.toString());
        }
        log.info("Wrote " + words + " with size of " + vec.lookupTable().getVectorLength());
        write.flush();
        write.close();
    }

    public static void writeTsneFormat(Word2Vec vec, INDArray tsne, File csv) throws Exception {
        BufferedWriter write = new BufferedWriter(new FileWriter(csv));
        int words = 0;
        InMemoryLookupCache l = (InMemoryLookupCache)vec.vocab();
        for (String word : vec.vocab().words()) {
            if (word == null) continue;
            StringBuilder sb = new StringBuilder();
            INDArray wordVector = tsne.getRow(l.wordFor(word).getIndex());
            for (int j = 0; j < wordVector.length(); ++j) {
                sb.append(wordVector.getDouble(j));
                if (j >= wordVector.length() - 1) continue;
                sb.append(",");
            }
            sb.append(",");
            sb.append(word);
            sb.append(" ");
            sb.append("\n");
            write.write(sb.toString());
        }
        log.info("Wrote " + words + " with size of " + vec.lookupTable().layerSize());
        write.flush();
        write.close();
    }
}

