/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.earlystopping.scorecalc;

import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.dataset.DataSet;

public class DataSetLossCalculator
implements ScoreCalculator {
    private DataSetIterator dataSetIterator;
    private boolean average;

    public DataSetLossCalculator(DataSetIterator dataSetIterator, boolean average) {
        this.dataSetIterator = dataSetIterator;
        this.average = average;
    }

    @Override
    public double calculateScore(MultiLayerNetwork network) {
        this.dataSetIterator.reset();
        double lossSum = 0.0;
        int exCount = 0;
        while (this.dataSetIterator.hasNext()) {
            DataSet dataSet = (DataSet)this.dataSetIterator.next();
            int nEx = dataSet.getFeatureMatrix().size(0);
            lossSum += network.score(dataSet) * (double)nEx;
            exCount += nEx;
        }
        if (this.average) {
            return lossSum / (double)exCount;
        }
        return lossSum;
    }

    public String toString() {
        return "DataSetLossCalculator(" + this.dataSetIterator + ",average=" + this.average + ")";
    }
}

