package ai.djl.training.loss;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.training.metrics.TrainingMetric;
import java.util.stream.IntStream;

/* loaded from: input_file:ai/djl/training/loss/Loss.class */
public abstract class Loss extends TrainingMetric {
    private float totalLoss;
    private int totalInstances;

    public Loss(String str) {
        super(str);
    }

    public abstract NDArray getLoss(NDList nDList, NDList nDList2);

    public static L1Loss l1Loss() {
        return new L1Loss();
    }

    public static L1Loss l1Loss(float f, int i) {
        return new L1Loss(f, i);
    }

    public static L2Loss l2Loss() {
        return new L2Loss();
    }

    public static L2Loss l2Loss(float f, int i) {
        return new L2Loss(f, i);
    }

    public static SigmoidBinaryCrossEntropyLoss sigmoidBinaryCrossEntropyLoss() {
        return new SigmoidBinaryCrossEntropyLoss();
    }

    public static SigmoidBinaryCrossEntropyLoss sigmoidBinaryCrossEntropyLoss(float f, int i, boolean z) {
        return new SigmoidBinaryCrossEntropyLoss(f, i, z);
    }

    public static SoftmaxCrossEntropyLoss softmaxCrossEntropyLoss() {
        return new SoftmaxCrossEntropyLoss();
    }

    public static SoftmaxCrossEntropyLoss softmaxCrossEntropyLoss(float f, int i, int i2, boolean z, boolean z2) {
        return new SoftmaxCrossEntropyLoss(f, i, i2, z, z2);
    }

    public static HingeLoss hingeLoss() {
        return new HingeLoss();
    }

    public static HingeLoss hingeLoss(int i, float f, int i2) {
        return new HingeLoss(i, f, i2);
    }

    @Override // ai.djl.training.metrics.TrainingMetric
    public Loss duplicate() {
        try {
            return (Loss) clone();
        } catch (CloneNotSupportedException e) {
            throw new AssertionError("Clone is not supported", e);
        }
    }

    @Override // ai.djl.training.metrics.TrainingMetric
    public void update(NDList nDList, NDList nDList2) {
        NDArray loss = getLoss(nDList, nDList2);
        this.totalLoss += loss.sum().getFloat(new long[0]);
        this.totalInstances = (int) (this.totalInstances + loss.size());
    }

    @Override // ai.djl.training.metrics.TrainingMetric
    public void reset() {
        this.totalLoss = 0.0f;
        this.totalInstances = 0;
    }

    @Override // ai.djl.training.metrics.TrainingMetric
    public float getValue() {
        if (this.totalInstances == 0) {
            return Float.NaN;
        }
        return this.totalLoss / this.totalInstances;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public int[] excludeBatchAxis(NDArray nDArray, int i) {
        return IntStream.range(0, nDArray.getShape().dimension()).filter(i2 -> {
            return i2 != i;
        }).toArray();
    }
}
