/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.word2vec.wordstore;

import java.beans.ConstructorProperties;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.Huffman;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class VocabConstructor<T extends SequenceElement> {
    private List<VocabSource<T>> sources = new ArrayList<VocabSource<T>>();
    private VocabCache<T> cache;
    private List<String> stopWords;
    private boolean useAdaGrad = false;
    private boolean fetchLabels = false;
    protected static final Logger log = LoggerFactory.getLogger(VocabConstructor.class);

    private VocabConstructor() {
    }

    protected WeightLookupTable<T> buildExtendedLookupTable() {
        return null;
    }

    protected VocabCache<T> buildExtendedVocabulary() {
        return null;
    }

    public VocabCache<T> buildJointVocabulary(boolean resetCounters, boolean buildHuffmanTree) {
        if (resetCounters && buildHuffmanTree) {
            throw new IllegalStateException("You can't reset counters and build Huffman tree at the same time!");
        }
        if (this.cache == null) {
            throw new IllegalStateException("Cache is null, building fresh one");
        }
        if (this.cache == null) {
            this.cache = new AbstractCache.Builder().build();
        }
        log.debug("Target vocab size before building: [" + this.cache.numWords() + "]");
        AtomicLong sequenceCounter = new AtomicLong(0L);
        AtomicLong elementsCounter = new AtomicLong(0L);
        AbstractCache topHolder = new AbstractCache.Builder().minElementFrequency(0).build();
        int cnt = 0;
        for (VocabSource<T> source : this.sources) {
            SequenceIterator<T> iterator = source.getIterator();
            iterator.reset();
            log.debug("Trying source iterator: [" + cnt + "]");
            log.debug("Target vocab size before building: [" + this.cache.numWords() + "]");
            ++cnt;
            AbstractCache<T> tempHolder = new AbstractCache.Builder().build();
            int sequences = 0;
            long counter = 0L;
            while (iterator.hasMoreSequences()) {
                Sequence<T> document = iterator.nextSequence();
                sequenceCounter.incrementAndGet();
                if (this.fetchLabels) {
                    T labelWord = document.getSequenceLabel();
                    ((SequenceElement)labelWord).setSpecial(true);
                    ((SequenceElement)labelWord).setElementFrequency(1L);
                    tempHolder.addToken(labelWord);
                }
                List<String> tokens = document.asLabels();
                for (String token : tokens) {
                    if (this.stopWords != null && this.stopWords.contains(token) || token == null || token.isEmpty()) continue;
                    if (!tempHolder.containsWord(token)) {
                        tempHolder.addToken(document.getElementByLabel(token));
                        elementsCounter.incrementAndGet();
                        ++counter;
                        continue;
                    }
                    ++counter;
                    tempHolder.incrementWordCount(token);
                }
                ++sequences;
                if (sequenceCounter.get() % 100000L != 0L) continue;
                log.info("Sequences checked: [" + sequenceCounter.get() + "], Current vocabulary size: [" + elementsCounter.get() + "]");
            }
            log.debug("Vocab size before truncation: [" + tempHolder.numWords() + "],  NumWords: [" + tempHolder.totalWordOccurrences() + "], sequences parsed: [" + sequences + "], counter: [" + counter + "]");
            if (source.getMinWordFrequency() > 0) {
                LinkedBlockingQueue<String> labelsToRemove = new LinkedBlockingQueue<String>();
                for (SequenceElement element : tempHolder.vocabWords()) {
                    if (!(element.getElementFrequency() < (double)source.getMinWordFrequency()) || element.isSpecial()) continue;
                    labelsToRemove.add(element.getLabel());
                }
                for (String label : labelsToRemove) {
                    tempHolder.removeElement(label);
                }
            }
            log.debug("Vocab size after truncation: [" + tempHolder.numWords() + "],  NumWords: [" + tempHolder.totalWordOccurrences() + "], sequences parsed: [" + sequences + "], counter: [" + counter + "]");
            topHolder.importVocabulary(tempHolder);
        }
        this.cache.importVocabulary(topHolder);
        if (resetCounters) {
            for (SequenceElement element : this.cache.vocabWords()) {
                element.setElementFrequency(0L);
            }
            this.cache.updateWordsOccurencies();
        }
        if (buildHuffmanTree) {
            Huffman huffman = new Huffman(this.cache.vocabWords());
            huffman.build();
            huffman.applyIndexes(this.cache);
        }
        log.info("Sequences checked: [" + sequenceCounter.get() + "], Current vocabulary size: [" + this.cache.numWords() + "]");
        return this.cache;
    }

    private static class VocabSource<T extends SequenceElement> {
        @NonNull
        private SequenceIterator<T> iterator;
        @NonNull
        private int minWordFrequency;

        @ConstructorProperties(value={"iterator", "minWordFrequency"})
        public VocabSource(@NonNull SequenceIterator<T> iterator, @NonNull int minWordFrequency) {
            if (iterator == null) {
                throw new NullPointerException("iterator");
            }
            this.iterator = iterator;
            this.minWordFrequency = minWordFrequency;
        }

        @NonNull
        public SequenceIterator<T> getIterator() {
            return this.iterator;
        }

        @NonNull
        public int getMinWordFrequency() {
            return this.minWordFrequency;
        }

        public void setIterator(@NonNull SequenceIterator<T> iterator) {
            if (iterator == null) {
                throw new NullPointerException("iterator");
            }
            this.iterator = iterator;
        }

        public void setMinWordFrequency(@NonNull int minWordFrequency) {
            this.minWordFrequency = minWordFrequency;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof VocabSource)) {
                return false;
            }
            VocabSource other = (VocabSource)o;
            if (!other.canEqual(this)) {
                return false;
            }
            SequenceIterator<T> this$iterator = this.getIterator();
            SequenceIterator<T> other$iterator = other.getIterator();
            if (this$iterator == null ? other$iterator != null : !this$iterator.equals(other$iterator)) {
                return false;
            }
            return this.getMinWordFrequency() == other.getMinWordFrequency();
        }

        protected boolean canEqual(Object other) {
            return other instanceof VocabSource;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            SequenceIterator<T> $iterator = this.getIterator();
            result = result * 59 + ($iterator == null ? 0 : $iterator.hashCode());
            result = result * 59 + this.getMinWordFrequency();
            return result;
        }

        public String toString() {
            return "VocabConstructor.VocabSource(iterator=" + this.getIterator() + ", minWordFrequency=" + this.getMinWordFrequency() + ")";
        }
    }

    public static class Builder<T extends SequenceElement> {
        private List<VocabSource<T>> sources = new ArrayList<VocabSource<T>>();
        private VocabCache<T> cache;
        private List<String> stopWords = new ArrayList<String>();
        private boolean useAdaGrad = false;
        private boolean fetchLabels = false;

        protected Builder<T> useAdaGrad(boolean useAdaGrad) {
            this.useAdaGrad = useAdaGrad;
            return this;
        }

        public Builder<T> setTargetVocabCache(@NonNull VocabCache<T> cache) {
            if (cache == null) {
                throw new NullPointerException("cache");
            }
            this.cache = cache;
            return this;
        }

        public Builder<T> addSource(@NonNull SequenceIterator<T> iterator, int minElementFrequency) {
            if (iterator == null) {
                throw new NullPointerException("iterator");
            }
            this.sources.add(new VocabSource<T>(iterator, minElementFrequency));
            return this;
        }

        public Builder<T> setStopWords(@NonNull List<String> stopWords) {
            if (stopWords == null) {
                throw new NullPointerException("stopWords");
            }
            this.stopWords = stopWords;
            return this;
        }

        public Builder<T> fetchLabels(boolean reallyFetch) {
            this.fetchLabels = reallyFetch;
            return this;
        }

        public VocabConstructor<T> build() {
            VocabConstructor constructor = new VocabConstructor();
            constructor.sources = this.sources;
            constructor.cache = this.cache;
            constructor.stopWords = this.stopWords;
            constructor.useAdaGrad = this.useAdaGrad;
            constructor.fetchLabels = this.fetchLabels;
            return constructor;
        }
    }
}

