/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.embeddings.loader;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.FilterInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.Reader;
import java.util.ArrayList;
import java.util.UUID;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.LineIterator;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectorsImpl;
import org.deeplearning4j.models.glove.Glove;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.InMemoryLookupCache;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class WordVectorSerializer {
    private static final int MAX_SIZE = 50;
    private static Logger log = LoggerFactory.getLogger(WordVectorSerializer.class);

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static Word2Vec loadGoogleModel(String path, boolean binary) throws IOException {
        ArrayList<File> vectorPaths = new ArrayList<File>();
        File rootDir = new File("." + UUID.randomUUID().toString());
        if (!rootDir.mkdirs()) {
            throw new IllegalStateException("Unable to create directory for word vectors");
        }
        if (binary) {
            WeightLookupTable lookupTable;
            InMemoryLookupCache cache;
            FilterInputStream dis = null;
            BufferedInputStream bis = null;
            int size = 0;
            try {
                bis = new BufferedInputStream(new FileInputStream(path));
                dis = new DataInputStream(bis);
                int words = Integer.parseInt(WordVectorSerializer.readString((DataInputStream)dis));
                size = Integer.parseInt(WordVectorSerializer.readString((DataInputStream)dis));
                cache = new InMemoryLookupCache();
                lookupTable = new InMemoryLookupTable.Builder().cache(cache).vectorLength(size).build();
                for (int i = 0; i < words; ++i) {
                    int j;
                    String word = WordVectorSerializer.readString((DataInputStream)dis);
                    if (word.isEmpty()) continue;
                    float[] vectors = new float[size];
                    double len = 0.0;
                    for (j = 0; j < size; ++j) {
                        float vector = WordVectorSerializer.readFloat(dis);
                        len += (double)(vector * vector);
                        vectors[j] = vector;
                    }
                    len = Math.sqrt(len);
                    j = 0;
                    while (j < size) {
                        int n = j++;
                        vectors[n] = (float)((double)vectors[n] / len);
                    }
                    File write = new File(rootDir, String.valueOf(i));
                    vectorPaths.add(write);
                    WordVectorSerializer.writeVector(vectors, write);
                    cache.addWordToIndex(cache.numWords(), word);
                    cache.addToken(new VocabWord(1.0, word));
                    cache.putVocabWord(word);
                    dis.read();
                }
            }
            finally {
                bis.close();
                dis.close();
            }
            Word2Vec ret = new Word2Vec();
            lookupTable.resetWeights();
            for (int i = 0; i < vectorPaths.size(); ++i) {
                float[] read = WordVectorSerializer.readVec((File)vectorPaths.get(i), size);
                lookupTable.putVector(cache.wordAtIndex(i), Nd4j.create((float[])read));
                ((File)vectorPaths.get(i)).delete();
            }
            ret.setVocab(cache);
            ret.setLookupTable(lookupTable);
            rootDir.delete();
            return ret;
        }
        BufferedReader reader = new BufferedReader(new FileReader(new File(path)));
        String line = reader.readLine();
        String[] initial = line.split(" ");
        int words = Integer.parseInt(initial[0]);
        int layerSize = Integer.parseInt(initial[1]);
        InMemoryLookupCache cache = new InMemoryLookupCache();
        while ((line = reader.readLine()) != null) {
            String[] split = line.split(" ");
            String word = split[0];
            if (word.isEmpty()) continue;
            float[] buffer = new float[layerSize];
            for (int i = 1; i < split.length; ++i) {
                buffer[i - 1] = Float.parseFloat(split[i]);
            }
            File vecFile = new File(rootDir, String.valueOf(cache.numWords()));
            WordVectorSerializer.writeVector(buffer, vecFile);
            vectorPaths.add(vecFile);
            cache.addWordToIndex(cache.numWords(), word);
            cache.addToken(new VocabWord(1.0, word));
            cache.putVocabWord(word);
        }
        WeightLookupTable lookupTable = new InMemoryLookupTable.Builder().cache(cache).vectorLength(layerSize).build();
        lookupTable.resetWeights();
        for (int i = 0; i < words; ++i) {
            float[] read = WordVectorSerializer.readVec((File)vectorPaths.get(i), layerSize);
            lookupTable.putVector(cache.wordAtIndex(i), Nd4j.create((float[])read));
            ((File)vectorPaths.get(i)).delete();
        }
        Word2Vec ret = new Word2Vec();
        ret.setVocab(cache);
        ret.setLookupTable(lookupTable);
        reader.close();
        rootDir.delete();
        return ret;
    }

    private static float[] readVec(File from, int length) throws IOException {
        BufferedInputStream bis = new BufferedInputStream(new FileInputStream(from));
        DataInputStream dis = new DataInputStream(bis);
        float[] ret = new float[length];
        for (int i = 0; i < length; ++i) {
            ret[i] = dis.readFloat();
        }
        dis.close();
        return ret;
    }

    private static void writeVector(float[] vec, File to) throws IOException {
        BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(to));
        DataOutputStream dos = new DataOutputStream(bos);
        for (int i = 0; i < vec.length; ++i) {
            dos.writeFloat(vec[i]);
        }
        bos.flush();
        bos.close();
    }

    private static String readString(DataInputStream dis) throws IOException {
        byte[] bytes = new byte[50];
        byte b = dis.readByte();
        int i = -1;
        StringBuilder sb = new StringBuilder();
        while (b != 32 && b != 10) {
            bytes[++i] = b;
            b = dis.readByte();
            if (i != 49) continue;
            sb.append(new String(bytes));
            i = -1;
            bytes = new byte[50];
        }
        sb.append(new String(bytes, 0, i + 1));
        return sb.toString();
    }

    public static float readFloat(InputStream is) throws IOException {
        byte[] bytes = new byte[4];
        is.read(bytes);
        return WordVectorSerializer.getFloat(bytes);
    }

    public static float getFloat(byte[] b) {
        int accum = 0;
        accum |= (b[0] & 0xFF) << 0;
        accum |= (b[1] & 0xFF) << 8;
        accum |= (b[2] & 0xFF) << 16;
        return Float.intBitsToFloat(accum |= (b[3] & 0xFF) << 24);
    }

    public static void writeWordVectors(InMemoryLookupTable l, InMemoryLookupCache cache, String path) throws IOException {
        BufferedWriter write = new BufferedWriter(new FileWriter(new File(path), false));
        boolean words = false;
        for (int i = 0; i < l.getSyn0().rows(); ++i) {
            String word = cache.wordAtIndex(i);
            if (word == null) continue;
            StringBuffer sb = new StringBuffer();
            sb.append(word);
            sb.append(" ");
            INDArray wordVector = l.vector(word);
            for (int j = 0; j < wordVector.length(); ++j) {
                sb.append(wordVector.getDouble(j));
                if (j >= wordVector.length() - 1) continue;
                sb.append(" ");
            }
            sb.append("\n");
            write.write(sb.toString());
        }
        write.flush();
        write.close();
    }

    public static void writeWordVectors(Word2Vec vec, String path) throws IOException {
        BufferedWriter write = new BufferedWriter(new FileWriter(new File(path), false));
        int words = 0;
        for (String word : vec.vocab().words()) {
            if (word == null) continue;
            StringBuffer sb = new StringBuffer();
            sb.append(word);
            sb.append(" ");
            INDArray wordVector = vec.getWordVectorMatrix(word);
            for (int j = 0; j < wordVector.length(); ++j) {
                sb.append(wordVector.getDouble(j));
                if (j >= wordVector.length() - 1) continue;
                sb.append(" ");
            }
            sb.append("\n");
            write.write(sb.toString());
        }
        log.info("Wrote " + words + " with size of " + vec.lookupTable().layerSize());
        write.flush();
        write.close();
    }

    public static WordVectors loadTxtVectors(File path) throws FileNotFoundException {
        Pair<WeightLookupTable, VocabCache> pair = WordVectorSerializer.loadTxt(path);
        WordVectorsImpl vectors = new WordVectorsImpl();
        vectors.setLookupTable((WeightLookupTable)pair.getFirst());
        vectors.setVocab((VocabCache)pair.getSecond());
        return vectors;
    }

    public static Pair<WeightLookupTable, VocabCache> loadTxt(File path) throws FileNotFoundException {
        BufferedReader write = new BufferedReader(new FileReader(path));
        InMemoryLookupCache cache = new InMemoryLookupCache();
        InMemoryLookupTable l = null;
        LineIterator iter = IOUtils.lineIterator((Reader)write);
        ArrayList<INDArray> arrays = new ArrayList<INDArray>();
        while (iter.hasNext()) {
            String line = iter.nextLine();
            String[] split = line.split(" ");
            String word = split[0];
            VocabWord word1 = new VocabWord(1.0, word);
            cache.addToken(word1);
            cache.addWordToIndex(cache.numWords(), word);
            cache.putVocabWord(word);
            INDArray row = Nd4j.create((DataBuffer)Nd4j.createBuffer((long)(split.length - 1)));
            for (int i = 1; i < split.length; ++i) {
                row.putScalar(i - 1, Float.parseFloat(split[i]));
            }
            arrays.add(row);
        }
        INDArray syn = Nd4j.create((int[])new int[]{arrays.size(), ((INDArray)arrays.get(0)).columns()});
        for (int i = 0; i < syn.rows(); ++i) {
            syn.putRow(i, (INDArray)arrays.get(i));
        }
        l = (InMemoryLookupTable)new InMemoryLookupTable.Builder().vectorLength(((INDArray)arrays.get(0)).columns()).useAdaGrad(false).cache(cache).build();
        Nd4j.clearNans((INDArray)syn);
        l.setSyn0(syn);
        iter.close();
        return new Pair((Object)l, (Object)cache);
    }

    public static void writeTsneFormat(Glove vec, INDArray tsne, File csv) throws Exception {
        BufferedWriter write = new BufferedWriter(new FileWriter(csv));
        int words = 0;
        InMemoryLookupCache l = (InMemoryLookupCache)vec.vocab();
        for (String word : vec.vocab().words()) {
            if (word == null) continue;
            StringBuffer sb = new StringBuffer();
            INDArray wordVector = tsne.getRow(l.wordFor(word).getIndex());
            for (int j = 0; j < wordVector.length(); ++j) {
                sb.append(wordVector.getDouble(j));
                if (j >= wordVector.length() - 1) continue;
                sb.append(",");
            }
            sb.append(",");
            sb.append(word);
            sb.append(" ");
            sb.append("\n");
            write.write(sb.toString());
        }
        log.info("Wrote " + words + " with size of " + vec.lookupTable().getVectorLength());
        write.flush();
        write.close();
    }

    public static void writeTsneFormat(Word2Vec vec, INDArray tsne, File csv) throws Exception {
        BufferedWriter write = new BufferedWriter(new FileWriter(csv));
        int words = 0;
        InMemoryLookupCache l = (InMemoryLookupCache)vec.vocab();
        for (String word : vec.vocab().words()) {
            if (word == null) continue;
            StringBuffer sb = new StringBuffer();
            INDArray wordVector = tsne.getRow(l.wordFor(word).getIndex());
            for (int j = 0; j < wordVector.length(); ++j) {
                sb.append(wordVector.getDouble(j));
                if (j >= wordVector.length() - 1) continue;
                sb.append(",");
            }
            sb.append(",");
            sb.append(word);
            sb.append(" ");
            sb.append("\n");
            write.write(sb.toString());
        }
        log.info("Wrote " + words + " with size of " + vec.lookupTable().layerSize());
        write.flush();
        write.close();
    }
}

