package ai.djl.training.metrics;

import ai.djl.ndarray.NDArray;
import java.util.stream.IntStream;

/* loaded from: input_file:ai/djl/training/metrics/TopKAccuracy.class */
public class TopKAccuracy extends Accuracy {
    private int topK;

    public TopKAccuracy(String str, int i, int i2) {
        super(str, i);
        if (i2 <= 1) {
            throw new IllegalArgumentException("Please use TopKAccuracy with topK more than 1");
        }
        this.topK = i2;
    }

    public TopKAccuracy(int i, int i2) {
        this("Top_" + i2 + "_Accuracy", i, i2);
    }

    public TopKAccuracy(int i) {
        this("Top_" + i + "_Accuracy", 0, i);
    }

    @Override // ai.djl.training.metrics.Accuracy
    public void update(NDArray nDArray, NDArray nDArray2) {
        checkLabelShapes(nDArray, nDArray2);
        if (nDArray2.getShape().dimension() > 2) {
            throw new IllegalStateException("Prediction should be less than 2 dimensions");
        }
        NDArray argSort = nDArray2.argSort(this.axis);
        int dimension = argSort.getShape().dimension();
        if (dimension == 1) {
            addCorrectInstances(argSort.flatten().eq(nDArray.flatten()).countNonzero().getLong(new long[0]));
        } else if (dimension == 2) {
            int i = (int) argSort.getShape().get(1);
            this.topK = Math.min(this.topK, i);
            IntStream.range(0, this.topK).forEach(i2 -> {
                addCorrectInstances(argSort.get(":, " + ((i - i2) - 1)).flatten().eq(nDArray.flatten()).countNonzero().getLong(new long[0]));
            });
        }
        addTotalInstances((int) nDArray.getShape().get(0));
    }
}
