package org.deeplearning4j.eval;

import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.deeplearning4j.eval.curves.PrecisionRecallCurve;
import org.deeplearning4j.eval.curves.RocCurve;
import org.deeplearning4j.eval.serde.ROCArraySerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;

/* loaded from: input_file:org/deeplearning4j/eval/ROCMultiClass.class */
public class ROCMultiClass extends BaseEvaluation<ROCMultiClass> {
    public static final int DEFAULT_STATS_PRECISION = 4;
    private int thresholdSteps;
    private boolean rocRemoveRedundantPts;

    @JsonSerialize(using = ROCArraySerializer.class)
    private ROC[] underlying;
    private List<String> labels;

    public ROCMultiClass() {
        this(0);
    }

    public ROCMultiClass(int i) {
        this(i, true);
    }

    public ROCMultiClass(int i, boolean z) {
        this.thresholdSteps = i;
        this.rocRemoveRedundantPts = z;
    }

    @Override // org.deeplearning4j.eval.IEvaluation
    public void reset() {
        this.underlying = null;
    }

    @Override // org.deeplearning4j.eval.IEvaluation
    public String stats() {
        return stats(4);
    }

    public String stats(int i) {
        StringBuilder sb = new StringBuilder();
        int i2 = 15;
        if (this.labels != null) {
            Iterator<String> it = this.labels.iterator();
            while (it.hasNext()) {
                i2 = Math.max(it.next().length(), i2);
            }
        }
        String str = "%-" + (i2 + 5) + "s%-12." + i + "f%-10d%-10d";
        sb.append(String.format("%-" + (i2 + 5) + "s%-12s%-10s%-10s", "Label", "AUC", "# Pos", "# Neg"));
        if (this.underlying != null) {
            for (int i3 = 0; i3 < this.underlying.length; i3++) {
                sb.append("\n").append(String.format(str, this.labels == null ? String.valueOf(i3) : this.labels.get(i3), Double.valueOf(calculateAUC(i3)), Long.valueOf(getCountActualPositive(i3)), Long.valueOf(getCountActualNegative(i3))));
            }
            sb.append("Average AUC: ").append(String.format("%-12." + i + "f", Double.valueOf(calculateAverageAUC())));
        } else {
            sb.append("\n-- No Data --\n");
        }
        return sb.toString();
    }

    @Override // org.deeplearning4j.eval.IEvaluation
    public void eval(INDArray iNDArray, INDArray iNDArray2) {
        if (iNDArray.rank() == 3 && iNDArray2.rank() == 3) {
            evalTimeSeries(iNDArray, iNDArray2);
        }
        if (iNDArray.rank() > 2 || iNDArray2.rank() > 2 || iNDArray.size(1) != iNDArray2.size(1)) {
            throw new IllegalArgumentException("Invalid input data shape: labels shape = " + Arrays.toString(iNDArray.shape()) + ", predictions shape = " + Arrays.toString(iNDArray2.shape()) + "; require rank 2 array with size(1) == 1 or 2");
        }
        int size = iNDArray.size(1);
        if (this.underlying == null) {
            this.underlying = new ROC[size];
            for (int i = 0; i < size; i++) {
                this.underlying[i] = new ROC(this.thresholdSteps, this.rocRemoveRedundantPts);
            }
        }
        if (this.underlying.length != iNDArray.size(1)) {
            throw new IllegalArgumentException("Cannot evaluate data: number of label classes does not match previous call. Got " + iNDArray.size(1) + " labels (from array shape " + Arrays.toString(iNDArray.shape()) + ") vs. expected number of label classes = " + this.underlying.length);
        }
        for (int i2 = 0; i2 < size; i2++) {
            this.underlying[i2].eval(iNDArray.getColumn(i2), iNDArray2.getColumn(i2));
        }
    }

    public RocCurve getRocCurve(int i) {
        assertIndex(i);
        return this.underlying[i].getRocCurve();
    }

    public PrecisionRecallCurve getPrecisionRecallCurve(int i) {
        assertIndex(i);
        return this.underlying[i].getPrecisionRecallCurve();
    }

    public double calculateAUC(int i) {
        assertIndex(i);
        return this.underlying[i].calculateAUC();
    }

    public double calculateAUCPR(int i) {
        assertIndex(i);
        return this.underlying[i].calculateAUCPR();
    }

    public double calculateAverageAUC() {
        assertIndex(0);
        double d = 0.0d;
        for (int i = 0; i < this.underlying.length; i++) {
            d += calculateAUC(i);
        }
        return d / this.underlying.length;
    }

    public long getCountActualPositive(int i) {
        assertIndex(i);
        return this.underlying[i].getCountActualPositive();
    }

    public long getCountActualNegative(int i) {
        assertIndex(i);
        return this.underlying[i].getCountActualNegative();
    }

    @Override // org.deeplearning4j.eval.IEvaluation
    public void merge(ROCMultiClass rOCMultiClass) {
        if (this.underlying == null) {
            this.underlying = rOCMultiClass.underlying;
            return;
        }
        if (rOCMultiClass.underlying == null) {
            return;
        }
        if (this.underlying.length != rOCMultiClass.underlying.length) {
            throw new UnsupportedOperationException("Cannot merge ROCBinary: this expects " + this.underlying.length + "outputs, other expects " + rOCMultiClass.underlying.length + " outputs");
        }
        for (int i = 0; i < this.underlying.length; i++) {
            this.underlying[i].merge(rOCMultiClass.underlying[i]);
        }
    }

    public int getNumClasses() {
        if (this.underlying == null) {
            return -1;
        }
        return this.underlying.length;
    }

    private void assertIndex(int i) {
        if (this.underlying == null) {
            throw new IllegalStateException("Cannot get results: no data has been collected");
        }
        if (i < 0 || i >= this.underlying.length) {
            throw new IllegalArgumentException("Invalid class index (" + i + "): must be in range 0 to numClasses = " + this.underlying.length);
        }
    }

    public int getThresholdSteps() {
        return this.thresholdSteps;
    }

    public boolean isRocRemoveRedundantPts() {
        return this.rocRemoveRedundantPts;
    }

    public ROC[] getUnderlying() {
        return this.underlying;
    }

    public List<String> getLabels() {
        return this.labels;
    }

    public void setThresholdSteps(int i) {
        this.thresholdSteps = i;
    }

    public void setRocRemoveRedundantPts(boolean z) {
        this.rocRemoveRedundantPts = z;
    }

    public void setUnderlying(ROC[] rocArr) {
        this.underlying = rocArr;
    }

    public void setLabels(List<String> list) {
        this.labels = list;
    }

    @Override // org.deeplearning4j.eval.BaseEvaluation
    public String toString() {
        return "ROCMultiClass(thresholdSteps=" + getThresholdSteps() + ", rocRemoveRedundantPts=" + isRocRemoveRedundantPts() + ", underlying=" + Arrays.deepToString(getUnderlying()) + ", labels=" + getLabels() + ")";
    }

    @Override // org.deeplearning4j.eval.BaseEvaluation
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof ROCMultiClass)) {
            return false;
        }
        ROCMultiClass rOCMultiClass = (ROCMultiClass) obj;
        if (!rOCMultiClass.canEqual(this) || !super.equals(obj) || getThresholdSteps() != rOCMultiClass.getThresholdSteps() || isRocRemoveRedundantPts() != rOCMultiClass.isRocRemoveRedundantPts() || !Arrays.deepEquals(getUnderlying(), rOCMultiClass.getUnderlying())) {
            return false;
        }
        List<String> labels = getLabels();
        List<String> labels2 = rOCMultiClass.getLabels();
        return labels == null ? labels2 == null : labels.equals(labels2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof ROCMultiClass;
    }

    public int hashCode() {
        int hashCode = (((((((1 * 59) + super.hashCode()) * 59) + getThresholdSteps()) * 59) + (isRocRemoveRedundantPts() ? 79 : 97)) * 59) + Arrays.deepHashCode(getUnderlying());
        List<String> labels = getLabels();
        return (hashCode * 59) + (labels == null ? 43 : labels.hashCode());
    }
}
