package ai.djl.training.loss;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import java.util.Arrays;

/* loaded from: input_file:ai/djl/training/loss/CompositeLoss.class */
public class CompositeLoss extends Loss {
    private Loss[] components;

    public CompositeLoss(Loss... lossArr) {
        super("CompositeLoss");
        this.components = lossArr;
    }

    @Override // ai.djl.training.loss.Loss
    public NDArray getLoss(NDList nDList, NDList nDList2) {
        return NDArrays.add((NDArray[]) Arrays.stream(this.components).map(loss -> {
            return loss.getLoss(nDList, nDList2);
        }).toArray(i -> {
            return new NDArray[i];
        }));
    }

    @Override // ai.djl.training.loss.Loss, ai.djl.training.metrics.TrainingMetric
    public Loss duplicate() {
        return new CompositeLoss((Loss[]) Arrays.stream(this.components).map((v0) -> {
            return v0.duplicate();
        }).toArray(i -> {
            return new Loss[i];
        }));
    }

    @Override // ai.djl.training.loss.Loss, ai.djl.training.metrics.TrainingMetric
    public void update(NDList nDList, NDList nDList2) {
        for (Loss loss : this.components) {
            loss.update(nDList, nDList2);
        }
    }

    @Override // ai.djl.training.loss.Loss, ai.djl.training.metrics.TrainingMetric
    public void reset() {
        for (Loss loss : this.components) {
            loss.reset();
        }
    }

    @Override // ai.djl.training.loss.Loss, ai.djl.training.metrics.TrainingMetric
    public float getValue() {
        return (float) Arrays.stream(this.components).mapToDouble((v0) -> {
            return v0.getValue();
        }).sum();
    }
}
