package org.deeplearning4j.iterator;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import lombok.NonNull;
import org.deeplearning4j.iterator.provider.LabelAwareConverter;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.text.documentiterator.LabelAwareDocumentIterator;
import org.deeplearning4j.text.documentiterator.LabelAwareIterator;
import org.deeplearning4j.text.documentiterator.interoperability.DocumentIteratorConverter;
import org.deeplearning4j.text.sentenceiterator.interoperability.SentenceIteratorConverter;
import org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIterator;
import org.deeplearning4j.text.tokenization.tokenizer.Tokenizer;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:org/deeplearning4j/iterator/CnnSentenceDataSetIterator.class */
public class CnnSentenceDataSetIterator implements DataSetIterator {
    private static final String UNKNOWN_WORD_SENTINEL = "UNKNOWN_WORD_SENTINEL";
    private Format format;
    private LabeledSentenceProvider sentenceProvider;
    private WordVectors wordVectors;
    private TokenizerFactory tokenizerFactory;
    private UnknownWordHandling unknownWordHandling;
    private boolean useNormalizedWordVectors;
    private int minibatchSize;
    private int maxSentenceLength;
    private boolean sentencesAlongHeight;
    private DataSetPreProcessor dataSetPreProcessor;
    private int wordVectorSize;
    private int numClasses;
    private Map<String, Integer> labelClassMap;
    private INDArray unknown;
    private int cursor;
    private Pair<List<String>, String> preLoadedTokens;

    /* loaded from: input_file:org/deeplearning4j/iterator/CnnSentenceDataSetIterator$Builder.class */
    public static class Builder {
        private Format format;
        private LabeledSentenceProvider sentenceProvider;
        private WordVectors wordVectors;
        private TokenizerFactory tokenizerFactory;
        private UnknownWordHandling unknownWordHandling;
        private boolean useNormalizedWordVectors;
        private int maxSentenceLength;
        private int minibatchSize;
        private boolean sentencesAlongHeight;
        private DataSetPreProcessor dataSetPreProcessor;

        @Deprecated
        public Builder() {
            this(Format.CNN2D);
        }

        public Builder(@NonNull Format format) {
            this.sentenceProvider = null;
            this.tokenizerFactory = new DefaultTokenizerFactory();
            this.unknownWordHandling = UnknownWordHandling.RemoveWord;
            this.useNormalizedWordVectors = true;
            this.maxSentenceLength = -1;
            this.minibatchSize = 32;
            this.sentencesAlongHeight = true;
            if (format == null) {
                throw new NullPointerException("format is marked @NonNull but is null");
            }
            this.format = format;
        }

        public Builder sentenceProvider(LabeledSentenceProvider labeledSentenceProvider) {
            this.sentenceProvider = labeledSentenceProvider;
            return this;
        }

        public Builder sentenceProvider(LabelAwareIterator labelAwareIterator, @NonNull List<String> list) {
            if (list == null) {
                throw new NullPointerException("labels is marked @NonNull but is null");
            }
            return sentenceProvider(new LabelAwareConverter(labelAwareIterator, list));
        }

        public Builder sentenceProvider(LabelAwareDocumentIterator labelAwareDocumentIterator, @NonNull List<String> list) {
            if (list == null) {
                throw new NullPointerException("labels is marked @NonNull but is null");
            }
            return sentenceProvider(new DocumentIteratorConverter(labelAwareDocumentIterator), list);
        }

        public Builder sentenceProvider(LabelAwareSentenceIterator labelAwareSentenceIterator, @NonNull List<String> list) {
            if (list == null) {
                throw new NullPointerException("labels is marked @NonNull but is null");
            }
            return sentenceProvider(new SentenceIteratorConverter(labelAwareSentenceIterator), list);
        }

        public Builder wordVectors(WordVectors wordVectors) {
            this.wordVectors = wordVectors;
            return this;
        }

        public Builder tokenizerFactory(TokenizerFactory tokenizerFactory) {
            this.tokenizerFactory = tokenizerFactory;
            return this;
        }

        public Builder unknownWordHandling(UnknownWordHandling unknownWordHandling) {
            this.unknownWordHandling = unknownWordHandling;
            return this;
        }

        public Builder minibatchSize(int i) {
            this.minibatchSize = i;
            return this;
        }

        public Builder useNormalizedWordVectors(boolean z) {
            this.useNormalizedWordVectors = z;
            return this;
        }

        public Builder maxSentenceLength(int i) {
            this.maxSentenceLength = i;
            return this;
        }

        public Builder sentencesAlongHeight(boolean z) {
            this.sentencesAlongHeight = z;
            return this;
        }

        public Builder dataSetPreProcessor(DataSetPreProcessor dataSetPreProcessor) {
            this.dataSetPreProcessor = dataSetPreProcessor;
            return this;
        }

        public CnnSentenceDataSetIterator build() {
            if (this.wordVectors == null) {
                throw new IllegalStateException("Cannot build CnnSentenceDataSetIterator without a WordVectors instance");
            }
            return new CnnSentenceDataSetIterator(this);
        }
    }

    /* loaded from: input_file:org/deeplearning4j/iterator/CnnSentenceDataSetIterator$Format.class */
    public enum Format {
        RNN,
        CNN1D,
        CNN2D
    }

    /* loaded from: input_file:org/deeplearning4j/iterator/CnnSentenceDataSetIterator$UnknownWordHandling.class */
    public enum UnknownWordHandling {
        RemoveWord,
        UseUnknownVector
    }

    protected CnnSentenceDataSetIterator(Builder builder) {
        this.cursor = 0;
        this.format = builder.format;
        this.sentenceProvider = builder.sentenceProvider;
        this.wordVectors = builder.wordVectors;
        this.tokenizerFactory = builder.tokenizerFactory;
        this.unknownWordHandling = builder.unknownWordHandling;
        this.useNormalizedWordVectors = builder.useNormalizedWordVectors;
        this.minibatchSize = builder.minibatchSize;
        this.maxSentenceLength = builder.maxSentenceLength;
        this.sentencesAlongHeight = builder.sentencesAlongHeight;
        this.dataSetPreProcessor = builder.dataSetPreProcessor;
        this.numClasses = this.sentenceProvider.numLabelClasses();
        this.labelClassMap = new HashMap();
        int i = 0;
        ArrayList arrayList = new ArrayList(this.sentenceProvider.allLabels());
        Collections.sort(arrayList);
        this.wordVectorSize = this.wordVectors.getWordVector(this.wordVectors.vocab().wordAtIndex(0)).length;
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            this.labelClassMap.put((String) it.next(), Integer.valueOf(i2));
        }
        if (this.unknownWordHandling == UnknownWordHandling.UseUnknownVector) {
            if (this.useNormalizedWordVectors) {
                this.unknown = this.wordVectors.getWordVectorMatrixNormalized(this.wordVectors.getUNK());
            } else {
                this.unknown = this.wordVectors.getWordVectorMatrix(this.wordVectors.getUNK());
            }
            if (this.unknown == null) {
                this.unknown = this.wordVectors.getWordVectorMatrix(this.wordVectors.vocab().wordAtIndex(0)).like();
            }
        }
    }

    public INDArray loadSingleSentence(String str) {
        List<String> list = tokenizeSentence(str);
        if (list.isEmpty()) {
            throw new IllegalStateException("No tokens available for input sentence - empty string or no words in vocabulary with RemoveWord unknown handling? Sentence = \"" + str + "\"");
        }
        if (this.format == Format.CNN1D || this.format == Format.RNN) {
            int[] iArr = {1, this.wordVectorSize, Math.min(this.maxSentenceLength, list.size())};
            INDArray create = Nd4j.create(iArr, this.format == Format.CNN1D ? 'c' : 'f');
            INDArrayIndex[] iNDArrayIndexArr = new INDArrayIndex[3];
            iNDArrayIndexArr[0] = NDArrayIndex.point(0L);
            for (int i = 0; i < iArr[2]; i++) {
                INDArray vector = getVector(list.get(i));
                iNDArrayIndexArr[1] = NDArrayIndex.all();
                iNDArrayIndexArr[2] = NDArrayIndex.point(i);
                create.put(iNDArrayIndexArr, vector);
            }
            return create;
        }
        int[] iArr2 = {1, 1, 0, 0};
        if (this.sentencesAlongHeight) {
            iArr2[2] = Math.min(this.maxSentenceLength, list.size());
            iArr2[3] = this.wordVectorSize;
        } else {
            iArr2[2] = this.wordVectorSize;
            iArr2[3] = Math.min(this.maxSentenceLength, list.size());
        }
        INDArray create2 = Nd4j.create(iArr2);
        int i2 = this.sentencesAlongHeight ? iArr2[2] : iArr2[3];
        INDArrayIndex[] iNDArrayIndexArr2 = new INDArrayIndex[4];
        iNDArrayIndexArr2[0] = NDArrayIndex.point(0L);
        iNDArrayIndexArr2[1] = NDArrayIndex.point(0L);
        for (int i3 = 0; i3 < i2; i3++) {
            INDArray vector2 = getVector(list.get(i3));
            if (this.sentencesAlongHeight) {
                iNDArrayIndexArr2[2] = NDArrayIndex.point(i3);
                iNDArrayIndexArr2[3] = NDArrayIndex.all();
            } else {
                iNDArrayIndexArr2[2] = NDArrayIndex.all();
                iNDArrayIndexArr2[3] = NDArrayIndex.point(i3);
            }
            create2.put(iNDArrayIndexArr2, vector2);
        }
        return create2;
    }

    private INDArray getVector(String str) {
        return (this.unknownWordHandling == UnknownWordHandling.UseUnknownVector && str == UNKNOWN_WORD_SENTINEL) ? this.unknown : this.useNormalizedWordVectors ? this.wordVectors.getWordVectorMatrixNormalized(str) : this.wordVectors.getWordVectorMatrix(str);
    }

    private List<String> tokenizeSentence(String str) {
        Tokenizer create = this.tokenizerFactory.create(str);
        ArrayList arrayList = new ArrayList();
        while (create.hasMoreTokens()) {
            String nextToken = create.nextToken();
            if (!this.wordVectors.outOfVocabularySupported() && !this.wordVectors.hasWord(nextToken)) {
                switch (this.unknownWordHandling) {
                    case RemoveWord:
                        break;
                    case UseUnknownVector:
                        nextToken = UNKNOWN_WORD_SENTINEL;
                        break;
                }
            }
            arrayList.add(nextToken);
        }
        return arrayList;
    }

    public Map<String, Integer> getLabelClassMap() {
        return new HashMap(this.labelClassMap);
    }

    public List<String> getLabels() {
        String[] strArr = new String[this.labelClassMap.size()];
        for (Map.Entry<String, Integer> entry : this.labelClassMap.entrySet()) {
            strArr[entry.getValue().intValue()] = entry.getKey();
        }
        return Arrays.asList(strArr);
    }

    public boolean hasNext() {
        if (this.sentenceProvider == null) {
            throw new UnsupportedOperationException("Cannot do next/hasNext without a sentence provider");
        }
        while (this.preLoadedTokens == null && this.sentenceProvider.hasNext()) {
            preLoadTokens();
        }
        return this.preLoadedTokens != null;
    }

    private void preLoadTokens() {
        if (this.preLoadedTokens != null) {
            return;
        }
        Pair<String, String> nextSentence = this.sentenceProvider.nextSentence();
        List<String> list = tokenizeSentence((String) nextSentence.getFirst());
        if (list.isEmpty()) {
            return;
        }
        this.preLoadedTokens = new Pair<>(list, nextSentence.getSecond());
    }

    /* renamed from: next, reason: merged with bridge method [inline-methods] */
    public DataSet m7next() {
        return next(this.minibatchSize);
    }

    public DataSet next(int i) {
        INDArray create;
        if (this.sentenceProvider == null) {
            throw new UnsupportedOperationException("Cannot do next/hasNext without a sentence provider");
        }
        if (!hasNext()) {
            throw new NoSuchElementException("No next element");
        }
        ArrayList arrayList = new ArrayList(i);
        int i2 = -1;
        int i3 = Integer.MAX_VALUE;
        if (this.preLoadedTokens != null) {
            arrayList.add(this.preLoadedTokens);
            i2 = Math.max(-1, ((List) this.preLoadedTokens.getFirst()).size());
            i3 = Math.min(Integer.MAX_VALUE, ((List) this.preLoadedTokens.getFirst()).size());
            this.preLoadedTokens = null;
        }
        int size = arrayList.size();
        while (size < i && this.sentenceProvider.hasNext()) {
            Pair<String, String> nextSentence = this.sentenceProvider.nextSentence();
            List<String> list = tokenizeSentence((String) nextSentence.getFirst());
            if (list.isEmpty()) {
                size--;
            } else {
                i2 = Math.max(i2, list.size());
                i3 = Math.min(i3, list.size());
                arrayList.add(new Pair(list, nextSentence.getSecond()));
            }
            size++;
        }
        if (this.maxSentenceLength > 0 && i2 > this.maxSentenceLength) {
            i2 = this.maxSentenceLength;
        }
        int size2 = arrayList.size();
        INDArray create2 = Nd4j.create(size2, this.numClasses);
        for (int i4 = 0; i4 < arrayList.size(); i4++) {
            String str = (String) ((Pair) arrayList.get(i4)).getSecond();
            if (!this.labelClassMap.containsKey(str)) {
                throw new IllegalStateException("Got label \"" + str + "\" that is not present in list of LabeledSentenceProvider labels");
            }
            create2.putScalar(i4, this.labelClassMap.get(str).intValue(), 1.0d);
        }
        INDArray iNDArray = null;
        if (this.format == Format.CNN1D || this.format == Format.RNN) {
            create = Nd4j.create(new int[]{size2, this.wordVectorSize, i2}, this.format == Format.CNN1D ? 'c' : 'f');
            INDArrayIndex[] iNDArrayIndexArr = new INDArrayIndex[3];
            iNDArrayIndexArr[1] = NDArrayIndex.all();
            for (int i5 = 0; i5 < size2; i5++) {
                iNDArrayIndexArr[0] = NDArrayIndex.point(i5);
                List list2 = (List) ((Pair) arrayList.get(i5)).getFirst();
                for (int i6 = 0; i6 < list2.size() && i6 < this.maxSentenceLength; i6++) {
                    iNDArrayIndexArr[2] = NDArrayIndex.point(i6);
                    create.put(iNDArrayIndexArr, getVector((String) list2.get(i6)));
                }
            }
            if (i3 != i2) {
                iNDArray = Nd4j.create(size2, i2);
                for (int i7 = 0; i7 < size2; i7++) {
                    int size3 = ((List) ((Pair) arrayList.get(i7)).getFirst()).size();
                    if (size3 >= i2) {
                        iNDArray.getRow(i7).assign(Double.valueOf(1.0d));
                    } else {
                        iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(i7), NDArrayIndex.interval(0, size3)}).assign(Double.valueOf(1.0d));
                    }
                }
            }
        } else {
            int[] iArr = new int[4];
            iArr[0] = size2;
            iArr[1] = 1;
            if (this.sentencesAlongHeight) {
                iArr[2] = i2;
                iArr[3] = this.wordVectorSize;
            } else {
                iArr[2] = this.wordVectorSize;
                iArr[3] = i2;
            }
            create = Nd4j.create(iArr);
            INDArrayIndex[] iNDArrayIndexArr2 = new INDArrayIndex[4];
            iNDArrayIndexArr2[1] = NDArrayIndex.point(0L);
            for (int i8 = 0; i8 < size2; i8++) {
                iNDArrayIndexArr2[0] = NDArrayIndex.point(i8);
                List list3 = (List) ((Pair) arrayList.get(i8)).getFirst();
                for (int i9 = 0; i9 < list3.size() && i9 < this.maxSentenceLength; i9++) {
                    INDArray vector = getVector((String) list3.get(i9));
                    if (this.sentencesAlongHeight) {
                        iNDArrayIndexArr2[2] = NDArrayIndex.point(i9);
                        iNDArrayIndexArr2[3] = NDArrayIndex.all();
                    } else {
                        iNDArrayIndexArr2[2] = NDArrayIndex.all();
                        iNDArrayIndexArr2[3] = NDArrayIndex.point(i9);
                    }
                    create.put(iNDArrayIndexArr2, vector);
                }
            }
            if (i3 != i2) {
                if (this.sentencesAlongHeight) {
                    iNDArray = Nd4j.create(new int[]{size2, 1, i2, 1});
                    for (int i10 = 0; i10 < size2; i10++) {
                        int size4 = ((List) ((Pair) arrayList.get(i10)).getFirst()).size();
                        if (size4 >= i2) {
                            iNDArray.slice(i10).assign(Double.valueOf(1.0d));
                        } else {
                            iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(i10), NDArrayIndex.point(0L), NDArrayIndex.interval(0, size4), NDArrayIndex.point(0L)}).assign(Double.valueOf(1.0d));
                        }
                    }
                } else {
                    iNDArray = Nd4j.create(new int[]{size2, 1, 1, i2});
                    for (int i11 = 0; i11 < size2; i11++) {
                        int size5 = ((List) ((Pair) arrayList.get(i11)).getFirst()).size();
                        if (size5 >= i2) {
                            iNDArray.slice(i11).assign(Double.valueOf(1.0d));
                        } else {
                            iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(i11), NDArrayIndex.point(0L), NDArrayIndex.point(0L), NDArrayIndex.interval(0, size5)}).assign(Double.valueOf(1.0d));
                        }
                    }
                }
            }
        }
        DataSet dataSet = new DataSet(create, create2, iNDArray, (INDArray) null);
        if (this.dataSetPreProcessor != null) {
            this.dataSetPreProcessor.preProcess(dataSet);
        }
        this.cursor += dataSet.numExamples();
        return dataSet;
    }

    public int inputColumns() {
        return this.wordVectorSize;
    }

    public int totalOutcomes() {
        return this.numClasses;
    }

    public boolean resetSupported() {
        return true;
    }

    public boolean asyncSupported() {
        return true;
    }

    public void reset() {
        this.cursor = 0;
        this.sentenceProvider.reset();
    }

    public int batch() {
        return this.minibatchSize;
    }

    public void setPreProcessor(DataSetPreProcessor dataSetPreProcessor) {
        this.dataSetPreProcessor = dataSetPreProcessor;
    }

    public DataSetPreProcessor getPreProcessor() {
        return this.dataSetPreProcessor;
    }

    public void remove() {
        throw new UnsupportedOperationException("Not supported");
    }

    public CnnSentenceDataSetIterator(Format format, LabeledSentenceProvider labeledSentenceProvider, WordVectors wordVectors, TokenizerFactory tokenizerFactory, UnknownWordHandling unknownWordHandling, boolean z, int i, int i2, boolean z2, DataSetPreProcessor dataSetPreProcessor, int i3, int i4, Map<String, Integer> map, INDArray iNDArray, int i5, Pair<List<String>, String> pair) {
        this.cursor = 0;
        this.format = format;
        this.sentenceProvider = labeledSentenceProvider;
        this.wordVectors = wordVectors;
        this.tokenizerFactory = tokenizerFactory;
        this.unknownWordHandling = unknownWordHandling;
        this.useNormalizedWordVectors = z;
        this.minibatchSize = i;
        this.maxSentenceLength = i2;
        this.sentencesAlongHeight = z2;
        this.dataSetPreProcessor = dataSetPreProcessor;
        this.wordVectorSize = i3;
        this.numClasses = i4;
        this.labelClassMap = map;
        this.unknown = iNDArray;
        this.cursor = i5;
        this.preLoadedTokens = pair;
    }
}
