package ai.djl.timeseries.distribution.output;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.timeseries.distribution.Distribution;
import ai.djl.timeseries.distribution.NegativeBinomial;
import ai.djl.util.PairList;

/* loaded from: input_file:ai/djl/timeseries/distribution/output/NegativeBinomialOutput.class */
public final class NegativeBinomialOutput extends DistributionOutput {
    public NegativeBinomialOutput() {
        this.argsDim = new PairList<>(2);
        this.argsDim.add("total_count", 1);
        this.argsDim.add("logits", 1);
    }

    @Override // ai.djl.timeseries.distribution.output.DistributionOutput
    public NDList domainMap(NDList nDList) {
        NDArray nDArray = (NDArray) nDList.get(0);
        NDArray nDArray2 = (NDArray) nDList.get(1);
        NDArray squeeze = nDArray.getNDArrayInternal().softPlus().squeeze(-1);
        NDArray squeeze2 = nDArray2.squeeze(-1);
        squeeze.setName("total_count");
        squeeze2.setName("logits");
        return new NDList(new NDArray[]{squeeze, squeeze2});
    }

    @Override // ai.djl.timeseries.distribution.output.DistributionOutput
    public Distribution.DistributionBuilder<?> distributionBuilder() {
        return NegativeBinomial.builder();
    }
}
