package org.deeplearning4j.arbiter.scoring.impl;

import lombok.NonNull;
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
import org.deeplearning4j.eval.ROC;
import org.deeplearning4j.eval.ROCBinary;
import org.deeplearning4j.eval.ROCMultiClass;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

/* loaded from: input_file:org/deeplearning4j/arbiter/scoring/impl/ROCScoreFunction.class */
public class ROCScoreFunction extends BaseNetScoreFunction {
    protected ROCType type;
    protected Metric metric;

    /* loaded from: input_file:org/deeplearning4j/arbiter/scoring/impl/ROCScoreFunction$Metric.class */
    public enum Metric {
        AUC,
        AUPRC
    }

    /* loaded from: input_file:org/deeplearning4j/arbiter/scoring/impl/ROCScoreFunction$ROCType.class */
    public enum ROCType {
        ROC,
        BINARY,
        MULTICLASS
    }

    public ROCScoreFunction(@NonNull ROCType rOCType, @NonNull Metric metric) {
        if (rOCType == null) {
            throw new NullPointerException("type");
        }
        if (metric == null) {
            throw new NullPointerException("metric");
        }
        this.type = rOCType;
        this.metric = metric;
    }

    public String toString() {
        return "ROCScoreFunction(type=" + this.type + ",metric=" + this.metric + ")";
    }

    @Override // org.deeplearning4j.arbiter.scoring.impl.BaseNetScoreFunction
    public double score(MultiLayerNetwork multiLayerNetwork, DataSetIterator dataSetIterator) {
        switch (this.type) {
            case ROC:
                ROC evaluateROC = multiLayerNetwork.evaluateROC(dataSetIterator);
                return this.metric == Metric.AUC ? evaluateROC.calculateAUC() : evaluateROC.calculateAUCPR();
            case BINARY:
                ROCBinary rOCBinary = multiLayerNetwork.doEvaluation(dataSetIterator, new ROCBinary[]{new ROCBinary()})[0];
                return this.metric == Metric.AUC ? rOCBinary.calculateAverageAuc() : rOCBinary.calculateAverageAUCPR();
            case MULTICLASS:
                ROCMultiClass evaluateROCMultiClass = multiLayerNetwork.evaluateROCMultiClass(dataSetIterator);
                return this.metric == Metric.AUC ? evaluateROCMultiClass.calculateAverageAUC() : evaluateROCMultiClass.calculateAverageAUCPR();
            default:
                throw new RuntimeException("Unknown type: " + this.type);
        }
    }

    @Override // org.deeplearning4j.arbiter.scoring.impl.BaseNetScoreFunction
    public double score(MultiLayerNetwork multiLayerNetwork, MultiDataSetIterator multiDataSetIterator) {
        return score(multiLayerNetwork, (DataSetIterator) new MultiDataSetWrapperIterator(multiDataSetIterator));
    }

    @Override // org.deeplearning4j.arbiter.scoring.impl.BaseNetScoreFunction
    public double score(ComputationGraph computationGraph, DataSetIterator dataSetIterator) {
        return score(computationGraph, (MultiDataSetIterator) new MultiDataSetIteratorAdapter(dataSetIterator));
    }

    @Override // org.deeplearning4j.arbiter.scoring.impl.BaseNetScoreFunction
    public double score(ComputationGraph computationGraph, MultiDataSetIterator multiDataSetIterator) {
        switch (this.type) {
            case ROC:
                ROC evaluateROC = computationGraph.evaluateROC(multiDataSetIterator);
                return this.metric == Metric.AUC ? evaluateROC.calculateAUC() : evaluateROC.calculateAUCPR();
            case BINARY:
                ROCBinary rOCBinary = computationGraph.doEvaluation(multiDataSetIterator, new ROCBinary[]{new ROCBinary()})[0];
                return this.metric == Metric.AUC ? rOCBinary.calculateAverageAuc() : rOCBinary.calculateAverageAUCPR();
            case MULTICLASS:
                ROCMultiClass evaluateROCMultiClass = computationGraph.evaluateROCMultiClass(multiDataSetIterator, 0);
                return this.metric == Metric.AUC ? evaluateROCMultiClass.calculateAverageAUC() : evaluateROCMultiClass.calculateAverageAUCPR();
            default:
                throw new RuntimeException("Unknown type: " + this.type);
        }
    }

    public boolean minimize() {
        return false;
    }

    public ROCType getType() {
        return this.type;
    }

    public Metric getMetric() {
        return this.metric;
    }

    public void setType(ROCType rOCType) {
        this.type = rOCType;
    }

    public void setMetric(Metric metric) {
        this.metric = metric;
    }

    @Override // org.deeplearning4j.arbiter.scoring.impl.BaseNetScoreFunction
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof ROCScoreFunction)) {
            return false;
        }
        ROCScoreFunction rOCScoreFunction = (ROCScoreFunction) obj;
        if (!rOCScoreFunction.canEqual(this) || !super.equals(obj)) {
            return false;
        }
        ROCType type = getType();
        ROCType type2 = rOCScoreFunction.getType();
        if (type == null) {
            if (type2 != null) {
                return false;
            }
        } else if (!type.equals(type2)) {
            return false;
        }
        Metric metric = getMetric();
        Metric metric2 = rOCScoreFunction.getMetric();
        return metric == null ? metric2 == null : metric.equals(metric2);
    }

    @Override // org.deeplearning4j.arbiter.scoring.impl.BaseNetScoreFunction
    protected boolean canEqual(Object obj) {
        return obj instanceof ROCScoreFunction;
    }

    @Override // org.deeplearning4j.arbiter.scoring.impl.BaseNetScoreFunction
    public int hashCode() {
        int hashCode = super.hashCode();
        ROCType type = getType();
        int hashCode2 = (hashCode * 59) + (type == null ? 43 : type.hashCode());
        Metric metric = getMetric();
        return (hashCode2 * 59) + (metric == null ? 43 : metric.hashCode());
    }

    protected ROCScoreFunction() {
    }
}
