/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.datavec;

import java.io.Serializable;
import java.util.Iterator;
import java.util.List;
import org.apache.spark.api.java.function.Function;
import org.datavec.api.io.WritableConverter;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.util.FeatureUtil;

public class DataVecSequenceDataSetFunction
implements Function<List<List<Writable>>, org.nd4j.linalg.dataset.DataSet>,
Serializable {
    private final boolean regression;
    private final int labelIndex;
    private final int numPossibleLabels;
    private final DataSetPreProcessor preProcessor;
    private final WritableConverter converter;

    public DataVecSequenceDataSetFunction(int labelIndex, int numPossibleLabels, boolean regression) {
        this(labelIndex, numPossibleLabels, regression, null, null);
    }

    public DataVecSequenceDataSetFunction(int labelIndex, int numPossibleLabels, boolean regression, DataSetPreProcessor preProcessor, WritableConverter converter) {
        this.labelIndex = labelIndex;
        this.numPossibleLabels = numPossibleLabels;
        this.regression = regression;
        this.preProcessor = preProcessor;
        this.converter = converter;
    }

    public org.nd4j.linalg.dataset.DataSet call(List<List<Writable>> input) throws Exception {
        Iterator<List<Writable>> iter = input.iterator();
        INDArray features = null;
        INDArray labels = Nd4j.zeros((int[])new int[]{1, this.regression ? 1 : this.numPossibleLabels, input.size()});
        int[] fIdx = new int[3];
        int[] lIdx = new int[3];
        int i = 0;
        while (iter.hasNext()) {
            List<Writable> step = iter.next();
            if (i == 0) {
                features = Nd4j.zeros((int[])new int[]{1, step.size() - 1, input.size()});
            }
            Iterator<Writable> timeStepIter = step.iterator();
            int countIn = 0;
            int countFeatures = 0;
            while (timeStepIter.hasNext()) {
                Writable current = timeStepIter.next();
                if (this.converter != null) {
                    current = this.converter.convert(current);
                }
                if (countIn++ == this.labelIndex) {
                    if (this.regression) {
                        lIdx[2] = i;
                        labels.putScalar(lIdx, current.toDouble());
                        continue;
                    }
                    INDArray line = FeatureUtil.toOutcomeVector((int)current.toInt(), (int)this.numPossibleLabels);
                    labels.tensorAlongDimension(i, new int[]{1}).assign(line);
                    continue;
                }
                fIdx[1] = countFeatures++;
                fIdx[2] = i;
                try {
                    features.putScalar(fIdx, current.toDouble());
                }
                catch (UnsupportedOperationException e) {
                    if (current instanceof NDArrayWritable) {
                        features.get(new INDArrayIndex[]{NDArrayIndex.point((int)fIdx[0]), NDArrayIndex.all(), NDArrayIndex.point((int)fIdx[2])}).putRow(0, ((NDArrayWritable)current).get());
                        continue;
                    }
                    throw e;
                }
            }
            ++i;
        }
        org.nd4j.linalg.dataset.DataSet ds = new org.nd4j.linalg.dataset.DataSet(features, labels);
        if (this.preProcessor != null) {
            this.preProcessor.preProcess((DataSet)ds);
        }
        return ds;
    }
}

