package ai.djl.timeseries.distribution.output;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.timeseries.distribution.Distribution;
import ai.djl.timeseries.distribution.StudentT;
import ai.djl.util.PairList;

/* loaded from: input_file:ai/djl/timeseries/distribution/output/StudentTOutput.class */
public class StudentTOutput extends DistributionOutput {
    public StudentTOutput() {
        this.argsDim = new PairList<>(3);
        this.argsDim.add("mu", 1);
        this.argsDim.add("sigma", 1);
        this.argsDim.add("nu", 1);
    }

    @Override // ai.djl.timeseries.distribution.output.DistributionOutput
    public NDList domainMap(NDList nDList) {
        NDArray nDArray = (NDArray) nDList.get(0);
        NDArray nDArray2 = (NDArray) nDList.get(1);
        NDArray nDArray3 = (NDArray) nDList.get(2);
        NDArray squeeze = nDArray.squeeze(-1);
        NDArray squeeze2 = nDArray2.getNDArrayInternal().softPlus().squeeze(-1);
        NDArray squeeze3 = nDArray3.getNDArrayInternal().softPlus().add(Double.valueOf(2.0d)).squeeze(-1);
        squeeze.setName("mu");
        squeeze2.setName("sigma");
        squeeze3.setName("nu");
        return new NDList(new NDArray[]{squeeze, squeeze2, squeeze3});
    }

    @Override // ai.djl.timeseries.distribution.output.DistributionOutput
    public Distribution.DistributionBuilder<?> distributionBuilder() {
        return StudentT.builder();
    }
}
