/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.datasets.datavec;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.writable.Writable;
import org.datavec.common.data.NDArrayWritable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.util.FeatureUtil;

public class SequenceRecordReaderDataSetIterator
implements DataSetIterator {
    private SequenceRecordReader recordReader;
    private SequenceRecordReader labelsReader;
    private int miniBatchSize = 10;
    private final boolean regression;
    private int labelIndex = -1;
    private final int numPossibleLabels;
    private int cursor = 0;
    private int inputColumns = -1;
    private int totalOutcomes = -1;
    private boolean useStored = false;
    private org.nd4j.linalg.dataset.DataSet stored = null;
    private DataSetPreProcessor preProcessor;
    private AlignmentMode alignmentMode;
    private final boolean singleSequenceReaderMode;

    public SequenceRecordReaderDataSetIterator(SequenceRecordReader featuresReader, SequenceRecordReader labels, int miniBatchSize, int numPossibleLabels, boolean regression) {
        this(featuresReader, labels, miniBatchSize, numPossibleLabels, regression, AlignmentMode.EQUAL_LENGTH);
    }

    public SequenceRecordReaderDataSetIterator(SequenceRecordReader featuresReader, SequenceRecordReader labels, int miniBatchSize, int numPossibleLabels, boolean regression, AlignmentMode alignmentMode) {
        this.recordReader = featuresReader;
        this.labelsReader = labels;
        this.miniBatchSize = miniBatchSize;
        this.numPossibleLabels = numPossibleLabels;
        this.regression = regression;
        this.alignmentMode = alignmentMode;
        this.singleSequenceReaderMode = false;
    }

    public SequenceRecordReaderDataSetIterator(SequenceRecordReader reader, int miniBatchSize, int numPossibleLabels, int labelIndex, boolean regression) {
        this.recordReader = reader;
        this.labelsReader = null;
        this.miniBatchSize = miniBatchSize;
        this.regression = regression;
        this.labelIndex = labelIndex;
        this.numPossibleLabels = numPossibleLabels;
        this.singleSequenceReaderMode = true;
    }

    public boolean hasNext() {
        return this.recordReader.hasNext();
    }

    public org.nd4j.linalg.dataset.DataSet next() {
        return this.next(this.miniBatchSize);
    }

    public org.nd4j.linalg.dataset.DataSet next(int num) {
        if (this.useStored) {
            this.useStored = false;
            org.nd4j.linalg.dataset.DataSet temp = this.stored;
            this.stored = null;
            if (this.preProcessor != null) {
                this.preProcessor.preProcess((DataSet)temp);
            }
            return temp;
        }
        if (!this.hasNext()) {
            throw new NoSuchElementException();
        }
        if (this.singleSequenceReaderMode) {
            return this.nextSingleSequenceReader(num);
        }
        return this.nextMultipleSequenceReaders(num);
    }

    private org.nd4j.linalg.dataset.DataSet nextSingleSequenceReader(int num) {
        int i;
        ArrayList<INDArray> listFeatures = new ArrayList<INDArray>(num);
        ArrayList<INDArray> listLabels = new ArrayList<INDArray>(num);
        int minLength = 0;
        int maxLength = 0;
        for (int i2 = 0; i2 < num && this.hasNext(); ++i2) {
            Collection sequence = this.recordReader.sequenceRecord();
            INDArray[] fl = this.getFeaturesLabelsSingleReader(sequence);
            if (i2 == 0) {
                maxLength = minLength = fl[0].size(0);
            } else {
                minLength = Math.min(minLength, fl[0].size(0));
                maxLength = Math.max(maxLength, fl[0].size(0));
            }
            listFeatures.add(fl[0]);
            listLabels.add(fl[1]);
        }
        INDArray featuresOut = Nd4j.create((int[])new int[]{listFeatures.size(), ((INDArray)listFeatures.get(0)).size(1), maxLength}, (char)'f');
        INDArray labelsOut = Nd4j.create((int[])new int[]{listLabels.size(), ((INDArray)listLabels.get(0)).size(1), maxLength}, (char)'f');
        INDArray featuresMask = null;
        INDArray labelsMask = null;
        if (minLength == maxLength) {
            for (i = 0; i < listFeatures.size(); ++i) {
                featuresOut.tensorAlongDimension(i, new int[]{1, 2}).permutei(new int[]{1, 0}).assign((INDArray)listFeatures.get(i));
                labelsOut.tensorAlongDimension(i, new int[]{1, 2}).permutei(new int[]{1, 0}).assign((INDArray)listLabels.get(i));
            }
        } else {
            featuresMask = Nd4j.ones((int)listFeatures.size(), (int)maxLength);
            labelsMask = Nd4j.ones((int)listLabels.size(), (int)maxLength);
            for (i = 0; i < listFeatures.size(); ++i) {
                INDArray f = (INDArray)listFeatures.get(i);
                int tsLength = f.size(0);
                featuresOut.tensorAlongDimension(i, new int[]{1, 2}).permutei(new int[]{1, 0}).put(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)tsLength), NDArrayIndex.all()}, f);
                labelsOut.tensorAlongDimension(i, new int[]{1, 2}).permutei(new int[]{1, 0}).put(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)tsLength), NDArrayIndex.all()}, (INDArray)listLabels.get(i));
                for (int j = tsLength; j < maxLength; ++j) {
                    featuresMask.put(i, j, (Number)0.0);
                    labelsMask.put(i, j, (Number)0.0);
                }
            }
        }
        this.cursor += listFeatures.size();
        if (this.inputColumns == -1) {
            this.inputColumns = featuresOut.size(1);
        }
        if (this.totalOutcomes == -1) {
            this.totalOutcomes = labelsOut.size(1);
        }
        org.nd4j.linalg.dataset.DataSet ds = new org.nd4j.linalg.dataset.DataSet(featuresOut, labelsOut, featuresMask, labelsMask);
        if (this.preProcessor != null) {
            this.preProcessor.preProcess((DataSet)ds);
        }
        return ds;
    }

    private org.nd4j.linalg.dataset.DataSet nextMultipleSequenceReaders(int num) {
        INDArray labelsOut;
        INDArray featuresOut;
        ArrayList<INDArray> featureList = new ArrayList<INDArray>(num);
        ArrayList<INDArray> labelList = new ArrayList<INDArray>(num);
        for (int i = 0; i < num && this.hasNext(); ++i) {
            Collection featureSequence = this.recordReader.sequenceRecord();
            Collection labelSequence = this.labelsReader.sequenceRecord();
            INDArray features = this.getFeatures(featureSequence);
            INDArray labels = this.getLabels(labelSequence);
            featureList.add(features);
            labelList.add(labels);
        }
        INDArray featuresMask = null;
        INDArray labelsMask = null;
        if (this.alignmentMode == AlignmentMode.EQUAL_LENGTH) {
            int[] featureShape = new int[]{featureList.size(), ((INDArray)featureList.get(0)).size(1), ((INDArray)featureList.get(0)).size(0)};
            int[] labelShape = new int[]{labelList.size(), ((INDArray)labelList.get(0)).size(1), ((INDArray)labelList.get(0)).size(0)};
            featuresOut = Nd4j.create((int[])featureShape, (char)'f');
            labelsOut = Nd4j.create((int[])labelShape, (char)'f');
            for (int i = 0; i < featureList.size(); ++i) {
                featuresOut.tensorAlongDimension(i, new int[]{1, 2}).permutei(new int[]{1, 0}).assign((INDArray)featureList.get(i));
                labelsOut.tensorAlongDimension(i, new int[]{1, 2}).permutei(new int[]{1, 0}).assign((INDArray)labelList.get(i));
            }
        } else if (this.alignmentMode == AlignmentMode.ALIGN_START) {
            int longestTimeSeries = 0;
            for (INDArray features : featureList) {
                longestTimeSeries = Math.max(features.size(0), longestTimeSeries);
            }
            for (INDArray labels : labelList) {
                longestTimeSeries = Math.max(labels.size(0), longestTimeSeries);
            }
            int[] featuresShape = new int[]{featureList.size(), ((INDArray)featureList.get(0)).size(1), longestTimeSeries};
            int[] labelsShape = new int[]{labelList.size(), ((INDArray)labelList.get(0)).size(1), longestTimeSeries};
            featuresOut = Nd4j.create((int[])featuresShape, (char)'f');
            labelsOut = Nd4j.create((int[])labelsShape, (char)'f');
            featuresMask = Nd4j.ones((int)featureList.size(), (int)longestTimeSeries);
            labelsMask = Nd4j.ones((int)labelList.size(), (int)longestTimeSeries);
            for (int i = 0; i < featureList.size(); ++i) {
                int j;
                INDArray f = (INDArray)featureList.get(i);
                INDArray l = (INDArray)labelList.get(i);
                featuresOut.tensorAlongDimension(i, new int[]{1, 2}).permutei(new int[]{1, 0}).put(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)f.size(0)), NDArrayIndex.all()}, f);
                labelsOut.tensorAlongDimension(i, new int[]{1, 2}).permutei(new int[]{1, 0}).put(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)l.size(0)), NDArrayIndex.all()}, l);
                for (j = f.size(0); j < longestTimeSeries; ++j) {
                    featuresMask.putScalar(i, j, 0.0);
                }
                for (j = l.size(0); j < longestTimeSeries; ++j) {
                    labelsMask.putScalar(i, j, 0.0);
                }
            }
        } else if (this.alignmentMode == AlignmentMode.ALIGN_END) {
            int longestTimeSeries = 0;
            for (INDArray features : featureList) {
                longestTimeSeries = Math.max(features.size(0), longestTimeSeries);
            }
            for (INDArray labels : labelList) {
                longestTimeSeries = Math.max(labels.size(0), longestTimeSeries);
            }
            Object featuresShape = new int[]{featureList.size(), ((INDArray)featureList.get(0)).size(1), longestTimeSeries};
            int[] labelsShape = new int[]{labelList.size(), ((INDArray)labelList.get(0)).size(1), longestTimeSeries};
            featuresOut = Nd4j.create((int[])featuresShape, (char)'f');
            labelsOut = Nd4j.create((int[])labelsShape, (char)'f');
            featuresMask = Nd4j.ones((int)featureList.size(), (int)longestTimeSeries);
            labelsMask = Nd4j.ones((int)labelList.size(), (int)longestTimeSeries);
            for (int i = 0; i < featureList.size(); ++i) {
                int j;
                int lLen;
                INDArray f = (INDArray)featureList.get(i);
                INDArray l = (INDArray)labelList.get(i);
                int fLen = f.size(0);
                if (fLen >= (lLen = l.size(0))) {
                    featuresOut.tensorAlongDimension(i, new int[]{1, 2}).permutei(new int[]{1, 0}).put(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)fLen), NDArrayIndex.all()}, f);
                    labelsOut.tensorAlongDimension(i, new int[]{1, 2}).permutei(new int[]{1, 0}).put(new INDArrayIndex[]{NDArrayIndex.interval((int)(fLen - lLen), (int)fLen), NDArrayIndex.all()}, l);
                    for (j = fLen; j < longestTimeSeries; ++j) {
                        featuresMask.putScalar(i, j, 0.0);
                    }
                    for (j = 0; j < fLen - lLen; ++j) {
                        labelsMask.putScalar(i, j, 0.0);
                    }
                    for (j = fLen; j < longestTimeSeries; ++j) {
                        labelsMask.putScalar(i, j, 0.0);
                    }
                    continue;
                }
                featuresOut.tensorAlongDimension(i, new int[]{1, 2}).permutei(new int[]{1, 0}).put(new INDArrayIndex[]{NDArrayIndex.interval((int)(lLen - fLen), (int)lLen), NDArrayIndex.all()}, f);
                labelsOut.tensorAlongDimension(i, new int[]{1, 2}).permutei(new int[]{1, 0}).put(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)lLen), NDArrayIndex.all()}, l);
                for (j = 0; j < lLen - fLen; ++j) {
                    featuresMask.putScalar(i, j, 0.0);
                }
                for (j = lLen; j < longestTimeSeries; ++j) {
                    featuresMask.putScalar(i, j, 0.0);
                }
                for (j = lLen; j < longestTimeSeries; ++j) {
                    labelsMask.putScalar(i, j, 0.0);
                }
            }
        } else {
            throw new UnsupportedOperationException("Unknown alignment mode: " + (Object)((Object)this.alignmentMode));
        }
        this.cursor += featureList.size();
        if (this.inputColumns == -1) {
            this.inputColumns = featuresOut.size(1);
        }
        if (this.totalOutcomes == -1) {
            this.totalOutcomes = labelsOut.size(1);
        }
        org.nd4j.linalg.dataset.DataSet ds = new org.nd4j.linalg.dataset.DataSet(featuresOut, labelsOut, featuresMask, labelsMask);
        if (this.preProcessor != null) {
            this.preProcessor.preProcess((DataSet)ds);
        }
        return ds;
    }

    public int totalExamples() {
        throw new UnsupportedOperationException("Not supported");
    }

    public int inputColumns() {
        if (this.inputColumns != -1) {
            return this.inputColumns;
        }
        this.preLoad();
        return this.inputColumns;
    }

    public int totalOutcomes() {
        if (this.totalOutcomes != -1) {
            return this.totalOutcomes;
        }
        this.preLoad();
        return this.totalOutcomes;
    }

    private void preLoad() {
        this.stored = this.next();
        this.useStored = true;
        this.inputColumns = this.stored.getFeatureMatrix().size(1);
        this.totalOutcomes = this.stored.getLabels().size(1);
    }

    public void reset() {
        this.recordReader.reset();
        if (this.labelsReader != null) {
            this.labelsReader.reset();
        }
        this.cursor = 0;
        this.stored = null;
        this.useStored = false;
    }

    public int batch() {
        return this.miniBatchSize;
    }

    public int cursor() {
        return this.cursor;
    }

    public int numExamples() {
        throw new UnsupportedOperationException("Not supported");
    }

    public void setPreProcessor(DataSetPreProcessor preProcessor) {
        this.preProcessor = preProcessor;
    }

    public List<String> getLabels() {
        return null;
    }

    public void remove() {
        throw new UnsupportedOperationException("Remove not supported for this iterator");
    }

    private INDArray getFeatures(Collection<Collection<Writable>> features) {
        int[] shape = new int[2];
        shape[0] = features.size();
        Iterator<Collection<Writable>> iter = features.iterator();
        int i = 0;
        INDArray out = null;
        while (iter.hasNext()) {
            Collection<Writable> step = iter.next();
            if (i == 0) {
                for (Writable w : step) {
                    if (w instanceof NDArrayWritable) {
                        shape[1] = shape[1] + ((NDArrayWritable)w).get().length();
                        continue;
                    }
                    shape[1] = shape[1] + 1;
                }
                out = Nd4j.create((int[])shape, (char)'f');
            }
            Iterator<Writable> timeStepIter = step.iterator();
            int f = 0;
            while (timeStepIter.hasNext()) {
                Writable current = timeStepIter.next();
                if (current instanceof NDArrayWritable) {
                    INDArray arr = ((NDArrayWritable)current).get();
                    out.put(new INDArrayIndex[]{NDArrayIndex.point((int)i), NDArrayIndex.interval((int)f, (int)(f + arr.length()))}, arr);
                    f += arr.length();
                    continue;
                }
                out.put(i, f++, (Number)current.toDouble());
            }
            ++i;
        }
        return out;
    }

    /*
     * WARNING - void declaration
     */
    private INDArray getLabels(Collection<Collection<Writable>> labels) {
        int[] shape = new int[2];
        shape[0] = labels.size();
        Iterator<Collection<Writable>> iter = labels.iterator();
        int i = 0;
        INDArray out = null;
        while (iter.hasNext()) {
            List<Object> step;
            Collection<Writable> stepCollection = iter.next();
            List<Object> list = step = stepCollection instanceof List ? (List<Object>)stepCollection : new ArrayList<Writable>(stepCollection);
            if (i == 0) {
                if (this.regression) {
                    for (Writable writable : step) {
                        if (writable instanceof NDArrayWritable) {
                            shape[1] = shape[1] + ((NDArrayWritable)writable).get().length();
                            continue;
                        }
                        shape[1] = shape[1] + 1;
                    }
                } else {
                    shape[1] = this.numPossibleLabels;
                }
                out = Nd4j.create((int[])shape, (char)'f');
            }
            Iterator timeStepIter = step.iterator();
            boolean bl = false;
            if (this.regression) {
                while (timeStepIter.hasNext()) {
                    void var9_9;
                    Writable current = (Writable)timeStepIter.next();
                    if (current instanceof NDArrayWritable) {
                        INDArray w = ((NDArrayWritable)current).get();
                        out.put(new INDArrayIndex[]{NDArrayIndex.point((int)i), NDArrayIndex.interval((int)var9_9, (int)(var9_9 + w.length()))}, w);
                        var9_9 += w.length();
                        continue;
                    }
                    out.put(i, (int)var9_9++, (Number)current.toDouble());
                }
            } else {
                Writable value = (Writable)timeStepIter.next();
                int idx = value.toInt();
                INDArray line = FeatureUtil.toOutcomeVector((int)idx, (int)this.numPossibleLabels);
                out.getRow(i).assign(line);
            }
            ++i;
        }
        return out;
    }

    /*
     * WARNING - void declaration
     */
    private INDArray[] getFeaturesLabelsSingleReader(Collection<Collection<Writable>> input) {
        Iterator<Collection<Writable>> iter = input.iterator();
        int i = 0;
        INDArray features = null;
        INDArray labels = null;
        int featureSize = 0;
        while (iter.hasNext()) {
            List<Object> step;
            Collection<Writable> stepCollection = iter.next();
            List<Object> list = step = stepCollection instanceof List ? (List<Object>)stepCollection : new ArrayList<Writable>(stepCollection);
            if (i == 0) {
                int j = 0;
                for (Writable writable : step) {
                    if (j++ == this.labelIndex) continue;
                    if (writable instanceof NDArrayWritable) {
                        featureSize += ((NDArrayWritable)writable).get().length();
                        continue;
                    }
                    ++featureSize;
                }
                features = Nd4j.zeros((int)input.size(), (int)featureSize);
                int labelSize = this.regression ? (step.get(this.labelIndex) instanceof NDArrayWritable ? ((NDArrayWritable)step.get(this.labelIndex)).get().length() : 1) : this.numPossibleLabels;
                labels = Nd4j.zeros((int)input.size(), (int)labelSize);
            }
            Iterator timeStepIter = step.iterator();
            int countIn = 0;
            boolean bl = false;
            while (timeStepIter.hasNext()) {
                void var11_13;
                Writable current = (Writable)timeStepIter.next();
                if (countIn++ == this.labelIndex) {
                    if (this.regression) {
                        if (current instanceof NDArrayWritable) {
                            labels.putRow(i, ((NDArrayWritable)current).get());
                            continue;
                        }
                        labels.put(i, 0, (Number)current.toDouble());
                        continue;
                    }
                    labels.putScalar(i, current.toInt(), 1.0);
                    continue;
                }
                if (current instanceof NDArrayWritable) {
                    INDArray w = ((NDArrayWritable)current).get();
                    int length = w.length();
                    features.put(new INDArrayIndex[]{NDArrayIndex.point((int)i), NDArrayIndex.interval((int)var11_13, (int)(var11_13 + length))}, w);
                    var11_13 += length;
                    continue;
                }
                features.put(i, (int)var11_13++, (Number)current.toDouble());
            }
            ++i;
        }
        return new INDArray[]{features, labels};
    }

    public DataSetPreProcessor getPreProcessor() {
        return this.preProcessor;
    }

    public static enum AlignmentMode {
        EQUAL_LENGTH,
        ALIGN_START,
        ALIGN_END;

    }
}

