package org.deeplearning4j.datasets.datavec;

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
import org.datavec.api.records.SequenceRecord;
import org.datavec.api.records.metadata.RecordMetaData;
import org.datavec.api.records.metadata.RecordMetaDataComposable;
import org.datavec.api.records.metadata.RecordMetaDataComposableMap;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator;
import org.deeplearning4j.datasets.datavec.exception.ZeroLengthSequenceException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/deeplearning4j/datasets/datavec/SequenceRecordReaderDataSetIterator.class */
public class SequenceRecordReaderDataSetIterator implements DataSetIterator {
    private static final String READER_KEY = "reader";
    private static final String READER_KEY_LABEL = "reader_labels";
    private SequenceRecordReader recordReader;
    private SequenceRecordReader labelsReader;
    private int miniBatchSize;
    private final boolean regression;
    private int labelIndex;
    private final int numPossibleLabels;
    private int cursor;
    private int inputColumns;
    private int totalOutcomes;
    private boolean useStored;
    private DataSet stored;
    private DataSetPreProcessor preProcessor;
    private AlignmentMode alignmentMode;
    private final boolean singleSequenceReaderMode;
    private boolean collectMetaData;
    private RecordReaderMultiDataSetIterator underlying;
    private boolean underlyingIsDisjoint;

    /* loaded from: input_file:org/deeplearning4j/datasets/datavec/SequenceRecordReaderDataSetIterator$AlignmentMode.class */
    public enum AlignmentMode {
        EQUAL_LENGTH,
        ALIGN_START,
        ALIGN_END
    }

    public SequenceRecordReaderDataSetIterator(SequenceRecordReader sequenceRecordReader, SequenceRecordReader sequenceRecordReader2, int i, int i2) {
        this(sequenceRecordReader, sequenceRecordReader2, i, i2, false);
    }

    public SequenceRecordReaderDataSetIterator(SequenceRecordReader sequenceRecordReader, SequenceRecordReader sequenceRecordReader2, int i, int i2, boolean z) {
        this(sequenceRecordReader, sequenceRecordReader2, i, i2, z, AlignmentMode.EQUAL_LENGTH);
    }

    public SequenceRecordReaderDataSetIterator(SequenceRecordReader sequenceRecordReader, SequenceRecordReader sequenceRecordReader2, int i, int i2, boolean z, AlignmentMode alignmentMode) {
        this.miniBatchSize = 10;
        this.labelIndex = -1;
        this.cursor = 0;
        this.inputColumns = -1;
        this.totalOutcomes = -1;
        this.useStored = false;
        this.stored = null;
        this.collectMetaData = false;
        this.recordReader = sequenceRecordReader;
        this.labelsReader = sequenceRecordReader2;
        this.miniBatchSize = i;
        this.numPossibleLabels = i2;
        this.regression = z;
        this.alignmentMode = alignmentMode;
        this.singleSequenceReaderMode = false;
    }

    public SequenceRecordReaderDataSetIterator(SequenceRecordReader sequenceRecordReader, int i, int i2, int i3) {
        this(sequenceRecordReader, i, i2, i3, false);
    }

    public SequenceRecordReaderDataSetIterator(SequenceRecordReader sequenceRecordReader, int i, int i2, int i3, boolean z) {
        this.miniBatchSize = 10;
        this.labelIndex = -1;
        this.cursor = 0;
        this.inputColumns = -1;
        this.totalOutcomes = -1;
        this.useStored = false;
        this.stored = null;
        this.collectMetaData = false;
        this.recordReader = sequenceRecordReader;
        this.labelsReader = null;
        this.miniBatchSize = i;
        this.regression = z;
        this.labelIndex = i3;
        this.numPossibleLabels = i2;
        this.singleSequenceReaderMode = true;
    }

    private void initializeUnderlyingFromReader() {
        initializeUnderlying(this.recordReader.nextSequence());
        this.underlying.reset();
    }

    private void initializeUnderlying(SequenceRecord sequenceRecord) {
        int i;
        int i2;
        if (sequenceRecord.getSequenceRecord().isEmpty()) {
            throw new ZeroLengthSequenceException();
        }
        int size = ((List) sequenceRecord.getSequenceRecord().get(0)).size();
        if (this.singleSequenceReaderMode && this.numPossibleLabels >= 1 && this.labelIndex < 0) {
            this.labelIndex = size - 1;
        } else if (!this.singleSequenceReaderMode && this.numPossibleLabels >= 1 && this.labelIndex < 0) {
            this.labelIndex = 0;
        }
        this.recordReader.reset();
        RecordReaderMultiDataSetIterator.Builder builder = new RecordReaderMultiDataSetIterator.Builder(this.miniBatchSize);
        builder.addSequenceReader(READER_KEY, this.recordReader);
        if (this.labelsReader != null) {
            builder.addSequenceReader(READER_KEY_LABEL, this.labelsReader);
        }
        if (this.singleSequenceReaderMode) {
            if (this.labelIndex < 0 && this.numPossibleLabels < 0) {
                builder.addInput(READER_KEY);
            } else if (this.labelIndex == 0 || this.labelIndex == size - 1) {
                if (this.labelIndex < 0) {
                    i = 0;
                    i2 = size - 1;
                } else if (this.labelIndex == 0) {
                    i = 1;
                    i2 = size - 1;
                } else {
                    i = 0;
                    i2 = this.labelIndex - 1;
                }
                builder.addInput(READER_KEY, i, i2);
                this.underlyingIsDisjoint = false;
            } else if (!this.regression || this.numPossibleLabels <= 1) {
                builder.addInput(READER_KEY, 0, this.labelIndex - 1);
                builder.addInput(READER_KEY, this.labelIndex + 1, size - 1);
                this.underlyingIsDisjoint = true;
            } else {
                builder.addInput(READER_KEY, 0, this.labelIndex - 1);
                builder.addOutput(READER_KEY, this.labelIndex, size - 1);
                this.underlyingIsDisjoint = false;
            }
            if (this.labelIndex >= 0 || this.numPossibleLabels >= 0) {
                if (this.regression && this.numPossibleLabels <= 1) {
                    builder.addOutput(READER_KEY, this.labelIndex, this.labelIndex);
                } else if (!this.regression) {
                    builder.addOutputOneHot(READER_KEY, this.labelIndex, this.numPossibleLabels);
                }
            }
        } else {
            builder.addInput(READER_KEY);
            this.underlyingIsDisjoint = false;
            if (this.regression) {
                builder.addOutput(READER_KEY_LABEL);
            } else {
                builder.addOutputOneHot(READER_KEY_LABEL, 0, this.numPossibleLabels);
            }
        }
        if (this.alignmentMode != null) {
            switch (this.alignmentMode) {
                case EQUAL_LENGTH:
                    builder.sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.EQUAL_LENGTH);
                    break;
                case ALIGN_START:
                    builder.sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_START);
                    break;
                case ALIGN_END:
                    builder.sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_END);
                    break;
            }
        }
        this.underlying = builder.build();
        if (this.collectMetaData) {
            this.underlying.setCollectMetaData(true);
        }
    }

    private DataSet mdsToDataSet(MultiDataSet multiDataSet) {
        INDArray orNull;
        INDArray orNull2;
        if (this.underlyingIsDisjoint) {
            INDArray orNull3 = RecordReaderDataSetIterator.getOrNull(multiDataSet.getFeatures(), 0);
            INDArray orNull4 = RecordReaderDataSetIterator.getOrNull(multiDataSet.getFeatures(), 1);
            orNull2 = RecordReaderDataSetIterator.getOrNull(multiDataSet.getFeaturesMaskArrays(), 0);
            orNull = Nd4j.createUninitialized(new long[]{orNull3.size(0), orNull3.size(1) + orNull4.size(1), orNull3.size(2)});
            orNull.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0L, orNull3.size(1)), NDArrayIndex.all()}, orNull3);
            orNull.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(orNull3.size(1), orNull3.size(1) + orNull4.size(1)), NDArrayIndex.all()}, orNull4);
        } else {
            orNull = RecordReaderDataSetIterator.getOrNull(multiDataSet.getFeatures(), 0);
            orNull2 = RecordReaderDataSetIterator.getOrNull(multiDataSet.getFeaturesMaskArrays(), 0);
        }
        DataSet dataSet = new DataSet(orNull, RecordReaderDataSetIterator.getOrNull(multiDataSet.getLabels(), 0), orNull2, RecordReaderDataSetIterator.getOrNull(multiDataSet.getLabelsMaskArrays(), 0));
        if (this.collectMetaData) {
            List<RecordMetaDataComposableMap> exampleMetaData = multiDataSet.getExampleMetaData();
            ArrayList arrayList = new ArrayList(exampleMetaData.size());
            for (RecordMetaDataComposableMap recordMetaDataComposableMap : exampleMetaData) {
                if (this.singleSequenceReaderMode) {
                    arrayList.add((Serializable) recordMetaDataComposableMap.getMeta().get(READER_KEY));
                } else {
                    arrayList.add(new RecordMetaDataComposable(new RecordMetaData[]{(RecordMetaData) recordMetaDataComposableMap.getMeta().get(READER_KEY), (RecordMetaData) recordMetaDataComposableMap.getMeta().get(READER_KEY_LABEL)}));
                }
            }
            dataSet.setExampleMetaData(arrayList);
        }
        if (this.preProcessor != null) {
            this.preProcessor.preProcess(dataSet);
        }
        return dataSet;
    }

    public boolean hasNext() {
        if (this.underlying == null) {
            initializeUnderlyingFromReader();
        }
        return this.underlying.hasNext();
    }

    /* renamed from: next, reason: merged with bridge method [inline-methods] */
    public DataSet m4next() {
        return next(this.miniBatchSize);
    }

    public DataSet next(int i) {
        if (this.useStored) {
            this.useStored = false;
            DataSet dataSet = this.stored;
            this.stored = null;
            if (this.preProcessor != null) {
                this.preProcessor.preProcess(dataSet);
            }
            return dataSet;
        }
        if (!hasNext()) {
            throw new NoSuchElementException();
        }
        if (this.underlying == null) {
            initializeUnderlyingFromReader();
        }
        DataSet mdsToDataSet = mdsToDataSet(this.underlying.next(i));
        if (this.totalOutcomes == -1) {
            this.inputColumns = (int) mdsToDataSet.getFeatures().size(1);
            this.totalOutcomes = mdsToDataSet.getLabels() == null ? -1 : (int) mdsToDataSet.getLabels().size(1);
        }
        return mdsToDataSet;
    }

    public int inputColumns() {
        if (this.inputColumns != -1) {
            return this.inputColumns;
        }
        preLoad();
        return this.inputColumns;
    }

    public int totalOutcomes() {
        if (this.totalOutcomes != -1) {
            return this.totalOutcomes;
        }
        preLoad();
        return this.totalOutcomes;
    }

    private void preLoad() {
        this.stored = m4next();
        this.useStored = true;
        this.inputColumns = (int) this.stored.getFeatures().size(1);
        this.totalOutcomes = (int) this.stored.getLabels().size(1);
    }

    public boolean resetSupported() {
        return true;
    }

    public boolean asyncSupported() {
        return true;
    }

    public void reset() {
        if (this.underlying != null) {
            this.underlying.reset();
        }
        this.cursor = 0;
        this.stored = null;
        this.useStored = false;
    }

    public int batch() {
        return this.miniBatchSize;
    }

    public void setPreProcessor(DataSetPreProcessor dataSetPreProcessor) {
        this.preProcessor = dataSetPreProcessor;
    }

    public List<String> getLabels() {
        return null;
    }

    public void remove() {
        throw new UnsupportedOperationException("Remove not supported for this iterator");
    }

    public DataSet loadFromMetaData(RecordMetaData recordMetaData) throws IOException {
        return loadFromMetaData(Collections.singletonList(recordMetaData));
    }

    public DataSet loadFromMetaData(List<RecordMetaData> list) throws IOException {
        if (this.underlying == null) {
            initializeUnderlying(this.recordReader.loadSequenceFromMetaData(list.get(0)));
        }
        ArrayList arrayList = new ArrayList(list.size());
        if (this.singleSequenceReaderMode) {
            Iterator<RecordMetaData> it = list.iterator();
            while (it.hasNext()) {
                arrayList.add(new RecordMetaDataComposableMap(Collections.singletonMap(READER_KEY, it.next())));
            }
        } else {
            Iterator<RecordMetaData> it2 = list.iterator();
            while (it2.hasNext()) {
                RecordMetaDataComposable recordMetaDataComposable = (RecordMetaData) it2.next();
                HashMap hashMap = new HashMap(2);
                hashMap.put(READER_KEY, recordMetaDataComposable.getMeta()[0]);
                hashMap.put(READER_KEY_LABEL, recordMetaDataComposable.getMeta()[1]);
                arrayList.add(new RecordMetaDataComposableMap(hashMap));
            }
        }
        return mdsToDataSet(this.underlying.loadFromMetaData(arrayList));
    }

    public DataSetPreProcessor getPreProcessor() {
        return this.preProcessor;
    }

    public boolean isCollectMetaData() {
        return this.collectMetaData;
    }

    public void setCollectMetaData(boolean z) {
        this.collectMetaData = z;
    }
}
