/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.impl.multilayer.scoring;

import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.spark.impl.common.score.BaseVaeScoreWithKeyFunctionAdapter;
import org.nd4j.linalg.api.ndarray.INDArray;

class VaeReconstructionErrorWithKeyFunctionAdapter<K>
extends BaseVaeScoreWithKeyFunctionAdapter<K> {
    public VaeReconstructionErrorWithKeyFunctionAdapter(Broadcast<INDArray> params, Broadcast<String> jsonConfig, int batchSize) {
        super(params, jsonConfig, batchSize);
    }

    @Override
    public VariationalAutoencoder getVaeLayer() {
        MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson((String)((String)this.jsonConfig.getValue())));
        network.init();
        INDArray val = ((INDArray)this.params.value()).unsafeDuplication();
        if (val.length() != network.numParams(false)) {
            throw new IllegalStateException("Network did not have same number of parameters as the broadcast set parameters");
        }
        network.setParameters(val);
        Layer l = network.getLayer(0);
        if (!(l instanceof VariationalAutoencoder)) {
            throw new RuntimeException("Cannot use VaeReconstructionErrorWithKeyFunction on network that doesn't have a VAE layer as layer 0. Layer type: " + l.getClass());
        }
        return (VariationalAutoencoder)l;
    }

    @Override
    public INDArray computeScore(VariationalAutoencoder vae, INDArray toScore) {
        return vae.reconstructionError(toScore);
    }
}

