/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.impl.layer;

import java.io.Serializable;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.datavec.api.records.reader.RecordReader;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.spark.datavec.RecordReaderFunction;
import org.deeplearning4j.spark.impl.common.Add;
import org.deeplearning4j.spark.impl.layer.IterativeReduceFlatMap;
import org.deeplearning4j.spark.util.MLLibUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import parquet.org.slf4j.Logger;
import parquet.org.slf4j.LoggerFactory;

public class SparkDl4jLayer
implements Serializable {
    private transient SparkContext sparkContext;
    private transient JavaSparkContext sc;
    private NeuralNetConfiguration conf;
    private Layer layer;
    private Broadcast<INDArray> params;
    private boolean averageEachIteration = false;
    private static Logger log = LoggerFactory.getLogger(SparkDl4jLayer.class);

    public SparkDl4jLayer(SparkContext sparkContext, NeuralNetConfiguration conf) {
        this.sparkContext = sparkContext;
        this.conf = conf.clone();
        this.sc = new JavaSparkContext(this.sparkContext);
    }

    public SparkDl4jLayer(JavaSparkContext sc, NeuralNetConfiguration conf) {
        this.sc = sc;
        this.conf = conf.clone();
    }

    public Layer fit(String path, int labelIndex, RecordReader recordReader) {
        FeedForwardLayer ffLayer = (FeedForwardLayer)this.conf.getLayer();
        JavaRDD lines = this.sc.textFile(path);
        JavaRDD points = lines.map((Function)new RecordReaderFunction(recordReader, labelIndex, ffLayer.getNOut()));
        return this.fitDataSet((JavaRDD<DataSet>)points);
    }

    public Layer fit(JavaSparkContext sc, JavaRDD<LabeledPoint> rdd) {
        FeedForwardLayer ffLayer = (FeedForwardLayer)this.conf.getLayer();
        return this.fitDataSet(MLLibUtil.fromLabeledPoint(sc, rdd, ffLayer.getNOut()));
    }

    public Layer fitDataSet(JavaRDD<DataSet> rdd) {
        int iterations = this.conf.getNumIterations();
        long count = rdd.count();
        log.info("Running distributed training averaging each iteration " + this.averageEachIteration + " and " + rdd.partitions().size() + " partitions");
        if (!this.averageEachIteration) {
            int numParams = this.conf.getLayer().initializer().numParams(this.conf);
            INDArray params = Nd4j.create((int)1, (int)numParams);
            Layer layer = this.conf.getLayer().instantiate(this.conf, null, 0, params, true);
            layer.setBackpropGradientsViewArray(Nd4j.create((int)1, (int)numParams));
            this.params = this.sc.broadcast((Object)params);
            log.info("Broadcasting initial parameters of length " + params.length());
            int paramsLength = layer.numParams();
            if (params.length() != paramsLength) {
                throw new IllegalStateException("Number of params " + paramsLength + " was not equal to " + params.length());
            }
            JavaRDD results = rdd.sample(true, 0.4).mapPartitions((FlatMapFunction)new IterativeReduceFlatMap(this.conf.toJson(), this.params));
            log.debug("Ran iterative reduce...averaging results now.");
            INDArray newParams = (INDArray)results.fold((Object)Nd4j.zeros((int[])((INDArray)results.first()).shape()), (Function2)new Add());
            newParams.divi((Number)rdd.partitions().size());
            layer.setParams(newParams);
            this.layer = layer;
        } else {
            this.conf.setNumIterations(1);
            int numParams = this.conf.getLayer().initializer().numParams(this.conf);
            INDArray params = Nd4j.create((int)1, (int)numParams);
            Layer layer = this.conf.getLayer().instantiate(this.conf, null, 0, params, true);
            layer.setBackpropGradientsViewArray(Nd4j.create((int)1, (int)numParams));
            this.params = this.sc.broadcast((Object)params);
            for (int i = 0; i < iterations; ++i) {
                JavaRDD results = rdd.sample(true, 0.3).mapPartitions((FlatMapFunction)new IterativeReduceFlatMap(this.conf.toJson(), this.params));
                int paramsLength = layer.numParams();
                if (params.length() != paramsLength) {
                    throw new IllegalStateException("Number of params " + paramsLength + " was not equal to " + params.length());
                }
                INDArray newParams = (INDArray)results.fold((Object)Nd4j.zeros((int[])((INDArray)results.first()).shape()), (Function2)new Add());
                newParams.divi((Number)rdd.partitions().size());
            }
            layer.setParams(((INDArray)this.params.value()).dup());
            this.layer = layer;
        }
        return this.layer;
    }

    public Matrix predict(Matrix features) {
        return MLLibUtil.toMatrix(this.layer.activate(MLLibUtil.toMatrix(features)));
    }

    public Vector predict(Vector point) {
        return MLLibUtil.toVector(this.layer.activate(MLLibUtil.toVector(point)));
    }

    public static Layer train(JavaRDD<LabeledPoint> data, NeuralNetConfiguration conf) {
        SparkDl4jLayer multiLayer = new SparkDl4jLayer(data.context(), conf);
        return multiLayer.fit(new JavaSparkContext(data.context()), data);
    }
}

