/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.util;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.lang.reflect.Array;
import java.net.URI;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import org.apache.commons.io.IOUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.LocatedFileStatus;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.RemoteIterator;
import org.apache.spark.HashPartitioner;
import org.apache.spark.Partitioner;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.serializer.SerializerInstance;
import org.deeplearning4j.spark.api.Repartition;
import org.deeplearning4j.spark.api.RepartitionStrategy;
import org.deeplearning4j.spark.data.BatchDataSetsFunction;
import org.deeplearning4j.spark.data.shuffle.SplitDataSetExamplesPairFlatMapFunction;
import org.deeplearning4j.spark.impl.common.CountPartitionsFunction;
import org.deeplearning4j.spark.impl.common.SplitPartitionsFunction;
import org.deeplearning4j.spark.impl.common.SplitPartitionsFunction2;
import org.deeplearning4j.spark.impl.common.repartition.BalancedPartitioner;
import org.deeplearning4j.spark.impl.common.repartition.HashingBalancedPartitioner;
import org.deeplearning4j.spark.impl.common.repartition.MapTupleToPairFlatMap;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import scala.Tuple2;

public class SparkUtils {
    private static final String KRYO_EXCEPTION_MSG = "Kryo serialization detected without an appropriate registrator for ND4J INDArrays.\nWhen using Kryo, An appropriate Kryo registrator must be used to avoid serialization issues (NullPointerException) with off-heap data in INDArrays.\nUse nd4j-kryo_2.10 or _2.11 artifact, with sparkConf.set(\"spark.kryo.registrator\", \"org.nd4j.Nd4jRegistrator\");\nSee https://deeplearning4j.org/spark#kryo for more details";

    private SparkUtils() {
    }

    public static boolean checkKryoConfiguration(JavaSparkContext javaSparkContext, Logger log) {
        String kryoRegistrator;
        String serializer = javaSparkContext.getConf().get("spark.serializer", null);
        if (serializer != null && serializer.equals("org.apache.spark.serializer.KryoSerializer") && ((kryoRegistrator = javaSparkContext.getConf().get("spark.kryo.registrator", null)) == null || !kryoRegistrator.equals("org.nd4j.Nd4jRegistrator"))) {
            boolean equals;
            ByteBuffer bb;
            SerializerInstance si;
            try {
                si = javaSparkContext.env().serializer().newInstance();
                bb = si.serialize((Object)Nd4j.linspace((int)1, (int)5, (int)5), null);
            }
            catch (Exception e) {
                throw new RuntimeException(KRYO_EXCEPTION_MSG, e);
            }
            if (bb == null) {
                throw new RuntimeException("Kryo serialization detected without an appropriate registrator for ND4J INDArrays.\nWhen using Kryo, An appropriate Kryo registrator must be used to avoid serialization issues (NullPointerException) with off-heap data in INDArrays.\nUse nd4j-kryo_2.10 or _2.11 artifact, with sparkConf.set(\"spark.kryo.registrator\", \"org.nd4j.Nd4jRegistrator\");\nSee https://deeplearning4j.org/spark#kryo for more details\n(Got: null ByteBuffer from Spark SerializerInstance)");
            }
            try {
                INDArray deserialized = (INDArray)si.deserialize(bb, null);
                equals = Nd4j.linspace((int)1, (int)5, (int)5).equals(deserialized);
            }
            catch (Exception e) {
                throw new RuntimeException(KRYO_EXCEPTION_MSG, e);
            }
            if (!equals) {
                throw new RuntimeException("Kryo serialization detected without an appropriate registrator for ND4J INDArrays.\nWhen using Kryo, An appropriate Kryo registrator must be used to avoid serialization issues (NullPointerException) with off-heap data in INDArrays.\nUse nd4j-kryo_2.10 or _2.11 artifact, with sparkConf.set(\"spark.kryo.registrator\", \"org.nd4j.Nd4jRegistrator\");\nSee https://deeplearning4j.org/spark#kryo for more details\n(Error during deserialization: test array was not deserialized successfully)");
            }
            return true;
        }
        return true;
    }

    public static void writeStringToFile(String path, String toWrite, JavaSparkContext sc) throws IOException {
        SparkUtils.writeStringToFile(path, toWrite, sc.sc());
    }

    public static void writeStringToFile(String path, String toWrite, SparkContext sc) throws IOException {
        FileSystem fileSystem = FileSystem.get((Configuration)sc.hadoopConfiguration());
        try (BufferedOutputStream bos = new BufferedOutputStream((OutputStream)fileSystem.create(new Path(path)));){
            bos.write(toWrite.getBytes("UTF-8"));
        }
    }

    public static String readStringFromFile(String path, JavaSparkContext sc) throws IOException {
        return SparkUtils.readStringFromFile(path, sc.sc());
    }

    public static String readStringFromFile(String path, SparkContext sc) throws IOException {
        FileSystem fileSystem = FileSystem.get((Configuration)sc.hadoopConfiguration());
        try (BufferedInputStream bis = new BufferedInputStream((InputStream)fileSystem.open(new Path(path)));){
            byte[] asBytes = IOUtils.toByteArray((InputStream)bis);
            String string = new String(asBytes, "UTF-8");
            return string;
        }
    }

    public static void writeObjectToFile(String path, Object toWrite, JavaSparkContext sc) throws IOException {
        SparkUtils.writeObjectToFile(path, toWrite, sc.sc());
    }

    public static void writeObjectToFile(String path, Object toWrite, SparkContext sc) throws IOException {
        FileSystem fileSystem = FileSystem.get((Configuration)sc.hadoopConfiguration());
        try (BufferedOutputStream bos = new BufferedOutputStream((OutputStream)fileSystem.create(new Path(path)));){
            ObjectOutputStream oos = new ObjectOutputStream(bos);
            oos.writeObject(toWrite);
        }
    }

    public static <T> T readObjectFromFile(String path, Class<T> type, JavaSparkContext sc) throws IOException {
        return SparkUtils.readObjectFromFile(path, type, sc.sc());
    }

    public static <T> T readObjectFromFile(String path, Class<T> type, SparkContext sc) throws IOException {
        FileSystem fileSystem = FileSystem.get((Configuration)sc.hadoopConfiguration());
        try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream((InputStream)fileSystem.open(new Path(path))));){
            Object o;
            try {
                o = ois.readObject();
            }
            catch (ClassNotFoundException e) {
                throw new RuntimeException(e);
            }
            Object object = o;
            return (T)object;
        }
    }

    public static <T> JavaRDD<T> repartition(JavaRDD<T> rdd, Repartition repartition, RepartitionStrategy repartitionStrategy, int objectsPerPartition, int numPartitions) {
        if (repartition == Repartition.Never) {
            return rdd;
        }
        switch (repartitionStrategy) {
            case SparkDefault: {
                if (repartition == Repartition.NumPartitionsWorkersDiffers && rdd.partitions().size() == numPartitions) {
                    return rdd;
                }
                return rdd.repartition(numPartitions);
            }
            case Balanced: {
                return SparkUtils.repartitionBalanceIfRequired(rdd, repartition, objectsPerPartition, numPartitions);
            }
            case ApproximateBalanced: {
                return SparkUtils.repartitionApproximateBalance(rdd, repartition, numPartitions);
            }
        }
        throw new RuntimeException("Unknown repartition strategy: " + (Object)((Object)repartitionStrategy));
    }

    public static <T> JavaRDD<T> repartitionApproximateBalance(JavaRDD<T> rdd, Repartition repartition, int numPartitions) {
        int origNumPartitions = rdd.partitions().size();
        switch (repartition) {
            case Never: {
                return rdd;
            }
            case NumPartitionsWorkersDiffers: {
                if (origNumPartitions == numPartitions) {
                    return rdd;
                }
            }
            case Always: {
                int i;
                List partitionCounts = rdd.mapPartitionsWithIndex(new Function2<Integer, Iterator<T>, Iterator<Integer>>(){

                    public Iterator<Integer> call(Integer integer, Iterator<T> tIterator) throws Exception {
                        int count = 0;
                        while (tIterator.hasNext()) {
                            tIterator.next();
                            ++count;
                        }
                        return Collections.singletonList(count).iterator();
                    }
                }, true).collect();
                Integer totalCount = 0;
                for (Integer i2 : partitionCounts) {
                    totalCount = totalCount + i2;
                }
                ArrayList<Double> partitionWeights = new ArrayList<Double>(Math.max(numPartitions, origNumPartitions));
                Double ideal = (double)totalCount.intValue() / (double)numPartitions;
                for (i = 0; i < Math.min(origNumPartitions, numPartitions); ++i) {
                    partitionWeights.add((double)((Integer)partitionCounts.get(i)).intValue() / ideal);
                }
                for (i = Math.min(origNumPartitions, numPartitions); i < Math.max(origNumPartitions, numPartitions); ++i) {
                    if (i >= numPartitions) {
                        partitionWeights.add(-1.0);
                        continue;
                    }
                    partitionWeights.add(0.0);
                }
                JavaPairRDD indexedRDD = rdd.zipWithUniqueId().mapToPair(new PairFunction<Tuple2<T, Long>, Tuple2<Long, Integer>, T>(){

                    public Tuple2<Tuple2<Long, Integer>, T> call(Tuple2<T, Long> tLongTuple2) {
                        return new Tuple2((Object)new Tuple2(tLongTuple2._2(), (Object)0), tLongTuple2._1());
                    }
                });
                HashingBalancedPartitioner hbp = new HashingBalancedPartitioner(Collections.singletonList(partitionWeights));
                JavaPairRDD partitionedRDD = indexedRDD.partitionBy((Partitioner)hbp);
                return partitionedRDD.map(new Function<Tuple2<Tuple2<Long, Integer>, T>, T>(){

                    public T call(Tuple2<Tuple2<Long, Integer>, T> indexNPayload) {
                        return indexNPayload._2();
                    }
                });
            }
        }
        throw new RuntimeException("Unknown setting for repartition: " + (Object)((Object)repartition));
    }

    public static <T> JavaRDD<T> repartitionBalanceIfRequired(JavaRDD<T> rdd, Repartition repartition, int objectsPerPartition, int numPartitions) {
        int origNumPartitions = rdd.partitions().size();
        switch (repartition) {
            case Never: {
                return rdd;
            }
            case NumPartitionsWorkersDiffers: {
                if (origNumPartitions == numPartitions) {
                    return rdd;
                }
            }
            case Always: {
                List partitionCounts = rdd.mapPartitionsWithIndex(new CountPartitionsFunction(), true).collect();
                int totalObjects = 0;
                int initialPartitions = partitionCounts.size();
                boolean allCorrectSize = true;
                boolean x = false;
                for (Tuple2 t2 : partitionCounts) {
                    int partitionSize = (Integer)t2._2();
                    allCorrectSize &= partitionSize == objectsPerPartition;
                    totalObjects += ((Integer)t2._2()).intValue();
                }
                if (numPartitions * objectsPerPartition < totalObjects) {
                    allCorrectSize = true;
                    for (Tuple2 t2 : partitionCounts) {
                        allCorrectSize &= (Integer)t2._2() == objectsPerPartition;
                    }
                }
                if (initialPartitions == numPartitions && allCorrectSize) {
                    return rdd;
                }
                JavaPairRDD pairIndexed = SparkUtils.indexedRDD(rdd);
                int remainder = (totalObjects - numPartitions * objectsPerPartition) % numPartitions;
                pairIndexed = pairIndexed.partitionBy((Partitioner)new BalancedPartitioner(numPartitions, objectsPerPartition, remainder));
                return pairIndexed.values();
            }
        }
        throw new RuntimeException("Unknown setting for repartition: " + (Object)((Object)repartition));
    }

    static <T> JavaPairRDD<Integer, T> indexedRDD(JavaRDD<T> rdd) {
        return rdd.zipWithIndex().mapToPair(new PairFunction<Tuple2<T, Long>, Integer, T>(){

            public Tuple2<Integer, T> call(Tuple2<T, Long> elemIdx) {
                return new Tuple2((Object)((Long)elemIdx._2()).intValue(), elemIdx._1());
            }
        });
    }

    public static <T> JavaRDD<T>[] balancedRandomSplit(int totalObjectCount, int numObjectsPerSplit, JavaRDD<T> data) {
        return SparkUtils.balancedRandomSplit(totalObjectCount, numObjectsPerSplit, data, new Random().nextLong());
    }

    public static <T> JavaRDD<T>[] balancedRandomSplit(int totalObjectCount, int numObjectsPerSplit, JavaRDD<T> data, long rngSeed) {
        JavaRDD[] splits;
        if (totalObjectCount <= numObjectsPerSplit) {
            splits = (JavaRDD[])Array.newInstance(JavaRDD.class, 1);
            splits[0] = data;
        } else {
            int numSplits = totalObjectCount / numObjectsPerSplit;
            splits = (JavaRDD[])Array.newInstance(JavaRDD.class, numSplits);
            for (int i = 0; i < numSplits; ++i) {
                splits[i] = data.mapPartitionsWithIndex(new SplitPartitionsFunction(i, numSplits, rngSeed), true);
            }
        }
        return splits;
    }

    public static <T, U> JavaPairRDD<T, U>[] balancedRandomSplit(int totalObjectCount, int numObjectsPerSplit, JavaPairRDD<T, U> data) {
        return SparkUtils.balancedRandomSplit(totalObjectCount, numObjectsPerSplit, data, new Random().nextLong());
    }

    public static <T, U> JavaPairRDD<T, U>[] balancedRandomSplit(int totalObjectCount, int numObjectsPerSplit, JavaPairRDD<T, U> data, long rngSeed) {
        JavaPairRDD[] splits;
        if (totalObjectCount <= numObjectsPerSplit) {
            splits = (JavaPairRDD[])Array.newInstance(JavaPairRDD.class, 1);
            splits[0] = data;
        } else {
            int numSplits = totalObjectCount / numObjectsPerSplit;
            splits = (JavaPairRDD[])Array.newInstance(JavaPairRDD.class, numSplits);
            for (int i = 0; i < numSplits; ++i) {
                JavaRDD split = data.mapPartitionsWithIndex(new SplitPartitionsFunction2(i, numSplits, rngSeed), true);
                splits[i] = split.mapPartitionsToPair(new MapTupleToPairFlatMap(), true);
            }
        }
        return splits;
    }

    public static JavaRDD<String> listPaths(JavaSparkContext sc, String path) throws IOException {
        ArrayList<String> paths = new ArrayList<String>();
        Configuration config = new Configuration();
        FileSystem hdfs = FileSystem.get((URI)URI.create(path), (Configuration)config);
        RemoteIterator fileIter = hdfs.listFiles(new Path(path), false);
        while (fileIter.hasNext()) {
            String filePath = ((LocatedFileStatus)fileIter.next()).getPath().toString();
            paths.add(filePath);
        }
        return sc.parallelize(paths);
    }

    public static JavaRDD<DataSet> shuffleExamples(JavaRDD<DataSet> rdd, int newBatchSize, int numPartitions) {
        JavaPairRDD singleExampleDataSets = rdd.flatMapToPair((PairFlatMapFunction)new SplitDataSetExamplesPairFlatMapFunction(numPartitions));
        singleExampleDataSets = singleExampleDataSets.partitionBy((Partitioner)new HashPartitioner(numPartitions));
        return singleExampleDataSets.values().mapPartitions((FlatMapFunction)new BatchDataSetsFunction(newBatchSize));
    }
}

