package org.deeplearning4j.earlystopping.scorecalc.base;

import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

/* loaded from: input_file:org/deeplearning4j/earlystopping/scorecalc/base/BaseIEvaluationScoreCalculator.class */
public abstract class BaseIEvaluationScoreCalculator<T extends Model, U extends IEvaluation> implements ScoreCalculator<T> {
    protected MultiDataSetIterator iterator;
    protected DataSetIterator iter;

    /* JADX INFO: Access modifiers changed from: protected */
    public BaseIEvaluationScoreCalculator(MultiDataSetIterator multiDataSetIterator) {
        this.iterator = multiDataSetIterator;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public BaseIEvaluationScoreCalculator(DataSetIterator dataSetIterator) {
        this.iter = dataSetIterator;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v13, types: [org.nd4j.evaluation.IEvaluation[]] */
    /* JADX WARN: Type inference failed for: r0v14 */
    /* JADX WARN: Type inference failed for: r0v25, types: [org.nd4j.evaluation.IEvaluation[]] */
    /* JADX WARN: Type inference failed for: r0v26 */
    @Override // org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator
    public double calculateScore(T t) {
        U u;
        U newEval = newEval();
        if (t instanceof MultiLayerNetwork) {
            u = ((MultiLayerNetwork) t).doEvaluation(this.iter != null ? this.iter : new MultiDataSetWrapperIterator(this.iterator), newEval)[0];
        } else {
            if (!(t instanceof ComputationGraph)) {
                throw new RuntimeException("Unknown model type: " + t.getClass());
            }
            u = ((ComputationGraph) t).doEvaluation(this.iterator != null ? this.iterator : new MultiDataSetIteratorAdapter(this.iter), newEval)[0];
        }
        return finalScore(u);
    }

    protected abstract U newEval();

    protected abstract double finalScore(U u);
}
