package ai.djl.training.metrics;

import ai.djl.modality.cv.MultiBoxTarget;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;

/* loaded from: input_file:ai/djl/training/metrics/BoundingBoxError.class */
public class BoundingBoxError extends TrainingMetric {
    private float ssdBoxPredictionError;
    private float numInstances;
    private MultiBoxTarget multiBoxTarget;

    public BoundingBoxError(String str) {
        super(str);
        this.multiBoxTarget = new MultiBoxTarget.Builder().build();
    }

    @Override // ai.djl.training.metrics.TrainingMetric
    public void update(NDList nDList, NDList nDList2) {
        NDArray nDArray = nDList2.get(0);
        NDArray nDArray2 = nDList2.get(1);
        NDArray nDArray3 = nDList2.get(2);
        NDList target = this.multiBoxTarget.target(new NDList(nDArray, nDList.head(), nDArray2.transpose(0, 2, 1)));
        NDArray nDArray4 = target.get(0);
        this.ssdBoxPredictionError += nDArray4.sub(nDArray3).mul(target.get(1)).abs().sum().getFloat(new long[0]);
        this.numInstances += (float) nDArray4.size();
    }

    @Override // ai.djl.training.metrics.TrainingMetric
    public TrainingMetric duplicate() {
        return new BoundingBoxError(getName());
    }

    @Override // ai.djl.training.metrics.TrainingMetric
    public void reset() {
        this.ssdBoxPredictionError = 0.0f;
        this.numInstances = 0.0f;
    }

    @Override // ai.djl.training.metrics.TrainingMetric
    public float getValue() {
        return this.ssdBoxPredictionError / this.numInstances;
    }
}
