/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.ml.reconstruction;

import org.apache.spark.SparkContext;
import org.apache.spark.annotation.DeveloperApi;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamPair;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.storage.StorageLevel;
import org.apache.spark.storage.StorageLevel$;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.spark.ml.UnsupervisedLearner;
import org.deeplearning4j.spark.ml.UnsupervisedLearnerParams$class;
import org.deeplearning4j.spark.ml.nn.ParameterAveragingTrainingStrategy;
import org.deeplearning4j.spark.ml.param.shared.HasEpochs$class;
import org.deeplearning4j.spark.ml.param.shared.HasLayerIndex$class;
import org.deeplearning4j.spark.ml.param.shared.HasMultiLayerConfiguration$class;
import org.deeplearning4j.spark.ml.param.shared.HasReconstructionCol$class;
import org.deeplearning4j.spark.ml.reconstruction.NeuralNetworkReconstructionModel;
import org.deeplearning4j.spark.ml.reconstruction.NeuralNetworkReconstructionParams;
import org.deeplearning4j.spark.ml.reconstruction.NeuralNetworkReconstructionParams$class;
import org.deeplearning4j.spark.ml.util.Identifiable$;
import org.deeplearning4j.spark.util.package$conversions$;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import scala.Function1;
import scala.Predef$;
import scala.Serializable;
import scala.collection.Iterator;
import scala.collection.Seq;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

@DeveloperApi
@ScalaSignature(bytes="\u0006\u000114A!\u0001\u0002\u0001\u001b\tYb*Z;sC2tU\r^<pe.\u0014VmY8ogR\u0014Xo\u0019;j_:T!a\u0001\u0003\u0002\u001dI,7m\u001c8tiJ,8\r^5p]*\u0011QAB\u0001\u0003[2T!a\u0002\u0005\u0002\u000bM\u0004\u0018M]6\u000b\u0005%Q\u0011A\u00043fKBdW-\u0019:oS:<GG\u001b\u0006\u0002\u0017\u0005\u0019qN]4\u0004\u0001M\u0019\u0001A\u0004\u0012\u0011\u000b=\u0001\"#H\u0010\u000e\u0003\u0011I!!\u0005\u0003\u0003'Us7/\u001e9feZL7/\u001a3MK\u0006\u0014h.\u001a:\u0011\u0005MYR\"\u0001\u000b\u000b\u0005U1\u0012A\u00027j]\u0006dwM\u0003\u0002\u00181\u0005)Q\u000e\u001c7jE*\u0011q!\u0007\u0006\u00035)\ta!\u00199bG\",\u0017B\u0001\u000f\u0015\u0005\u00191Vm\u0019;peB\u0011a\u0004A\u0007\u0002\u0005A\u0011a\u0004I\u0005\u0003C\t\u0011\u0001ET3ve\u0006dg*\u001a;x_J\\'+Z2p]N$(/^2uS>tWj\u001c3fYB\u0011adI\u0005\u0003I\t\u0011\u0011ET3ve\u0006dg*\u001a;x_J\\'+Z2p]N$(/^2uS>t\u0007+\u0019:b[ND\u0001B\n\u0001\u0003\u0006\u0004%\teJ\u0001\u0004k&$W#\u0001\u0015\u0011\u0005%zcB\u0001\u0016.\u001b\u0005Y#\"\u0001\u0017\u0002\u000bM\u001c\u0017\r\\1\n\u00059Z\u0013A\u0002)sK\u0012,g-\u0003\u00021c\t11\u000b\u001e:j]\u001eT!AL\u0016\t\u0011M\u0002!\u0011!Q\u0001\n!\nA!^5eA!)Q\u0007\u0001C\u0001m\u00051A(\u001b8jiz\"\"!H\u001c\t\u000b\u0019\"\u0004\u0019\u0001\u0015\t\u000bU\u0002A\u0011A\u001d\u0015\u0003uAQa\u000f\u0001\u0005\u0002q\nqa]3u\u0007>tg\r\u0006\u0002>}5\t\u0001\u0001C\u0003@u\u0001\u0007\u0001&A\u0003wC2,X\rC\u0003<\u0001\u0011\u0005\u0011\t\u0006\u0002>\u0005\")q\b\u0011a\u0001\u0007B\u0011A)S\u0007\u0002\u000b*\u0011aiR\u0001\u0005G>tgM\u0003\u0002I\u0011\u0005\u0011aN\\\u0005\u0003\u0015\u0016\u0013q#T;mi&d\u0015-_3s\u0007>tg-[4ve\u0006$\u0018n\u001c8\t\u000b1\u0003A\u0011A'\u0002\u0013M,G/\u00129pG\"\u001cHCA\u001fO\u0011\u0015y4\n1\u0001P!\tQ\u0003+\u0003\u0002RW\t\u0019\u0011J\u001c;\t\u000bM\u0003A\u0011\u0001+\u0002\u001bM,G\u000fT1zKJLe\u000eZ3y)\tiT\u000bC\u0003@%\u0002\u0007q\nC\u0003X\u0001\u0011\u0005\u0001,\u0001\u000btKR\u0014VmY8ogR\u0014Xo\u0019;j_:\u001cu\u000e\u001c\u000b\u0003{eCQa\u0010,A\u0002!BQa\u0017\u0001\u0005Rq\u000bQ\u0001\\3be:$\"aH/\t\u000byS\u0006\u0019A0\u0002\u000f\u0011\fG/Y:fiB\u0011\u0001mY\u0007\u0002C*\u0011!\rG\u0001\u0004gFd\u0017B\u00013b\u0005%!\u0015\r^1Ge\u0006lW\r\u000b\u0002\u0001MB\u0011qM[\u0007\u0002Q*\u0011\u0011\u000eG\u0001\u000bC:tw\u000e^1uS>t\u0017BA6i\u00051!UM^3m_B,'/\u00119j\u0001")
public class NeuralNetworkReconstruction
extends UnsupervisedLearner<Vector, NeuralNetworkReconstruction, NeuralNetworkReconstructionModel>
implements NeuralNetworkReconstructionParams {
    private final String uid;
    private final Param<String> reconstructionCol;
    private final IntParam layerIndex;
    private final IntParam epochs;
    private final Param<String> conf;

    @Override
    public StructType org$deeplearning4j$spark$ml$reconstruction$NeuralNetworkReconstructionParams$$super$validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType) {
        return UnsupervisedLearnerParams$class.validateAndTransformSchema(this, schema, fitting, featuresDataType);
    }

    @Override
    public StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType) {
        return NeuralNetworkReconstructionParams$class.validateAndTransformSchema(this, schema, fitting, featuresDataType);
    }

    @Override
    public Param<String> reconstructionCol() {
        return this.reconstructionCol;
    }

    @Override
    public void org$deeplearning4j$spark$ml$param$shared$HasReconstructionCol$_setter_$reconstructionCol_$eq(Param x$1) {
        this.reconstructionCol = x$1;
    }

    @Override
    public String getReconstructionCol() {
        return HasReconstructionCol$class.getReconstructionCol(this);
    }

    @Override
    public IntParam layerIndex() {
        return this.layerIndex;
    }

    @Override
    public void org$deeplearning4j$spark$ml$param$shared$HasLayerIndex$_setter_$layerIndex_$eq(IntParam x$1) {
        this.layerIndex = x$1;
    }

    @Override
    public int getLayerIndex() {
        return HasLayerIndex$class.getLayerIndex(this);
    }

    @Override
    public IntParam epochs() {
        return this.epochs;
    }

    @Override
    public void org$deeplearning4j$spark$ml$param$shared$HasEpochs$_setter_$epochs_$eq(IntParam x$1) {
        this.epochs = x$1;
    }

    @Override
    public int getEpochs() {
        return HasEpochs$class.getEpochs(this);
    }

    @Override
    public Param<String> conf() {
        return this.conf;
    }

    @Override
    public void org$deeplearning4j$spark$ml$param$shared$HasMultiLayerConfiguration$_setter_$conf_$eq(Param x$1) {
        this.conf = x$1;
    }

    @Override
    public String getConf() {
        return HasMultiLayerConfiguration$class.getConf(this);
    }

    public String uid() {
        return this.uid;
    }

    public NeuralNetworkReconstruction setConf(String value) {
        return (NeuralNetworkReconstruction)this.set(this.conf(), value);
    }

    public NeuralNetworkReconstruction setConf(MultiLayerConfiguration value) {
        return (NeuralNetworkReconstruction)this.set(this.conf(), value.toJson());
    }

    public NeuralNetworkReconstruction setEpochs(int value) {
        return (NeuralNetworkReconstruction)this.set((Param)this.epochs(), BoxesRunTime.boxToInteger((int)value));
    }

    public NeuralNetworkReconstruction setLayerIndex(int value) {
        return (NeuralNetworkReconstruction)this.set((Param)this.layerIndex(), BoxesRunTime.boxToInteger((int)value));
    }

    public NeuralNetworkReconstruction setReconstructionCol(String value) {
        return (NeuralNetworkReconstruction)this.set(this.reconstructionCol(), value);
    }

    @Override
    public NeuralNetworkReconstructionModel learn(DataFrame dataset) {
        SQLContext sqlContext = dataset.sqlContext();
        SparkContext sc = sqlContext.sparkContext();
        MultiLayerConfiguration c = MultiLayerConfiguration.fromJson((String)((String)this.$(this.conf())));
        DataFrame prepared = dataset.select((String)this.$(this.featuresCol()), (Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[0]));
        StorageLevel storageLevel = dataset.rdd().getStorageLevel();
        StorageLevel storageLevel2 = StorageLevel$.MODULE$.NONE();
        boolean handlePersistence = !(storageLevel != null ? !storageLevel.equals(storageLevel2) : storageLevel2 != null);
        Object object = handlePersistence ? prepared.persist(StorageLevel$.MODULE$.MEMORY_AND_DISK()) : BoxedUnit.UNIT;
        ParameterAveragingTrainingStrategy trainingStrategy = new ParameterAveragingTrainingStrategy(c, BoxesRunTime.unboxToInt((Object)this.$((Param)this.epochs())));
        INDArray networkParams = trainingStrategy.train(prepared.rdd(), new Serializable(this){
            public static final long serialVersionUID = 0L;

            public final void apply(MultiLayerNetwork network, Iterator<Row> rows) {
                INDArray[] featureArrays = (INDArray[])rows.map((Function1)new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final INDArray apply(Row row) {
                        return package$conversions$.MODULE$.toINDArray((Vector)row.getAs(0));
                    }
                }).toArray(ClassTag$.MODULE$.apply(INDArray.class));
                if (featureArrays.length >= 1) {
                    INDArray featureMatrix = Nd4j.vstack((INDArray[])featureArrays);
                    network.fit(featureMatrix);
                }
            }
        });
        Object object2 = handlePersistence ? prepared.unpersist() : BoxedUnit.UNIT;
        return new NeuralNetworkReconstructionModel(this.uid(), (Broadcast<INDArray>)sc.broadcast((Object)networkParams, ClassTag$.MODULE$.apply(INDArray.class)));
    }

    public NeuralNetworkReconstruction(String uid) {
        this.uid = uid;
        HasMultiLayerConfiguration$class.$init$(this);
        HasEpochs$class.$init$(this);
        HasLayerIndex$class.$init$(this);
        HasReconstructionCol$class.$init$(this);
        NeuralNetworkReconstructionParams$class.$init$(this);
        this.setDefault((Seq)Predef$.MODULE$.wrapRefArray((Object[])new ParamPair[]{this.epochs().$minus$greater((Object)BoxesRunTime.boxToInteger((int)1))}));
        this.setDefault((Seq)Predef$.MODULE$.wrapRefArray((Object[])new ParamPair[]{this.layerIndex().$minus$greater((Object)BoxesRunTime.boxToInteger((int)1))}));
        this.setDefault((Seq)Predef$.MODULE$.wrapRefArray((Object[])new ParamPair[]{this.reconstructionCol().$minus$greater((Object)"reconstruction")}));
    }

    public NeuralNetworkReconstruction() {
        this(Identifiable$.MODULE$.randomUID("nnReconstruction"));
    }
}

