package ai.djl.nn.transformer;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Block;
import ai.djl.nn.core.Linear;
import ai.djl.nn.norm.Dropout;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.util.Iterator;

/* loaded from: input_file:ai/djl/nn/transformer/ScaledDotProductAttentionBlock.class */
public final class ScaledDotProductAttentionBlock extends AbstractBlock {
    private static final byte VERSION = 1;
    private int embeddingSize;
    private int headCount;
    private Linear keyProjection;
    private Linear queryProjection;
    private Linear valueProjection;
    private Linear resultProjection;
    private Dropout attentionProbsDropout;

    /* loaded from: input_file:ai/djl/nn/transformer/ScaledDotProductAttentionBlock$Builder.class */
    public static final class Builder {
        private int embeddingSize;
        private int headCount;
        private float attentionProbsDropoutProb;

        private Builder() {
            this.attentionProbsDropoutProb = 0.1f;
        }

        public Builder setEmbeddingSize(int i) {
            this.embeddingSize = i;
            return this;
        }

        public Builder setHeadCount(int i) {
            this.headCount = i;
            return this;
        }

        public Builder optAttentionProbsDropoutProb(float f) {
            this.attentionProbsDropoutProb = f;
            return this;
        }

        public ScaledDotProductAttentionBlock build() {
            if (this.embeddingSize < ScaledDotProductAttentionBlock.VERSION) {
                throw new IllegalStateException("Embedding size not initialized.");
            }
            if (this.headCount < ScaledDotProductAttentionBlock.VERSION) {
                throw new IllegalStateException("Head count not initialized.");
            }
            if (this.embeddingSize % this.headCount != 0) {
                throw new IllegalStateException("Embedding Size (" + this.embeddingSize + ") is not divisible by head count (" + this.headCount + ")");
            }
            return new ScaledDotProductAttentionBlock(this);
        }
    }

    private ScaledDotProductAttentionBlock(Builder builder) {
        super((byte) 1);
        this.embeddingSize = builder.embeddingSize;
        this.headCount = builder.headCount;
        this.keyProjection = (Linear) addChildBlock("keyProjection", buildProjection());
        this.queryProjection = (Linear) addChildBlock("queryProjection", buildProjection());
        this.valueProjection = (Linear) addChildBlock("valueProjection", buildProjection());
        this.resultProjection = (Linear) addChildBlock("resultProjection", buildProjection());
        this.attentionProbsDropout = (Dropout) addChildBlock("probabilityDropout", Dropout.builder().optRate(builder.attentionProbsDropoutProb).build());
    }

    private Linear buildProjection() {
        return Linear.builder().setUnits(this.embeddingSize).optBias(true).build();
    }

    public Linear getKeyProjection() {
        return this.keyProjection;
    }

    public Linear getQueryProjection() {
        return this.queryProjection;
    }

    public Linear getValueProjection() {
        return this.valueProjection;
    }

    public Linear getResultProjection() {
        return this.resultProjection;
    }

    @Override // ai.djl.nn.Block
    public Shape[] getOutputShapes(Shape[] shapeArr) {
        if (shapeArr.length == VERSION || shapeArr.length == 2) {
            return new Shape[]{shapeArr[0]};
        }
        if (shapeArr.length == 3 || shapeArr.length == 4) {
            return new Shape[]{shapeArr[VERSION]};
        }
        throw new IllegalArgumentException("Invalid number of input shapes: " + shapeArr.length + ", must be 1-4.");
    }

    @Override // ai.djl.nn.AbstractBlock
    public void initializeChildBlocks(NDManager nDManager, DataType dataType, Shape... shapeArr) {
        Shape shape = new Shape(-1, this.embeddingSize);
        Iterator<Block> it = this.children.values().iterator();
        while (it.hasNext()) {
            it.next().initialize(nDManager, DataType.FLOAT32, shape);
        }
    }

    private NDArray createAttentionHeadsFromEmbeddings(NDArray nDArray, long j, long j2, long j3, long j4) {
        return nDArray.reshape(j, j2, j3, j4).transpose(0, 2, VERSION, 3);
    }

    @Override // ai.djl.nn.AbstractBlock
    protected NDList forwardInternal(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        long j;
        long j2;
        NDList nDList2;
        NDList nDList3;
        NDList nDList4;
        NDArray nDArray;
        long j3 = this.embeddingSize;
        long j4 = nDList.head().getShape().get(0);
        long j5 = this.headCount;
        long j6 = j3 / j5;
        if (nDList.size() < 3) {
            j = nDList.head().getShape().get(VERSION);
            j2 = j;
            nDList2 = new NDList(nDList.head());
            nDList3 = nDList2;
            nDList4 = nDList2;
        } else {
            j = nDList.get(0).getShape().get(VERSION);
            j2 = nDList.get(VERSION).getShape().get(VERSION);
            nDList2 = new NDList(nDList.get(0));
            nDList3 = new NDList(nDList.get(VERSION));
            nDList4 = new NDList(nDList.get(2));
        }
        NDArray nDArray2 = (nDList.size() == 2 || nDList.size() == 4) ? nDList.get(nDList.size() - VERSION) : null;
        NDList forward = this.keyProjection.forward(parameterStore, nDList2, z, pairList);
        NDList forward2 = this.queryProjection.forward(parameterStore, nDList3, z, pairList);
        NDList forward3 = this.valueProjection.forward(parameterStore, nDList4, z, pairList);
        NDArray createAttentionHeadsFromEmbeddings = createAttentionHeadsFromEmbeddings(forward.head(), j4, j, j5, j6);
        NDArray createAttentionHeadsFromEmbeddings2 = createAttentionHeadsFromEmbeddings(forward2.head(), j4, j2, j5, j6);
        NDArray createAttentionHeadsFromEmbeddings3 = createAttentionHeadsFromEmbeddings(forward3.head(), j4, j, j5, j6);
        NDArray matMul = createAttentionHeadsFromEmbeddings2.matMul(createAttentionHeadsFromEmbeddings.transpose(0, VERSION, 3, 2));
        NDArray mul = matMul.mul(matMul.getManager().create(1.0f / ((float) Math.sqrt(j6))));
        if (nDArray2 != null) {
            if (nDArray2.getShape().dimension() != 4) {
                NDArray reshape = nDArray2.reshape(j4, 1, j2, j);
                nDArray = reshape.toType(DataType.FLOAT32, false).mul(reshape.getManager().create(-1.0f)).add(reshape.getManager().create(1.0f)).mul(reshape.getManager().create(-100000.0f));
            } else {
                nDArray = nDArray2;
            }
            mul = mul.add(nDArray);
        }
        return new NDList(this.resultProjection.forward(parameterStore, new NDList(this.attentionProbsDropout.forward(parameterStore, new NDList(mul.softmax(3)), z).singletonOrThrow().matMul(createAttentionHeadsFromEmbeddings3).transpose(0, 2, VERSION, 3).reshape(j4, j2, j3)), z));
    }

    public static Builder builder() {
        return new Builder();
    }
}
