package org.deeplearning4j.bagofwords.vectorizer;

import java.io.BufferedReader;
import java.io.File;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import lombok.NonNull;
import org.apache.commons.io.FileUtils;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
import org.deeplearning4j.text.documentiterator.DocumentIterator;
import org.deeplearning4j.text.documentiterator.LabelAwareIterator;
import org.deeplearning4j.text.documentiterator.LabelsSource;
import org.deeplearning4j.text.documentiterator.interoperability.DocumentIteratorConverter;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.sentenceiterator.interoperability.SentenceIteratorConverter;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.FeatureUtil;

/* loaded from: input_file:org/deeplearning4j/bagofwords/vectorizer/BagOfWordsVectorizer.class */
public class BagOfWordsVectorizer extends BaseTextVectorizer {

    /* loaded from: input_file:org/deeplearning4j/bagofwords/vectorizer/BagOfWordsVectorizer$Builder.class */
    public static class Builder {
        protected TokenizerFactory tokenizerFactory;
        protected LabelAwareIterator iterator;
        protected int minWordFrequency;
        protected VocabCache<VocabWord> vocabCache;
        protected LabelsSource labelsSource = new LabelsSource();
        protected List<String> stopWords = new ArrayList();

        public Builder setTokenizerFactory(@NonNull TokenizerFactory tokenizerFactory) {
            if (tokenizerFactory == null) {
                throw new NullPointerException("tokenizerFactory");
            }
            this.tokenizerFactory = tokenizerFactory;
            return this;
        }

        public Builder setIterator(@NonNull LabelAwareIterator labelAwareIterator) {
            if (labelAwareIterator == null) {
                throw new NullPointerException("iterator");
            }
            this.iterator = labelAwareIterator;
            return this;
        }

        public Builder setIterator(@NonNull DocumentIterator documentIterator) {
            if (documentIterator == null) {
                throw new NullPointerException("iterator");
            }
            this.iterator = new DocumentIteratorConverter(documentIterator, this.labelsSource);
            return this;
        }

        public Builder setIterator(@NonNull SentenceIterator sentenceIterator) {
            if (sentenceIterator == null) {
                throw new NullPointerException("iterator");
            }
            this.iterator = new SentenceIteratorConverter(sentenceIterator, this.labelsSource);
            return this;
        }

        public Builder setVocab(@NonNull VocabCache<VocabWord> vocabCache) {
            if (vocabCache == null) {
                throw new NullPointerException("vocab");
            }
            this.vocabCache = vocabCache;
            return this;
        }

        public Builder setMinWordFrequency(int i) {
            this.minWordFrequency = i;
            return this;
        }

        public Builder setStopWords(Collection<String> collection) {
            return this;
        }

        public BagOfWordsVectorizer build() {
            BagOfWordsVectorizer bagOfWordsVectorizer = new BagOfWordsVectorizer();
            bagOfWordsVectorizer.tokenizerFactory = this.tokenizerFactory;
            bagOfWordsVectorizer.iterator = this.iterator;
            bagOfWordsVectorizer.minWordFrequency = this.minWordFrequency;
            bagOfWordsVectorizer.labelsSource = this.labelsSource;
            if (this.vocabCache == null) {
                this.vocabCache = new AbstractCache.Builder().build();
            }
            bagOfWordsVectorizer.vocabCache = this.vocabCache;
            return bagOfWordsVectorizer;
        }
    }

    protected BagOfWordsVectorizer() {
    }

    @Override // org.deeplearning4j.bagofwords.vectorizer.TextVectorizer
    public DataSet vectorize(InputStream inputStream, String str) {
        try {
            BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream, "UTF-8"));
            StringBuilder sb = new StringBuilder();
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    return vectorize(sb.toString(), str);
                }
                sb.append(readLine);
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.deeplearning4j.bagofwords.vectorizer.TextVectorizer
    public DataSet vectorize(String str, String str2) {
        return new DataSet(transform(str), FeatureUtil.toOutcomeVector(this.labelsSource.indexOf(str2), this.labelsSource.size()));
    }

    @Override // org.deeplearning4j.bagofwords.vectorizer.TextVectorizer
    public INDArray transform(String str) {
        List<String> tokens = this.tokenizerFactory.create(str).getTokens();
        INDArray create = Nd4j.create(1, this.vocabCache.numWords());
        for (String str2 : tokens) {
            int indexOf = this.vocabCache.indexOf(str2);
            if (this.vocabCache.indexOf(str2) >= 0) {
                create.putScalar(indexOf, this.vocabCache.wordFrequency(str2));
            }
        }
        return create;
    }

    @Override // org.deeplearning4j.bagofwords.vectorizer.TextVectorizer
    public DataSet vectorize(File file, String str) {
        try {
            return vectorize(FileUtils.readFileToString(file), str);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public DataSet vectorize() {
        throw new UnsupportedOperationException("Can't vectorize empty input");
    }
}
