/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.impl.multilayer.scoring;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import org.apache.spark.broadcast.Broadcast;
import org.datavec.spark.functions.FlatMapFunctionAdapter;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.spark.impl.multilayer.scoring.ScoreExamplesFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

class ScoreExamplesFunctionAdapter
implements FlatMapFunctionAdapter<Iterator<DataSet>, Double> {
    protected static Logger log = LoggerFactory.getLogger(ScoreExamplesFunction.class);
    private final Broadcast<INDArray> params;
    private final Broadcast<String> jsonConfig;
    private final boolean addRegularization;
    private final int batchSize;

    public ScoreExamplesFunctionAdapter(Broadcast<INDArray> params, Broadcast<String> jsonConfig, boolean addRegularizationTerms, int batchSize) {
        this.params = params;
        this.jsonConfig = jsonConfig;
        this.addRegularization = addRegularizationTerms;
        this.batchSize = batchSize;
    }

    public Iterable<Double> call(Iterator<DataSet> iterator) throws Exception {
        if (!iterator.hasNext()) {
            return Collections.emptyList();
        }
        MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson((String)((String)this.jsonConfig.getValue())));
        network.init();
        INDArray val = ((INDArray)this.params.value()).unsafeDuplication();
        if (val.length() != network.numParams(false)) {
            throw new IllegalStateException("Network did not have same number of parameters as the broadcast set parameters");
        }
        network.setParameters(val);
        ArrayList<Double> ret = new ArrayList<Double>();
        ArrayList<DataSet> collect = new ArrayList<DataSet>(this.batchSize);
        int totalCount = 0;
        while (iterator.hasNext()) {
            double[] doubleScores;
            int nExamples;
            int n;
            collect.clear();
            for (nExamples = 0; iterator.hasNext() && nExamples < this.batchSize; nExamples += n) {
                DataSet ds = iterator.next();
                n = ds.numExamples();
                collect.add(ds);
            }
            totalCount += nExamples;
            DataSet data = DataSet.merge(collect);
            INDArray scores = network.scoreExamples(data, this.addRegularization);
            for (double doubleScore : doubleScores = scores.data().asDouble()) {
                ret.add(doubleScore);
            }
        }
        Nd4j.getExecutioner().commit();
        if (log.isDebugEnabled()) {
            log.debug("Scored {} examples ", (Object)totalCount);
        }
        return ret;
    }
}

