package org.deeplearning4j.iterator;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import lombok.NonNull;
import org.deeplearning4j.iterator.bert.BertMaskedLMMasker;
import org.deeplearning4j.iterator.bert.BertSequenceMasker;
import org.deeplearning4j.text.tokenization.tokenizer.Tokenizer;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:org/deeplearning4j/iterator/BertIterator.class */
public class BertIterator implements MultiDataSetIterator {
    protected Task task;
    protected TokenizerFactory tokenizerFactory;
    protected int maxTokens;
    protected int minibatchSize;
    protected boolean padMinibatches;
    protected MultiDataSetPreProcessor preProcessor;
    protected LabeledSentenceProvider sentenceProvider;
    protected LengthHandling lengthHandling;
    protected FeatureArrays featureArrays;
    protected Map<String, Integer> vocabMap;
    protected BertSequenceMasker masker;
    protected UnsupervisedLabelFormat unsupervisedLabelFormat;
    protected String maskToken;
    protected String prependToken;
    protected List<String> vocabKeysAsList;

    /* loaded from: input_file:org/deeplearning4j/iterator/BertIterator$Builder.class */
    public static class Builder {
        protected Task task;
        protected TokenizerFactory tokenizerFactory;
        protected MultiDataSetPreProcessor preProcessor;
        protected Map<String, Integer> vocabMap;
        protected UnsupervisedLabelFormat unsupervisedLabelFormat;
        protected String maskToken;
        protected String prependToken;
        protected LengthHandling lengthHandling = LengthHandling.FIXED_LENGTH;
        protected int maxTokens = -1;
        protected int minibatchSize = 32;
        protected boolean padMinibatches = false;
        protected LabeledSentenceProvider sentenceProvider = null;
        protected FeatureArrays featureArrays = FeatureArrays.INDICES_MASK_SEGMENTID;
        protected BertSequenceMasker masker = new BertMaskedLMMasker();

        public Builder task(Task task) {
            this.task = task;
            return this;
        }

        public Builder tokenizer(TokenizerFactory tokenizerFactory) {
            this.tokenizerFactory = tokenizerFactory;
            return this;
        }

        public Builder lengthHandling(@NonNull LengthHandling lengthHandling, int i) {
            if (lengthHandling == null) {
                throw new NullPointerException("lengthHandling is marked @NonNull but is null");
            }
            this.lengthHandling = lengthHandling;
            this.maxTokens = i;
            return this;
        }

        public Builder minibatchSize(int i) {
            this.minibatchSize = i;
            return this;
        }

        public Builder padMinibatches(boolean z) {
            this.padMinibatches = z;
            return this;
        }

        public Builder preProcessor(MultiDataSetPreProcessor multiDataSetPreProcessor) {
            this.preProcessor = multiDataSetPreProcessor;
            return this;
        }

        public Builder sentenceProvider(LabeledSentenceProvider labeledSentenceProvider) {
            this.sentenceProvider = labeledSentenceProvider;
            return this;
        }

        public Builder featureArrays(FeatureArrays featureArrays) {
            this.featureArrays = featureArrays;
            return this;
        }

        public Builder vocabMap(Map<String, Integer> map) {
            this.vocabMap = map;
            return this;
        }

        public Builder masker(BertSequenceMasker bertSequenceMasker) {
            this.masker = bertSequenceMasker;
            return this;
        }

        public Builder unsupervisedLabelFormat(UnsupervisedLabelFormat unsupervisedLabelFormat) {
            this.unsupervisedLabelFormat = unsupervisedLabelFormat;
            return this;
        }

        public Builder maskToken(String str) {
            this.maskToken = str;
            return this;
        }

        public Builder prependToken(String str) {
            this.prependToken = str;
            return this;
        }

        public BertIterator build() {
            Preconditions.checkState(this.task != null, "No task has been set. Use .task(BertIterator.Task.X) to set the task to be performed");
            Preconditions.checkState(this.tokenizerFactory != null, "No tokenizer factory has been set. A tokenizer factory (such as BertWordPieceTokenizerFactory) is required");
            Preconditions.checkState(this.vocabMap != null, "Cannot create iterator: No vocabMap has been set. Use Builder.vocabMap(Map<String,Integer>) to set");
            Preconditions.checkState((this.task == Task.UNSUPERVISED && this.masker == null) ? false : true, "If task is UNSUPERVISED training, a masker must be set via masker(BertSequenceMasker) method");
            Preconditions.checkState((this.task == Task.UNSUPERVISED && this.unsupervisedLabelFormat == null) ? false : true, "If task is UNSUPERVISED training, a label format must be set via masker(BertSequenceMasker) method");
            Preconditions.checkState((this.task == Task.UNSUPERVISED && this.maskToken == null) ? false : true, "If task is UNSUPERVISED training, the mask token in the vocab (such as \"[MASK]\" must be specified");
            return new BertIterator(this);
        }
    }

    /* loaded from: input_file:org/deeplearning4j/iterator/BertIterator$FeatureArrays.class */
    public enum FeatureArrays {
        INDICES_MASK,
        INDICES_MASK_SEGMENTID
    }

    /* loaded from: input_file:org/deeplearning4j/iterator/BertIterator$LengthHandling.class */
    public enum LengthHandling {
        FIXED_LENGTH,
        ANY_LENGTH,
        CLIP_ONLY
    }

    /* loaded from: input_file:org/deeplearning4j/iterator/BertIterator$Task.class */
    public enum Task {
        UNSUPERVISED,
        SEQ_CLASSIFICATION
    }

    /* loaded from: input_file:org/deeplearning4j/iterator/BertIterator$UnsupervisedLabelFormat.class */
    public enum UnsupervisedLabelFormat {
        RANK2_IDX,
        RANK3_NCL,
        RANK3_LNC
    }

    protected BertIterator(Builder builder) {
        this.maxTokens = -1;
        this.minibatchSize = 32;
        this.padMinibatches = false;
        this.sentenceProvider = null;
        this.masker = null;
        this.unsupervisedLabelFormat = null;
        this.task = builder.task;
        this.tokenizerFactory = builder.tokenizerFactory;
        this.maxTokens = builder.maxTokens;
        this.minibatchSize = builder.minibatchSize;
        this.padMinibatches = builder.padMinibatches;
        this.preProcessor = builder.preProcessor;
        this.sentenceProvider = builder.sentenceProvider;
        this.lengthHandling = builder.lengthHandling;
        this.featureArrays = builder.featureArrays;
        this.vocabMap = builder.vocabMap;
        this.masker = builder.masker;
        this.unsupervisedLabelFormat = builder.unsupervisedLabelFormat;
        this.maskToken = builder.maskToken;
        this.prependToken = builder.prependToken;
    }

    public boolean hasNext() {
        return this.sentenceProvider.hasNext();
    }

    /* renamed from: next, reason: merged with bridge method [inline-methods] */
    public MultiDataSet m1next() {
        return next(this.minibatchSize);
    }

    public void remove() {
        throw new UnsupportedOperationException("Not supported");
    }

    public MultiDataSet next(int i) {
        int min;
        INDArray[] iNDArrayArr;
        INDArray[] iNDArrayArr2;
        INDArray create;
        INDArray[] iNDArrayArr3;
        Preconditions.checkState(hasNext(), "No next element available");
        ArrayList<Pair> arrayList = new ArrayList(i);
        int i2 = 0;
        if (this.sentenceProvider == null) {
            throw new UnsupportedOperationException("Labelled sentence provider is null and no other iterator types have yet been implemented");
        }
        while (this.sentenceProvider.hasNext()) {
            int i3 = i2;
            i2++;
            if (i3 >= i) {
                break;
            }
            arrayList.add(this.sentenceProvider.nextSentence());
        }
        ArrayList arrayList2 = new ArrayList(i);
        int i4 = -1;
        for (Pair pair : arrayList) {
            List<String> list = tokenizeSentence((String) pair.getFirst());
            arrayList2.add(new Pair(list, pair.getSecond()));
            i4 = Math.max(i4, list.size());
        }
        switch (this.lengthHandling) {
            case FIXED_LENGTH:
                min = this.maxTokens;
                break;
            case ANY_LENGTH:
                min = i4;
                break;
            case CLIP_ONLY:
                min = Math.min(this.maxTokens, i4);
                break;
            default:
                throw new RuntimeException("Not implemented length handling mode: " + this.lengthHandling);
        }
        int size = arrayList2.size();
        int i5 = this.padMinibatches ? this.minibatchSize : size;
        int[][] iArr = new int[i5][min];
        int[][] iArr2 = new int[i5][min];
        for (int i6 = 0; i6 < arrayList2.size(); i6++) {
            List list2 = (List) ((Pair) arrayList2.get(i6)).getFirst();
            for (int i7 = 0; i7 < min && i7 < list2.size(); i7++) {
                Preconditions.checkState(this.vocabMap.containsKey(list2.get(i7)), "Unknown token encontered: token \"%s\" is not in vocabulary", list2.get(i7));
                iArr[i6][i7] = this.vocabMap.get(list2.get(i7)).intValue();
                iArr2[i6][i7] = 1;
            }
        }
        INDArray createFromArray = Nd4j.createFromArray(iArr);
        INDArray createFromArray2 = Nd4j.createFromArray(iArr2);
        if (this.featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID) {
            iNDArrayArr = new INDArray[]{createFromArray, Nd4j.zeros(DataType.INT, new long[]{i5, min})};
            iNDArrayArr2 = new INDArray[]{createFromArray2, null};
        } else {
            iNDArrayArr = new INDArray[]{createFromArray};
            iNDArrayArr2 = new INDArray[]{createFromArray2};
        }
        INDArray[] iNDArrayArr4 = new INDArray[1];
        if (this.task == Task.SEQ_CLASSIFICATION) {
            int[] iArr3 = new int[i5];
            if (this.sentenceProvider == null) {
                throw new RuntimeException();
            }
            int numLabelClasses = this.sentenceProvider.numLabelClasses();
            List<String> allLabels = this.sentenceProvider.allLabels();
            for (int i8 = 0; i8 < size; i8++) {
                String str = (String) ((Pair) arrayList2.get(i8)).getRight();
                iArr3[i8] = allLabels.indexOf(str);
                Preconditions.checkState(iArr3[i8] >= 0, "Provided label \"%s\" for sentence does not exist in set of classes/categories", str);
            }
            iNDArrayArr4[0] = Nd4j.create(DataType.FLOAT, new long[]{i5, numLabelClasses});
            for (int i9 = 0; i9 < size; i9++) {
                iNDArrayArr4[0].putScalar(i9, iArr3[i9], 1.0d);
            }
            iNDArrayArr3 = null;
            if (this.padMinibatches && size != i5) {
                INDArray zeros = Nd4j.zeros(DataType.FLOAT, new long[]{i5, 1});
                iNDArrayArr3 = new INDArray[]{zeros};
                zeros.get(new INDArrayIndex[]{NDArrayIndex.interval(0, size), NDArrayIndex.all()}).assign(1);
            }
        } else {
            if (this.task != Task.UNSUPERVISED) {
                throw new IllegalStateException("Task not yet implemented: " + this.task);
            }
            if (this.vocabKeysAsList == null) {
                String[] strArr = new String[this.vocabMap.size()];
                for (Map.Entry<String, Integer> entry : this.vocabMap.entrySet()) {
                    strArr[entry.getValue().intValue()] = entry.getKey();
                }
                this.vocabKeysAsList = Arrays.asList(strArr);
            }
            int size2 = this.vocabMap.size();
            INDArray zeros2 = Nd4j.zeros(DataType.INT, new long[]{i5, min});
            if (this.unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK2_IDX) {
                create = Nd4j.create(DataType.INT, new long[]{i5, min});
            } else if (this.unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_NCL) {
                create = Nd4j.create(DataType.FLOAT, new long[]{i5, size2, min});
            } else {
                if (this.unsupervisedLabelFormat != UnsupervisedLabelFormat.RANK3_LNC) {
                    throw new IllegalStateException("Unknown unsupervised label format: " + this.unsupervisedLabelFormat);
                }
                create = Nd4j.create(DataType.FLOAT, new long[]{min, i5, size2});
            }
            for (int i10 = 0; i10 < size; i10++) {
                Pair<List<String>, boolean[]> maskSequence = this.masker.maskSequence((List) ((Pair) arrayList2.get(i10)).getFirst(), this.maskToken, this.vocabKeysAsList);
                List list3 = (List) maskSequence.getFirst();
                boolean[] zArr = (boolean[]) maskSequence.getSecond();
                int min2 = Math.min(zArr.length, min);
                for (int i11 = 0; i11 < min2; i11++) {
                    if (zArr[i11]) {
                        int intValue = this.vocabMap.get((String) ((List) ((Pair) arrayList2.get(i10)).getFirst()).get(i11)).intValue();
                        if (this.unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK2_IDX) {
                            create.putScalar(i10, i11, intValue);
                        } else if (this.unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_NCL) {
                            create.putScalar(i10, i11, intValue, 1.0d);
                        } else if (this.unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_LNC) {
                            create.putScalar(i11, i10, intValue, 1.0d);
                        }
                        zeros2.putScalar(i10, i11, 1.0d);
                        createFromArray.putScalar(i10, i11, this.vocabMap.get((String) list3.get(i11)).intValue());
                    }
                }
            }
            iNDArrayArr4[0] = create;
            iNDArrayArr3 = new INDArray[]{zeros2};
        }
        org.nd4j.linalg.dataset.MultiDataSet multiDataSet = new org.nd4j.linalg.dataset.MultiDataSet(iNDArrayArr, iNDArrayArr4, iNDArrayArr2, iNDArrayArr3);
        if (this.preProcessor != null) {
            this.preProcessor.preProcess(multiDataSet);
        }
        return multiDataSet;
    }

    private List<String> tokenizeSentence(String str) {
        Tokenizer create = this.tokenizerFactory.create(str);
        ArrayList arrayList = new ArrayList();
        if (this.prependToken != null) {
            arrayList.add(this.prependToken);
        }
        while (create.hasMoreTokens()) {
            arrayList.add(create.nextToken());
        }
        return arrayList;
    }

    public boolean resetSupported() {
        return true;
    }

    public boolean asyncSupported() {
        return true;
    }

    public void reset() {
        if (this.sentenceProvider != null) {
            this.sentenceProvider.reset();
        }
    }

    public static Builder builder() {
        return new Builder();
    }

    public MultiDataSetPreProcessor getPreProcessor() {
        return this.preProcessor;
    }

    public void setPreProcessor(MultiDataSetPreProcessor multiDataSetPreProcessor) {
        this.preProcessor = multiDataSetPreProcessor;
    }
}
