package ai.djl.timeseries.distribution;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.timeseries.distribution.Distribution;
import ai.djl.util.Preconditions;

/* loaded from: input_file:ai/djl/timeseries/distribution/NegativeBinomial.class */
public final class NegativeBinomial extends Distribution {
    private NDArray totalCount;
    private NDArray logits;

    /* loaded from: input_file:ai/djl/timeseries/distribution/NegativeBinomial$Builder.class */
    public static final class Builder extends Distribution.DistributionBuilder<Builder> {
        @Override // ai.djl.timeseries.distribution.Distribution.DistributionBuilder
        public Distribution build() {
            Preconditions.checkArgument(this.distrArgs.contains("total_count"), "NegativeBinomial's args must contain total_count.");
            Preconditions.checkArgument(this.distrArgs.contains("logits"), "NegativeBinomial's args must contain logits.");
            if (this.scale != null) {
                NDArray nDArray = this.distrArgs.get("logits");
                nDArray.add(this.scale.log());
                nDArray.setName("logits");
                this.distrArgs.remove("logits");
                this.distrArgs.add(nDArray);
            }
            return new NegativeBinomial(this);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // ai.djl.timeseries.distribution.Distribution.DistributionBuilder
        public Builder self() {
            return this;
        }
    }

    NegativeBinomial(Builder builder) {
        this.totalCount = builder.distrArgs.get("total_count");
        this.logits = builder.distrArgs.get("logits");
    }

    @Override // ai.djl.timeseries.distribution.Distribution
    public NDArray logProb(NDArray nDArray) {
        return this.totalCount.mul(logSigmoid(this.logits.mul(-1))).add(nDArray.mul(logSigmoid(this.logits))).sub(this.totalCount.add(nDArray).gammaln().mul(-1).add(nDArray.add(1).gammaln()).add(this.totalCount.gammaln()));
    }

    @Override // ai.djl.timeseries.distribution.Distribution
    public NDArray sample(int i) {
        NDManager manager = this.totalCount.getManager();
        return manager.samplePoisson(manager.sampleGamma(i > 0 ? this.totalCount.expandDims(0).repeat(0, i) : this.totalCount, (i > 0 ? this.logits.expandDims(0).repeat(0, i) : this.logits).exp()));
    }

    @Override // ai.djl.timeseries.distribution.Distribution
    public NDArray mean() {
        return this.totalCount.mul(this.logits.exp());
    }

    private NDArray logSigmoid(NDArray nDArray) {
        return nDArray.mul(-1).exp().add(1).getNDArrayInternal().rdiv(1).log();
    }

    public static Builder builder() {
        return new Builder();
    }
}
