/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.util;

import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.input.PortableDataStream;
import org.apache.spark.mllib.linalg.Matrices;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.canova.api.records.reader.RecordReader;
import org.canova.api.split.InputSplit;
import org.canova.api.split.InputStreamInputSplit;
import org.canova.api.writable.Writable;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.FeatureUtil;
import scala.Tuple2;

public class MLLibUtil {
    public static double toClassifierPrediction(Vector vector) {
        double max = Double.NEGATIVE_INFINITY;
        int maxIndex = 0;
        for (int i = 0; i < vector.size(); ++i) {
            double curr = vector.apply(i);
            if (!(curr > max)) continue;
            maxIndex = i;
            max = curr;
        }
        return maxIndex;
    }

    public static INDArray toMatrix(Matrix arr) {
        return Nd4j.create((double[])arr.toArray(), (int[])new int[]{arr.numRows(), arr.numCols()});
    }

    public static INDArray toVector(Vector arr) {
        return Nd4j.create((DataBuffer)Nd4j.createBuffer((double[])arr.toArray()));
    }

    public static Matrix toMatrix(INDArray arr) {
        if (!arr.isMatrix()) {
            throw new IllegalArgumentException("passed in array must be a matrix");
        }
        return Matrices.dense((int)arr.rows(), (int)arr.columns(), (double[])arr.data().asDouble());
    }

    public static Vector toVector(INDArray arr) {
        if (!arr.isVector()) {
            throw new IllegalArgumentException("passed in array must be a vector");
        }
        double[] ret = new double[arr.length()];
        for (int i = 0; i < arr.length(); ++i) {
            ret[i] = arr.getDouble(i);
        }
        return Vectors.dense((double[])ret);
    }

    public static JavaRDD<LabeledPoint> fromBinary(JavaPairRDD<String, PortableDataStream> binaryFiles, final RecordReader reader) {
        JavaRDD records = binaryFiles.map((Function)new Function<Tuple2<String, PortableDataStream>, Collection<Writable>>(){

            public Collection<Writable> call(Tuple2<String, PortableDataStream> stringPortableDataStreamTuple2) throws Exception {
                reader.initialize((InputSplit)new InputStreamInputSplit((InputStream)((PortableDataStream)stringPortableDataStreamTuple2._2()).open(), (String)stringPortableDataStreamTuple2._1()));
                return reader.next();
            }
        });
        JavaRDD ret = records.map((Function)new Function<Collection<Writable>, LabeledPoint>(){

            public LabeledPoint call(Collection<Writable> writables) throws Exception {
                return MLLibUtil.pointOf(writables);
            }
        });
        return ret;
    }

    public static LabeledPoint pointOf(Collection<Writable> writables) {
        double[] ret = new double[writables.size() - 1];
        int count = 0;
        double target = 0.0;
        for (Writable w : writables) {
            if (count < writables.size() - 1) {
                ret[count++] = Float.parseFloat(w.toString());
                continue;
            }
            target = Float.parseFloat(w.toString());
        }
        if (target < 0.0) {
            throw new IllegalStateException("Target must be >= 0");
        }
        return new LabeledPoint(target, Vectors.dense((double[])ret));
    }

    public static JavaRDD<DataSet> fromLabeledPoint(JavaRDD<LabeledPoint> data, final int numPossibleLabels, int batchSize) {
        JavaPairRDD dataWithIndex = data.zipWithIndex().mapToPair((PairFunction)new PairFunction<Tuple2<LabeledPoint, Long>, Long, LabeledPoint>(){

            public Tuple2<Long, LabeledPoint> call(Tuple2<LabeledPoint, Long> labeledPointLongTuple2) throws Exception {
                return new Tuple2(labeledPointLongTuple2._2(), labeledPointLongTuple2._1());
            }
        });
        JavaPairRDD mappedData = dataWithIndex.mapToPair((PairFunction)new PairFunction<Tuple2<Long, LabeledPoint>, Long, DataSet>(){

            public Tuple2<Long, DataSet> call(Tuple2<Long, LabeledPoint> longLabeledPointTuple2) throws Exception {
                return new Tuple2(longLabeledPointTuple2._1(), (Object)MLLibUtil.fromLabeledPoint((LabeledPoint)longLabeledPointTuple2._2(), numPossibleLabels));
            }
        });
        JavaPairRDD aggregated = mappedData.reduceByKey((Function2)new Function2<DataSet, DataSet, DataSet>(){

            public DataSet call(DataSet v1, DataSet v2) throws Exception {
                return new DataSet(Nd4j.vstack((INDArray[])new INDArray[]{v1.getFeatureMatrix(), v2.getFeatureMatrix()}), Nd4j.vstack((INDArray[])new INDArray[]{v1.getLabels(), v2.getLabels()}));
            }
        }, (int)(mappedData.count() / (long)batchSize));
        JavaRDD data2 = aggregated.flatMap((FlatMapFunction)new FlatMapFunction<Tuple2<Long, DataSet>, DataSet>(){

            public Iterable<DataSet> call(Tuple2<Long, DataSet> longDataSetTuple2) throws Exception {
                return (Iterable)longDataSetTuple2._2();
            }
        });
        return data2;
    }

    public static JavaRDD<DataSet> fromLabeledPoint(JavaSparkContext sc, JavaRDD<LabeledPoint> data, int numPossibleLabels) {
        List<DataSet> list = MLLibUtil.fromLabeledPoint(data.collect(), numPossibleLabels);
        return sc.parallelize(list);
    }

    public static JavaRDD<LabeledPoint> fromDataSet(JavaSparkContext sc, JavaRDD<DataSet> data) {
        List<LabeledPoint> list = MLLibUtil.toLabeledPoint(data.collect());
        return sc.parallelize(list);
    }

    private static List<LabeledPoint> toLabeledPoint(List<DataSet> labeledPoints) {
        ArrayList<LabeledPoint> ret = new ArrayList<LabeledPoint>();
        for (DataSet point : labeledPoints) {
            ret.add(MLLibUtil.toLabeledPoint(point));
        }
        return ret;
    }

    private static LabeledPoint toLabeledPoint(DataSet point) {
        if (!point.getFeatureMatrix().isVector()) {
            throw new IllegalArgumentException("Feature matrix must be a vector");
        }
        Vector features = MLLibUtil.toVector(point.getFeatureMatrix().dup());
        double label = Nd4j.getBlasWrapper().iamax(point.getLabels());
        return new LabeledPoint(label, features);
    }

    private static List<DataSet> fromLabeledPoint(List<LabeledPoint> labeledPoints, int numPossibleLabels) {
        ArrayList<DataSet> ret = new ArrayList<DataSet>();
        for (LabeledPoint point : labeledPoints) {
            ret.add(MLLibUtil.fromLabeledPoint(point, numPossibleLabels));
        }
        return ret;
    }

    private static DataSet fromLabeledPoint(LabeledPoint point, int numPossibleLabels) {
        Vector features = point.features();
        double label = point.label();
        return new DataSet(Nd4j.create((double[])features.toArray()), FeatureUtil.toOutcomeVector((int)((int)label), (int)numPossibleLabels));
    }
}

