package ai.djl.modality.cv.translator;

import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.TranslatorContext;

/* loaded from: input_file:ai/djl/modality/cv/translator/BigGANTranslator.class */
public final class BigGANTranslator implements NoBatchifyTranslator<int[], Image[]> {
    private static final int NUMBER_OF_CATEGORIES = 1000;
    private static final int SEED_COLUMN_SIZE = 128;
    private float truncation;

    public BigGANTranslator(float f) {
        this.truncation = f;
    }

    @Override // ai.djl.translate.PostProcessor
    public Image[] processOutput(TranslatorContext translatorContext, NDList nDList) {
        NDArray clip = nDList.get(0).duplicate().addi((Number) 1).muli(Integer.valueOf(SEED_COLUMN_SIZE)).clip(0, 255);
        int i = (int) clip.getShape().get(0);
        Image[] imageArr = new Image[i];
        for (int i2 = 0; i2 < i; i2++) {
            imageArr[i2] = ImageFactory.getInstance().fromNDArray(clip.get(i2));
        }
        return imageArr;
    }

    @Override // ai.djl.translate.PreProcessor
    public NDList processInput(TranslatorContext translatorContext, int[] iArr) throws Exception {
        NDManager nDManager = translatorContext.getNDManager();
        return new NDList(nDManager.truncatedNormal(new Shape(iArr.length, 128)).muli(Float.valueOf(this.truncation)), nDManager.create(iArr).oneHot(NUMBER_OF_CATEGORIES), nDManager.create(this.truncation));
    }
}
