package ai.djl.nn.transformer;

import ai.djl.ndarray.NDList;
import ai.djl.training.loss.AbstractCompositeLoss;
import ai.djl.util.Pair;
import java.util.Arrays;

/* loaded from: input_file:ai/djl/nn/transformer/BertPretrainingLoss.class */
public class BertPretrainingLoss extends AbstractCompositeLoss {
    private BertNextSentenceLoss bertNextSentenceLoss;
    private BertMaskedLanguageModelLoss bertMaskedLanguageModelLoss;

    public BertPretrainingLoss() {
        super(BertPretrainingLoss.class.getSimpleName());
        this.bertNextSentenceLoss = new BertNextSentenceLoss(0, 0);
        this.bertMaskedLanguageModelLoss = new BertMaskedLanguageModelLoss(1, 2, 1);
        this.components = Arrays.asList(this.bertNextSentenceLoss, this.bertMaskedLanguageModelLoss);
    }

    @Override // ai.djl.training.loss.AbstractCompositeLoss
    protected Pair<NDList, NDList> inputForComponent(int i, NDList nDList, NDList nDList2) {
        return new Pair<>(nDList, nDList2);
    }

    public BertNextSentenceLoss getBertNextSentenceLoss() {
        return this.bertNextSentenceLoss;
    }

    public BertMaskedLanguageModelLoss getBertMaskedLanguageModelLoss() {
        return this.bertMaskedLanguageModelLoss;
    }
}
