package org.deeplearning4j.datasets.datavec;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.writable.Writable;
import org.datavec.common.data.NDArrayWritable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.util.FeatureUtil;

/* loaded from: input_file:org/deeplearning4j/datasets/datavec/SequenceRecordReaderDataSetIterator.class */
public class SequenceRecordReaderDataSetIterator implements DataSetIterator {
    private SequenceRecordReader recordReader;
    private SequenceRecordReader labelsReader;
    private int miniBatchSize;
    private final boolean regression;
    private int labelIndex;
    private final int numPossibleLabels;
    private int cursor;
    private int inputColumns;
    private int totalOutcomes;
    private boolean useStored;
    private DataSet stored;
    private DataSetPreProcessor preProcessor;
    private AlignmentMode alignmentMode;
    private final boolean singleSequenceReaderMode;

    /* loaded from: input_file:org/deeplearning4j/datasets/datavec/SequenceRecordReaderDataSetIterator$AlignmentMode.class */
    public enum AlignmentMode {
        EQUAL_LENGTH,
        ALIGN_START,
        ALIGN_END
    }

    public SequenceRecordReaderDataSetIterator(SequenceRecordReader sequenceRecordReader, SequenceRecordReader sequenceRecordReader2, int i, int i2, boolean z) {
        this(sequenceRecordReader, sequenceRecordReader2, i, i2, z, AlignmentMode.EQUAL_LENGTH);
    }

    public SequenceRecordReaderDataSetIterator(SequenceRecordReader sequenceRecordReader, SequenceRecordReader sequenceRecordReader2, int i, int i2, boolean z, AlignmentMode alignmentMode) {
        this.miniBatchSize = 10;
        this.labelIndex = -1;
        this.cursor = 0;
        this.inputColumns = -1;
        this.totalOutcomes = -1;
        this.useStored = false;
        this.stored = null;
        this.recordReader = sequenceRecordReader;
        this.labelsReader = sequenceRecordReader2;
        this.miniBatchSize = i;
        this.numPossibleLabels = i2;
        this.regression = z;
        this.alignmentMode = alignmentMode;
        this.singleSequenceReaderMode = false;
    }

    public SequenceRecordReaderDataSetIterator(SequenceRecordReader sequenceRecordReader, int i, int i2, int i3, boolean z) {
        this.miniBatchSize = 10;
        this.labelIndex = -1;
        this.cursor = 0;
        this.inputColumns = -1;
        this.totalOutcomes = -1;
        this.useStored = false;
        this.stored = null;
        this.recordReader = sequenceRecordReader;
        this.labelsReader = null;
        this.miniBatchSize = i;
        this.regression = z;
        this.labelIndex = i3;
        this.numPossibleLabels = i2;
        this.singleSequenceReaderMode = true;
    }

    public boolean hasNext() {
        return this.recordReader.hasNext();
    }

    /* renamed from: next, reason: merged with bridge method [inline-methods] */
    public DataSet m11next() {
        return next(this.miniBatchSize);
    }

    public DataSet next(int i) {
        if (!this.useStored) {
            if (hasNext()) {
                return this.singleSequenceReaderMode ? nextSingleSequenceReader(i) : nextMultipleSequenceReaders(i);
            }
            throw new NoSuchElementException();
        }
        this.useStored = false;
        DataSet dataSet = this.stored;
        this.stored = null;
        if (this.preProcessor != null) {
            this.preProcessor.preProcess(dataSet);
        }
        return dataSet;
    }

    private DataSet nextSingleSequenceReader(int i) {
        int max;
        ArrayList arrayList = new ArrayList(i);
        ArrayList arrayList2 = new ArrayList(i);
        int i2 = 0;
        int i3 = 0;
        for (int i4 = 0; i4 < i && hasNext(); i4++) {
            INDArray[] featuresLabelsSingleReader = getFeaturesLabelsSingleReader(this.recordReader.sequenceRecord());
            if (i4 == 0) {
                i2 = featuresLabelsSingleReader[0].size(0);
                max = i2;
            } else {
                i2 = Math.min(i2, featuresLabelsSingleReader[0].size(0));
                max = Math.max(i3, featuresLabelsSingleReader[0].size(0));
            }
            i3 = max;
            arrayList.add(featuresLabelsSingleReader[0]);
            arrayList2.add(featuresLabelsSingleReader[1]);
        }
        INDArray create = Nd4j.create(new int[]{arrayList.size(), ((INDArray) arrayList.get(0)).size(1), i3}, 'f');
        INDArray create2 = Nd4j.create(new int[]{arrayList2.size(), ((INDArray) arrayList2.get(0)).size(1), i3}, 'f');
        INDArray iNDArray = null;
        INDArray iNDArray2 = null;
        if (i2 == i3) {
            for (int i5 = 0; i5 < arrayList.size(); i5++) {
                create.tensorAlongDimension(i5, new int[]{1, 2}).permutei(new int[]{1, 0}).assign((INDArray) arrayList.get(i5));
                create2.tensorAlongDimension(i5, new int[]{1, 2}).permutei(new int[]{1, 0}).assign((INDArray) arrayList2.get(i5));
            }
        } else {
            iNDArray = Nd4j.ones(arrayList.size(), i3);
            iNDArray2 = Nd4j.ones(arrayList2.size(), i3);
            for (int i6 = 0; i6 < arrayList.size(); i6++) {
                INDArray iNDArray3 = (INDArray) arrayList.get(i6);
                int size = iNDArray3.size(0);
                create.tensorAlongDimension(i6, new int[]{1, 2}).permutei(new int[]{1, 0}).put(new INDArrayIndex[]{NDArrayIndex.interval(0, size), NDArrayIndex.all()}, iNDArray3);
                create2.tensorAlongDimension(i6, new int[]{1, 2}).permutei(new int[]{1, 0}).put(new INDArrayIndex[]{NDArrayIndex.interval(0, size), NDArrayIndex.all()}, (INDArray) arrayList2.get(i6));
                for (int i7 = size; i7 < i3; i7++) {
                    iNDArray.put(i6, i7, Double.valueOf(0.0d));
                    iNDArray2.put(i6, i7, Double.valueOf(0.0d));
                }
            }
        }
        this.cursor += arrayList.size();
        if (this.inputColumns == -1) {
            this.inputColumns = create.size(1);
        }
        if (this.totalOutcomes == -1) {
            this.totalOutcomes = create2.size(1);
        }
        DataSet dataSet = new DataSet(create, create2, iNDArray, iNDArray2);
        if (this.preProcessor != null) {
            this.preProcessor.preProcess(dataSet);
        }
        return dataSet;
    }

    private DataSet nextMultipleSequenceReaders(int i) {
        INDArray create;
        INDArray create2;
        ArrayList arrayList = new ArrayList(i);
        ArrayList arrayList2 = new ArrayList(i);
        for (int i2 = 0; i2 < i && hasNext(); i2++) {
            List<List<Writable>> sequenceRecord = this.recordReader.sequenceRecord();
            List<List<Writable>> sequenceRecord2 = this.labelsReader.sequenceRecord();
            INDArray features = getFeatures(sequenceRecord);
            INDArray labels = getLabels(sequenceRecord2);
            arrayList.add(features);
            arrayList2.add(labels);
        }
        INDArray iNDArray = null;
        INDArray iNDArray2 = null;
        if (this.alignmentMode == AlignmentMode.EQUAL_LENGTH) {
            int[] iArr = {arrayList.size(), ((INDArray) arrayList.get(0)).size(1), ((INDArray) arrayList.get(0)).size(0)};
            int[] iArr2 = {arrayList2.size(), ((INDArray) arrayList2.get(0)).size(1), ((INDArray) arrayList2.get(0)).size(0)};
            create = Nd4j.create(iArr, 'f');
            create2 = Nd4j.create(iArr2, 'f');
            for (int i3 = 0; i3 < arrayList.size(); i3++) {
                create.tensorAlongDimension(i3, new int[]{1, 2}).permutei(new int[]{1, 0}).assign((INDArray) arrayList.get(i3));
                create2.tensorAlongDimension(i3, new int[]{1, 2}).permutei(new int[]{1, 0}).assign((INDArray) arrayList2.get(i3));
            }
        } else if (this.alignmentMode == AlignmentMode.ALIGN_START) {
            int i4 = 0;
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                i4 = Math.max(((INDArray) it.next()).size(0), i4);
            }
            Iterator it2 = arrayList2.iterator();
            while (it2.hasNext()) {
                i4 = Math.max(((INDArray) it2.next()).size(0), i4);
            }
            int[] iArr3 = {arrayList.size(), ((INDArray) arrayList.get(0)).size(1), i4};
            int[] iArr4 = {arrayList2.size(), ((INDArray) arrayList2.get(0)).size(1), i4};
            create = Nd4j.create(iArr3, 'f');
            create2 = Nd4j.create(iArr4, 'f');
            iNDArray = Nd4j.ones(arrayList.size(), i4);
            iNDArray2 = Nd4j.ones(arrayList2.size(), i4);
            for (int i5 = 0; i5 < arrayList.size(); i5++) {
                INDArray iNDArray3 = (INDArray) arrayList.get(i5);
                INDArray iNDArray4 = (INDArray) arrayList2.get(i5);
                create.tensorAlongDimension(i5, new int[]{1, 2}).permutei(new int[]{1, 0}).put(new INDArrayIndex[]{NDArrayIndex.interval(0, iNDArray3.size(0)), NDArrayIndex.all()}, iNDArray3);
                create2.tensorAlongDimension(i5, new int[]{1, 2}).permutei(new int[]{1, 0}).put(new INDArrayIndex[]{NDArrayIndex.interval(0, iNDArray4.size(0)), NDArrayIndex.all()}, iNDArray4);
                for (int size = iNDArray3.size(0); size < i4; size++) {
                    iNDArray.putScalar(i5, size, 0.0d);
                }
                for (int size2 = iNDArray4.size(0); size2 < i4; size2++) {
                    iNDArray2.putScalar(i5, size2, 0.0d);
                }
            }
        } else {
            if (this.alignmentMode != AlignmentMode.ALIGN_END) {
                throw new UnsupportedOperationException("Unknown alignment mode: " + this.alignmentMode);
            }
            int i6 = 0;
            Iterator it3 = arrayList.iterator();
            while (it3.hasNext()) {
                i6 = Math.max(((INDArray) it3.next()).size(0), i6);
            }
            Iterator it4 = arrayList2.iterator();
            while (it4.hasNext()) {
                i6 = Math.max(((INDArray) it4.next()).size(0), i6);
            }
            int[] iArr5 = {arrayList.size(), ((INDArray) arrayList.get(0)).size(1), i6};
            int[] iArr6 = {arrayList2.size(), ((INDArray) arrayList2.get(0)).size(1), i6};
            create = Nd4j.create(iArr5, 'f');
            create2 = Nd4j.create(iArr6, 'f');
            iNDArray = Nd4j.ones(arrayList.size(), i6);
            iNDArray2 = Nd4j.ones(arrayList2.size(), i6);
            for (int i7 = 0; i7 < arrayList.size(); i7++) {
                INDArray iNDArray5 = (INDArray) arrayList.get(i7);
                INDArray iNDArray6 = (INDArray) arrayList2.get(i7);
                int size3 = iNDArray5.size(0);
                int size4 = iNDArray6.size(0);
                if (size3 >= size4) {
                    create.tensorAlongDimension(i7, new int[]{1, 2}).permutei(new int[]{1, 0}).put(new INDArrayIndex[]{NDArrayIndex.interval(0, size3), NDArrayIndex.all()}, iNDArray5);
                    create2.tensorAlongDimension(i7, new int[]{1, 2}).permutei(new int[]{1, 0}).put(new INDArrayIndex[]{NDArrayIndex.interval(size3 - size4, size3), NDArrayIndex.all()}, iNDArray6);
                    for (int i8 = size3; i8 < i6; i8++) {
                        iNDArray.putScalar(i7, i8, 0.0d);
                    }
                    for (int i9 = 0; i9 < size3 - size4; i9++) {
                        iNDArray2.putScalar(i7, i9, 0.0d);
                    }
                    for (int i10 = size3; i10 < i6; i10++) {
                        iNDArray2.putScalar(i7, i10, 0.0d);
                    }
                } else {
                    create.tensorAlongDimension(i7, new int[]{1, 2}).permutei(new int[]{1, 0}).put(new INDArrayIndex[]{NDArrayIndex.interval(size4 - size3, size4), NDArrayIndex.all()}, iNDArray5);
                    create2.tensorAlongDimension(i7, new int[]{1, 2}).permutei(new int[]{1, 0}).put(new INDArrayIndex[]{NDArrayIndex.interval(0, size4), NDArrayIndex.all()}, iNDArray6);
                    for (int i11 = 0; i11 < size4 - size3; i11++) {
                        iNDArray.putScalar(i7, i11, 0.0d);
                    }
                    for (int i12 = size4; i12 < i6; i12++) {
                        iNDArray.putScalar(i7, i12, 0.0d);
                    }
                    for (int i13 = size4; i13 < i6; i13++) {
                        iNDArray2.putScalar(i7, i13, 0.0d);
                    }
                }
            }
        }
        this.cursor += arrayList.size();
        if (this.inputColumns == -1) {
            this.inputColumns = create.size(1);
        }
        if (this.totalOutcomes == -1) {
            this.totalOutcomes = create2.size(1);
        }
        DataSet dataSet = new DataSet(create, create2, iNDArray, iNDArray2);
        if (this.preProcessor != null) {
            this.preProcessor.preProcess(dataSet);
        }
        return dataSet;
    }

    public int totalExamples() {
        throw new UnsupportedOperationException("Not supported");
    }

    public int inputColumns() {
        if (this.inputColumns != -1) {
            return this.inputColumns;
        }
        preLoad();
        return this.inputColumns;
    }

    public int totalOutcomes() {
        if (this.totalOutcomes != -1) {
            return this.totalOutcomes;
        }
        preLoad();
        return this.totalOutcomes;
    }

    private void preLoad() {
        this.stored = m11next();
        this.useStored = true;
        this.inputColumns = this.stored.getFeatureMatrix().size(1);
        this.totalOutcomes = this.stored.getLabels().size(1);
    }

    public boolean resetSupported() {
        return true;
    }

    public boolean asyncSupported() {
        return true;
    }

    public void reset() {
        this.recordReader.reset();
        if (this.labelsReader != null) {
            this.labelsReader.reset();
        }
        this.cursor = 0;
        this.stored = null;
        this.useStored = false;
    }

    public int batch() {
        return this.miniBatchSize;
    }

    public int cursor() {
        return this.cursor;
    }

    public int numExamples() {
        throw new UnsupportedOperationException("Not supported");
    }

    public void setPreProcessor(DataSetPreProcessor dataSetPreProcessor) {
        this.preProcessor = dataSetPreProcessor;
    }

    public List<String> getLabels() {
        return null;
    }

    public void remove() {
        throw new UnsupportedOperationException("Remove not supported for this iterator");
    }

    private INDArray getFeatures(List<List<Writable>> list) {
        int[] iArr = new int[2];
        iArr[0] = list.size();
        int i = 0;
        INDArray iNDArray = null;
        for (List<Writable> list2 : list) {
            if (i == 0) {
                Iterator<Writable> it = list2.iterator();
                while (it.hasNext()) {
                    NDArrayWritable nDArrayWritable = (Writable) it.next();
                    if (nDArrayWritable instanceof NDArrayWritable) {
                        iArr[1] = iArr[1] + nDArrayWritable.get().length();
                    } else {
                        iArr[1] = iArr[1] + 1;
                    }
                }
                iNDArray = Nd4j.create(iArr, 'f');
            }
            Iterator<Writable> it2 = list2.iterator();
            int i2 = 0;
            while (it2.hasNext()) {
                NDArrayWritable nDArrayWritable2 = (Writable) it2.next();
                if (nDArrayWritable2 instanceof NDArrayWritable) {
                    INDArray iNDArray2 = nDArrayWritable2.get();
                    iNDArray.put(new INDArrayIndex[]{NDArrayIndex.point(i), NDArrayIndex.interval(i2, i2 + iNDArray2.length())}, iNDArray2);
                    i2 += iNDArray2.length();
                } else {
                    int i3 = i2;
                    i2++;
                    iNDArray.put(i, i3, Double.valueOf(nDArrayWritable2.toDouble()));
                }
            }
            i++;
        }
        return iNDArray;
    }

    private INDArray getLabels(List<List<Writable>> list) {
        int[] iArr = new int[2];
        iArr[0] = list.size();
        int i = 0;
        INDArray iNDArray = null;
        for (List<Writable> list2 : list) {
            if (i == 0) {
                if (this.regression) {
                    Iterator<Writable> it = list2.iterator();
                    while (it.hasNext()) {
                        NDArrayWritable nDArrayWritable = (Writable) it.next();
                        if (nDArrayWritable instanceof NDArrayWritable) {
                            iArr[1] = iArr[1] + nDArrayWritable.get().length();
                        } else {
                            iArr[1] = iArr[1] + 1;
                        }
                    }
                } else {
                    iArr[1] = this.numPossibleLabels;
                }
                iNDArray = Nd4j.create(iArr, 'f');
            }
            Iterator<Writable> it2 = list2.iterator();
            int i2 = 0;
            if (this.regression) {
                while (it2.hasNext()) {
                    NDArrayWritable nDArrayWritable2 = (Writable) it2.next();
                    if (nDArrayWritable2 instanceof NDArrayWritable) {
                        INDArray iNDArray2 = nDArrayWritable2.get();
                        iNDArray.put(new INDArrayIndex[]{NDArrayIndex.point(i), NDArrayIndex.interval(i2, i2 + iNDArray2.length())}, iNDArray2);
                        i2 += iNDArray2.length();
                    } else {
                        int i3 = i2;
                        i2++;
                        iNDArray.put(i, i3, Double.valueOf(nDArrayWritable2.toDouble()));
                    }
                }
            } else {
                iNDArray.getRow(i).assign(FeatureUtil.toOutcomeVector(it2.next().toInt(), this.numPossibleLabels));
            }
            i++;
        }
        return iNDArray;
    }

    private INDArray[] getFeaturesLabelsSingleReader(List<List<Writable>> list) {
        int i = 0;
        INDArray iNDArray = null;
        INDArray iNDArray2 = null;
        int i2 = 0;
        for (List<Writable> list2 : list) {
            if (i == 0) {
                int i3 = 0;
                Iterator<Writable> it = list2.iterator();
                while (it.hasNext()) {
                    NDArrayWritable nDArrayWritable = (Writable) it.next();
                    int i4 = i3;
                    i3++;
                    if (i4 != this.labelIndex) {
                        i2 = nDArrayWritable instanceof NDArrayWritable ? i2 + nDArrayWritable.get().length() : i2 + 1;
                    }
                }
                iNDArray = Nd4j.zeros(list.size(), i2);
                iNDArray2 = Nd4j.zeros(list.size(), this.regression ? list2.get(this.labelIndex) instanceof NDArrayWritable ? list2.get(this.labelIndex).get().length() : 1 : this.numPossibleLabels);
            }
            Iterator<Writable> it2 = list2.iterator();
            int i5 = 0;
            int i6 = 0;
            while (it2.hasNext()) {
                NDArrayWritable nDArrayWritable2 = (Writable) it2.next();
                int i7 = i5;
                i5++;
                if (i7 == this.labelIndex) {
                    if (!this.regression) {
                        iNDArray2.putScalar(i, nDArrayWritable2.toInt(), 1.0d);
                    } else if (nDArrayWritable2 instanceof NDArrayWritable) {
                        iNDArray2.putRow(i, nDArrayWritable2.get());
                    } else {
                        iNDArray2.put(i, 0, Double.valueOf(nDArrayWritable2.toDouble()));
                    }
                } else if (nDArrayWritable2 instanceof NDArrayWritable) {
                    INDArray iNDArray3 = nDArrayWritable2.get();
                    int length = iNDArray3.length();
                    iNDArray.put(new INDArrayIndex[]{NDArrayIndex.point(i), NDArrayIndex.interval(i6, i6 + length)}, iNDArray3);
                    i6 += length;
                } else {
                    int i8 = i6;
                    i6++;
                    iNDArray.put(i, i8, Double.valueOf(nDArrayWritable2.toDouble()));
                }
            }
            i++;
        }
        return new INDArray[]{iNDArray, iNDArray2};
    }

    public DataSetPreProcessor getPreProcessor() {
        return this.preProcessor;
    }
}
