package ai.djl.training.loss;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;

/* loaded from: input_file:ai/djl/training/loss/L1Loss.class */
public class L1Loss extends Loss {
    private float weight;
    private int batchAxis;

    public L1Loss(float f, int i) {
        super("L1Loss");
        this.weight = f;
        this.batchAxis = i;
    }

    public L1Loss() {
        this(1.0f, 0);
    }

    @Override // ai.djl.training.loss.Loss
    public NDArray getLoss(NDList nDList, NDList nDList2) {
        NDArray singletonOrThrow = nDList2.singletonOrThrow();
        NDArray reshape = nDList.singletonOrThrow().reshape(singletonOrThrow.getShape());
        NDArray abs = reshape.sub(singletonOrThrow).abs();
        if (this.weight != 1.0f) {
            abs = reshape.mul(Float.valueOf(this.weight));
        }
        return abs.mean(excludeBatchAxis(abs, this.batchAxis));
    }
}
