package ai.djl.training.evaluator;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.util.Pair;

/* loaded from: input_file:ai/djl/training/evaluator/BinaryAccuracy.class */
public class BinaryAccuracy extends AbstractAccuracy {
    float threshold;

    public BinaryAccuracy(String str, float f, int i, int i2) {
        super(str, i, i2);
        this.threshold = f;
    }

    public BinaryAccuracy(String str, float f, int i) {
        this(str, f, i, 1);
    }

    public BinaryAccuracy(float f) {
        this("BinaryAccuracy", f, 0, 1);
    }

    public BinaryAccuracy() {
        this(0.0f);
    }

    @Override // ai.djl.training.evaluator.AbstractAccuracy
    protected Pair<Long, NDArray> accuracyHelper(NDList nDList, NDList nDList2) {
        if (nDList.size() != nDList2.size()) {
            throw new IllegalArgumentException("labels and prediction length does not match.");
        }
        NDArray nDArray = nDList.get(this.index);
        NDArray nDArray2 = nDList2.get(this.index);
        checkLabelShapes(nDArray, nDArray2, false);
        NDArray gte = nDArray2.gte(Float.valueOf(this.threshold));
        long size = nDArray.size();
        return new Pair<>(Long.valueOf(size), nDArray.toType(DataType.INT64, false).eq(gte.toType(DataType.INT64, false)).countNonzero());
    }
}
