package ai.djl.modality.nlp.embedding;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:ai/djl/modality/nlp/embedding/TrainableTextEmbedding.class */
public class TrainableTextEmbedding extends AbstractBlock implements TextEmbedding {
    private static final byte VERSION = 1;
    private TrainableWordEmbedding trainableWordEmbedding;

    public TrainableTextEmbedding(TrainableWordEmbedding trainableWordEmbedding) {
        super((byte) 1);
        this.trainableWordEmbedding = (TrainableWordEmbedding) addChildBlock("trainableWordEmbedding", trainableWordEmbedding);
    }

    @Override // ai.djl.modality.nlp.embedding.TextEmbedding
    public long[] preprocessTextToEmbed(List<String> list) {
        long[] jArr = new long[list.size()];
        for (int i = 0; i < list.size(); i += VERSION) {
            jArr[i] = this.trainableWordEmbedding.preprocessWordToEmbed(list.get(i));
        }
        return jArr;
    }

    @Override // ai.djl.modality.nlp.embedding.TextEmbedding
    public NDArray embedText(NDArray nDArray) throws EmbeddingException {
        throw new UnsupportedOperationException("EmbedText operation is not supported by this class.");
    }

    @Override // ai.djl.modality.nlp.embedding.TextEmbedding
    public List<String> unembedText(NDArray nDArray) {
        NDList split = nDArray.split(nDArray.getShape().get(0));
        ArrayList arrayList = new ArrayList(split.size());
        Iterator<NDArray> it = split.iterator();
        while (it.hasNext()) {
            arrayList.add(this.trainableWordEmbedding.unembedWord(it.next().get(0)));
        }
        return arrayList;
    }

    @Override // ai.djl.nn.AbstractBlock
    protected NDList forwardInternal(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        return this.trainableWordEmbedding.forward(parameterStore, nDList, z, pairList);
    }

    @Override // ai.djl.nn.AbstractBlock
    public void initializeChildBlocks(NDManager nDManager, DataType dataType, Shape... shapeArr) {
        this.trainableWordEmbedding.initialize(nDManager, dataType, shapeArr);
    }

    @Override // ai.djl.nn.Block
    public Shape[] getOutputShapes(Shape[] shapeArr) {
        return this.trainableWordEmbedding.getOutputShapes(shapeArr);
    }
}
