package ai.djl.nn.transformer;

import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.core.Linear;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.util.Collections;

/* loaded from: input_file:ai/djl/nn/transformer/BertNextSentenceBlock.class */
public class BertNextSentenceBlock extends AbstractBlock {
    private Linear binaryClassifier = (Linear) addChildBlock("binaryClassifier", Linear.builder().setUnits(2).optBias(true).build());

    @Override // ai.djl.nn.AbstractBlock
    public void initializeChildBlocks(NDManager nDManager, DataType dataType, Shape... shapeArr) {
        this.inputNames = Collections.singletonList("pooledOutput");
        this.binaryClassifier.initialize(nDManager, dataType, shapeArr);
    }

    @Override // ai.djl.nn.AbstractBlock
    protected NDList forwardInternal(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        return new NDList(this.binaryClassifier.forward(parameterStore, nDList, z).head().logSoftmax(1));
    }

    @Override // ai.djl.nn.Block
    public Shape[] getOutputShapes(Shape[] shapeArr) {
        return new Shape[]{new Shape(shapeArr[0].get(0), 2)};
    }
}
