package org.deeplearning4j.models.word2vec.wordstore;

import java.beans.ConstructorProperties;
import java.util.ArrayList;
import java.util.List;
import lombok.NonNull;
import org.deeplearning4j.models.word2vec.wordstore.VocabularyHolder;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.InMemoryLookupCache;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/models/word2vec/wordstore/VocabConstructor.class */
public class VocabConstructor {
    private List<VocabSource> sources;
    private TokenizerFactory tokenizerFactory;
    private VocabCache cache;
    private List<String> stopWords;
    private boolean useAdaGrad;
    protected static final Logger log = LoggerFactory.getLogger(VocabConstructor.class);

    /* loaded from: input_file:org/deeplearning4j/models/word2vec/wordstore/VocabConstructor$Builder.class */
    public static class Builder {
        private TokenizerFactory tokenizerFactory;
        private VocabCache cache;
        private List<VocabSource> sources = new ArrayList();
        private List<String> stopWords = new ArrayList();
        private boolean useAdaGrad = false;

        public Builder useAdaGrad(boolean z) {
            this.useAdaGrad = z;
            return this;
        }

        public Builder setTargetVocabCache(@NonNull VocabCache vocabCache) {
            if (vocabCache == null) {
                throw new NullPointerException("cache");
            }
            this.cache = vocabCache;
            return this;
        }

        public Builder addSource(SentenceIterator sentenceIterator, int i) {
            this.sources.add(new VocabSource(sentenceIterator, i));
            return this;
        }

        public Builder setTokenizerFactory(@NonNull TokenizerFactory tokenizerFactory) {
            if (tokenizerFactory == null) {
                throw new NullPointerException("factory");
            }
            this.tokenizerFactory = tokenizerFactory;
            return this;
        }

        public Builder setStopWords(@NonNull List<String> list) {
            if (list == null) {
                throw new NullPointerException("stopWords");
            }
            this.stopWords = list;
            return this;
        }

        public VocabConstructor build() {
            VocabConstructor vocabConstructor = new VocabConstructor();
            vocabConstructor.sources = this.sources;
            vocabConstructor.tokenizerFactory = this.tokenizerFactory;
            vocabConstructor.cache = this.cache;
            vocabConstructor.stopWords = this.stopWords;
            vocabConstructor.useAdaGrad = this.useAdaGrad;
            return vocabConstructor;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/deeplearning4j/models/word2vec/wordstore/VocabConstructor$VocabSource.class */
    public static class VocabSource {

        @NonNull
        private SentenceIterator iterator;

        @NonNull
        private int minWordFrequency;

        @ConstructorProperties({"iterator", "minWordFrequency"})
        public VocabSource(@NonNull SentenceIterator sentenceIterator, @NonNull int i) {
            if (sentenceIterator == null) {
                throw new NullPointerException("iterator");
            }
            this.iterator = sentenceIterator;
            this.minWordFrequency = i;
        }

        @NonNull
        public SentenceIterator getIterator() {
            return this.iterator;
        }

        @NonNull
        public int getMinWordFrequency() {
            return this.minWordFrequency;
        }

        public void setIterator(@NonNull SentenceIterator sentenceIterator) {
            if (sentenceIterator == null) {
                throw new NullPointerException("iterator");
            }
            this.iterator = sentenceIterator;
        }

        public void setMinWordFrequency(@NonNull int i) {
            this.minWordFrequency = i;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof VocabSource)) {
                return false;
            }
            VocabSource vocabSource = (VocabSource) obj;
            if (!vocabSource.canEqual(this)) {
                return false;
            }
            SentenceIterator iterator = getIterator();
            SentenceIterator iterator2 = vocabSource.getIterator();
            if (iterator == null) {
                if (iterator2 != null) {
                    return false;
                }
            } else if (!iterator.equals(iterator2)) {
                return false;
            }
            return getMinWordFrequency() == vocabSource.getMinWordFrequency();
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof VocabSource;
        }

        public int hashCode() {
            SentenceIterator iterator = getIterator();
            return (((1 * 59) + (iterator == null ? 0 : iterator.hashCode())) * 59) + getMinWordFrequency();
        }

        public String toString() {
            return "VocabConstructor.VocabSource(iterator=" + getIterator() + ", minWordFrequency=" + getMinWordFrequency() + ")";
        }
    }

    private VocabConstructor() {
        this.sources = new ArrayList();
        this.useAdaGrad = false;
    }

    public VocabCache buildJointVocabulary(boolean z, boolean z2) {
        if (z && z2) {
            throw new IllegalStateException("You can't reset counters and build Huffman tree at the same time!");
        }
        if (this.cache == null) {
            this.cache = new InMemoryLookupCache(false);
        }
        VocabularyHolder build = new VocabularyHolder.Builder().externalCache(this.cache).minWordFrequency(0).build();
        for (VocabSource vocabSource : this.sources) {
            SentenceIterator iterator = vocabSource.getIterator();
            iterator.reset();
            VocabularyHolder build2 = new VocabularyHolder.Builder().minWordFrequency(vocabSource.getMinWordFrequency()).build();
            while (iterator.hasNext()) {
                for (String str : this.tokenizerFactory.create(iterator.nextSentence()).getTokens()) {
                    if (this.stopWords == null || !this.stopWords.contains(str)) {
                        if (str != null && !str.isEmpty()) {
                            if (build2.containsWord(str)) {
                                build2.incrementWordCounter(str);
                            } else {
                                build2.addWord(str);
                            }
                        }
                    }
                }
            }
            log.info("Vocab size before truncation: " + build2.numWords());
            build2.truncateVocabulary();
            log.info("Vocab size after truncation: " + build2.numWords());
            build.consumeVocabulary(build2);
        }
        if (z) {
            build.resetWordCounters();
        }
        if (z2) {
            build.updateHuffmanCodes();
        }
        build.transferBackToVocabCache(this.cache);
        return this.cache;
    }
}
