/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.impl.multilayer;

import java.io.Serializable;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.canova.api.records.reader.RecordReader;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.spark.canova.RecordReaderFunction;
import org.deeplearning4j.spark.impl.common.Adder;
import org.deeplearning4j.spark.impl.multilayer.IterativeReduceFlatMap;
import org.deeplearning4j.spark.util.MLLibUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SparkDl4jMultiLayer
implements Serializable {
    private transient SparkContext sparkContext;
    private transient JavaSparkContext sc;
    private MultiLayerConfiguration conf;
    private MultiLayerNetwork network;
    private Broadcast<INDArray> params;
    private boolean averageEachIteration = false;
    public static final String AVERAGE_EACH_ITERATION = "org.deeplearning4j.spark.iteration.average";
    private static final Logger log = LoggerFactory.getLogger(SparkDl4jMultiLayer.class);

    public SparkDl4jMultiLayer(SparkContext sparkContext, MultiLayerNetwork network) {
        this.sparkContext = sparkContext;
        this.averageEachIteration = sparkContext.conf().getBoolean(AVERAGE_EACH_ITERATION, false);
        this.network = network;
        this.conf = this.network.getLayerWiseConfigurations().clone();
        this.sc = new JavaSparkContext(this.sparkContext);
        this.params = this.sc.broadcast((Object)network.params());
    }

    public SparkDl4jMultiLayer(SparkContext sparkContext, MultiLayerConfiguration conf) {
        this.sparkContext = sparkContext;
        this.conf = conf.clone();
        this.averageEachIteration = sparkContext.conf().getBoolean(AVERAGE_EACH_ITERATION, false);
        this.sc = new JavaSparkContext(this.sparkContext);
    }

    public SparkDl4jMultiLayer(JavaSparkContext sc, MultiLayerConfiguration conf) {
        this(sc.sc(), conf);
    }

    public MultiLayerNetwork fit(String path, int labelIndex, RecordReader recordReader) {
        JavaRDD lines = this.sc.textFile(path);
        FeedForwardLayer outputLayer = (FeedForwardLayer)this.conf.getConf(this.conf.getConfs().size() - 1).getLayer();
        JavaRDD points = lines.map((Function)new RecordReaderFunction(recordReader, labelIndex, outputLayer.getNOut()));
        return this.fitDataSet((JavaRDD<DataSet>)points);
    }

    public MultiLayerNetwork getNetwork() {
        return this.network;
    }

    public void setNetwork(MultiLayerNetwork network) {
        this.network = network;
    }

    public Matrix predict(Matrix features) {
        return MLLibUtil.toMatrix(this.network.output(MLLibUtil.toMatrix(features)));
    }

    public Vector predict(Vector point) {
        return MLLibUtil.toVector(this.network.output(MLLibUtil.toVector(point)));
    }

    public MultiLayerNetwork fit(JavaRDD<LabeledPoint> rdd, int batchSize) {
        FeedForwardLayer outputLayer = (FeedForwardLayer)this.conf.getConf(this.conf.getConfs().size() - 1).getLayer();
        return this.fitDataSet(MLLibUtil.fromLabeledPoint(rdd, outputLayer.getNOut(), batchSize));
    }

    public MultiLayerNetwork fit(JavaSparkContext sc, JavaRDD<LabeledPoint> rdd) {
        FeedForwardLayer outputLayer = (FeedForwardLayer)this.conf.getConf(this.conf.getConfs().size() - 1).getLayer();
        return this.fitDataSet(MLLibUtil.fromLabeledPoint(sc, rdd, outputLayer.getNOut()));
    }

    public MultiLayerNetwork fitDataSet(JavaRDD<DataSet> rdd) {
        int iterations = this.conf.getConf(0).getNumIterations();
        log.info("Running distributed training averaging each iteration " + this.averageEachIteration + " and " + rdd.partitions().size() + " partitions");
        if (!this.averageEachIteration) {
            this.runIteration(rdd);
        } else {
            for (NeuralNetConfiguration conf : this.conf.getConfs()) {
                conf.setNumIterations(1);
            }
            MultiLayerNetwork network = new MultiLayerNetwork(this.conf);
            network.init();
            INDArray params = network.params();
            this.params = this.sc.broadcast((Object)params);
            for (int i = 0; i < iterations; ++i) {
                this.runIteration(rdd);
            }
        }
        return this.network;
    }

    private void runIteration(JavaRDD<DataSet> rdd) {
        MultiLayerNetwork network = new MultiLayerNetwork(this.conf);
        network.init();
        INDArray params = network.params();
        this.params = this.sc.broadcast((Object)params);
        log.info("Broadcasting initial parameters of length " + params.length());
        int paramsLength = network.numParams();
        if (params.length() != paramsLength) {
            throw new IllegalStateException("Number of params " + paramsLength + " was not equal to " + params.length());
        }
        JavaRDD results = rdd.mapPartitions((FlatMapFunction)new IterativeReduceFlatMap(this.conf.toJson(), this.params), true).cache();
        log.info("Ran iterative reduce...averaging results now.");
        Adder a = new Adder(params.length());
        results.foreach((VoidFunction)a);
        INDArray newParams = (INDArray)a.getAccumulator().value();
        log.info("Accumulated parameters");
        newParams.divi((Number)rdd.partitions().size());
        log.info("Divided by partitions");
        network.setParameters(newParams);
        log.info("Set parameters");
        this.network = network;
    }

    public static MultiLayerNetwork train(JavaRDD<LabeledPoint> data, MultiLayerConfiguration conf) {
        SparkDl4jMultiLayer multiLayer = new SparkDl4jMultiLayer(data.context(), conf);
        return multiLayer.fit(new JavaSparkContext(data.context()), data);
    }
}

