package ai.djl.timeseries.distribution;

import ai.djl.ndarray.NDArray;

/* loaded from: input_file:ai/djl/timeseries/distribution/AffineTransformed.class */
public class AffineTransformed extends Distribution {
    private Distribution baseDistribution;
    private NDArray loc;
    private NDArray scale;

    public AffineTransformed(Distribution distribution, NDArray nDArray, NDArray nDArray2) {
        this.baseDistribution = distribution;
        this.loc = nDArray == null ? distribution.mean().zerosLike() : nDArray;
        this.scale = nDArray2 == null ? distribution.mean().onesLike() : nDArray2;
    }

    @Override // ai.djl.timeseries.distribution.Distribution
    public NDArray logProb(NDArray nDArray) {
        NDArray fInv = fInv(nDArray);
        return this.baseDistribution.logProb(fInv).add(logAbsDetJac(fInv).mul(-1));
    }

    @Override // ai.djl.timeseries.distribution.Distribution
    public NDArray sample(int i) {
        return f(this.baseDistribution.sample(i));
    }

    @Override // ai.djl.timeseries.distribution.Distribution
    public NDArray mean() {
        return this.baseDistribution.mean().mul(this.scale).add(this.loc);
    }

    private NDArray f(NDArray nDArray) {
        return nDArray.mul(this.scale).add(this.loc);
    }

    private NDArray fInv(NDArray nDArray) {
        return nDArray.sub(this.loc).div(this.scale);
    }

    private NDArray logAbsDetJac(NDArray nDArray) {
        return this.scale.broadcast(nDArray.getShape()).abs().log();
    }
}
