/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.impl.layer;

import org.apache.spark.api.java.function.Function;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.LayerFactory;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.layers.OutputLayer;
import org.deeplearning4j.nn.layers.factory.LayerFactories;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;

public class DL4jWorker
implements Function<org.nd4j.linalg.dataset.DataSet, INDArray> {
    private final Model network;

    public DL4jWorker(String json, INDArray params) {
        NeuralNetConfiguration conf = NeuralNetConfiguration.fromJson((String)json);
        LayerFactory layerFactory = LayerFactories.getFactory((org.deeplearning4j.nn.conf.layers.Layer)conf.getLayer());
        if (layerFactory == null) {
            throw new IllegalStateException("Please specify a layer factory");
        }
        this.network = layerFactory.create(conf);
        int numParams = this.network.numParams();
        if (numParams != params.length()) {
            throw new IllegalStateException("Number of params for configured network was " + numParams + " while the specified parameter vector length was " + params.length());
        }
        Layer network = (Layer)this.network;
        network.setParams(params);
    }

    public INDArray call(org.nd4j.linalg.dataset.DataSet v1) throws Exception {
        try {
            Layer network = (Layer)this.network;
            if (network instanceof OutputLayer) {
                OutputLayer o = (OutputLayer)network;
                o.fit((DataSet)v1);
            } else {
                network.fit(v1.getFeatureMatrix());
            }
            return network.params();
        }
        catch (Exception e) {
            System.err.println("Error with dataset " + v1.numExamples());
            throw e;
        }
    }
}

