package org.deeplearning4j.earlystopping.scorecalc.base;

import lombok.NonNull;
import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator;
import org.deeplearning4j.nn.api.Model;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

/* loaded from: input_file:org/deeplearning4j/earlystopping/scorecalc/base/BaseScoreCalculator.class */
public abstract class BaseScoreCalculator<T extends Model> implements ScoreCalculator<T> {
    protected MultiDataSetIterator mdsIterator;
    protected DataSetIterator iterator;
    protected double scoreSum;
    protected int minibatchCount;
    protected int exampleCount;

    /* JADX INFO: Access modifiers changed from: protected */
    public BaseScoreCalculator(@NonNull DataSetIterator dataSetIterator) {
        if (dataSetIterator == null) {
            throw new NullPointerException("iterator");
        }
        this.iterator = dataSetIterator;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public BaseScoreCalculator(@NonNull MultiDataSetIterator multiDataSetIterator) {
        if (multiDataSetIterator == null) {
            throw new NullPointerException("iterator");
        }
        this.mdsIterator = multiDataSetIterator;
    }

    @Override // org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator
    public double calculateScore(T t) {
        reset();
        if (this.iterator != null) {
            if (!this.iterator.hasNext()) {
                this.iterator.reset();
            }
            while (this.iterator.hasNext()) {
                DataSet dataSet = (DataSet) this.iterator.next();
                this.scoreSum += scoreMinibatch((BaseScoreCalculator<T>) t, dataSet.getFeatures(), dataSet.getLabels(), dataSet.getFeaturesMaskArray(), dataSet.getLabelsMaskArray(), output((BaseScoreCalculator<T>) t, dataSet.getFeatures(), dataSet.getFeaturesMaskArray(), dataSet.getLabelsMaskArray()));
                this.minibatchCount++;
                this.exampleCount += dataSet.getFeatures().size(0);
            }
        } else {
            if (!this.mdsIterator.hasNext()) {
                this.mdsIterator.reset();
            }
            while (this.mdsIterator.hasNext()) {
                MultiDataSet multiDataSet = (MultiDataSet) this.mdsIterator.next();
                this.scoreSum += scoreMinibatch((BaseScoreCalculator<T>) t, multiDataSet.getFeatures(), multiDataSet.getLabels(), multiDataSet.getFeaturesMaskArrays(), multiDataSet.getLabelsMaskArrays(), output((BaseScoreCalculator<T>) t, multiDataSet.getFeatures(), multiDataSet.getFeaturesMaskArrays(), multiDataSet.getLabelsMaskArrays()));
                this.minibatchCount++;
                this.exampleCount += multiDataSet.getFeatures(0).size(0);
            }
        }
        return finalScore(this.scoreSum, this.minibatchCount, this.exampleCount);
    }

    protected abstract void reset();

    protected abstract INDArray output(T t, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3);

    protected abstract INDArray[] output(T t, INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2, INDArray[] iNDArrayArr3);

    /* JADX INFO: Access modifiers changed from: protected */
    public double scoreMinibatch(T t, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, INDArray iNDArray5) {
        return scoreMinibatch((BaseScoreCalculator<T>) t, arr(iNDArray), arr(iNDArray2), arr(iNDArray3), arr(iNDArray4), arr(iNDArray5));
    }

    protected abstract double scoreMinibatch(T t, INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2, INDArray[] iNDArrayArr3, INDArray[] iNDArrayArr4, INDArray[] iNDArrayArr5);

    protected abstract double finalScore(double d, int i, int i2);

    public static INDArray[] arr(INDArray iNDArray) {
        if (iNDArray == null) {
            return null;
        }
        return new INDArray[]{iNDArray};
    }

    public static INDArray get0(INDArray[] iNDArrayArr) {
        if (iNDArrayArr == null) {
            return null;
        }
        if (iNDArrayArr.length != 1) {
            throw new IllegalStateException("Expected length 1 array here: got length " + iNDArrayArr.length);
        }
        return iNDArrayArr[0];
    }
}
