package io.siddhi.extension.execution.streamingml.bayesian.model;

import io.siddhi.core.exception.SiddhiAppCreationException;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:io/siddhi/extension/execution/streamingml/bayesian/model/NormalDistribution.class */
public class NormalDistribution extends Distribution {
    private static final long serialVersionUID = 6666442004853927088L;
    private SDVariable loc;
    private SDVariable scale;

    public NormalDistribution(SDVariable sDVariable, SDVariable sDVariable2, SameDiff sameDiff) {
        this.loc = sDVariable;
        this.scale = sDVariable2;
        this.sd = sameDiff;
    }

    NormalDistribution(SDVariable sDVariable, INDArray iNDArray, SameDiff sameDiff) {
        this.loc = sDVariable;
        this.scale = sameDiff.var("scale", iNDArray);
        this.sd = sameDiff;
    }

    @Override // io.siddhi.extension.execution.streamingml.bayesian.model.Distribution
    public SDVariable logProbability(SDVariable sDVariable) {
        return this.sd.neg(this.sd.log(this.scale.mul(Math.sqrt(6.283185307179586d)))).sub(this.sd.square(sDVariable.sub(this.loc).div(this.scale.mul(2.0d))));
    }

    @Override // io.siddhi.extension.execution.streamingml.bayesian.model.Distribution
    public SDVariable sample() {
        return this.loc.add(this.scale.mul(this.sd.randomNormal(CMAESOptimizer.DEFAULT_STOPFITNESS, 1.0d, this.scale.getShape())));
    }

    @Override // io.siddhi.extension.execution.streamingml.bayesian.model.Distribution
    public SDVariable sample(int i) {
        return this.sd.reshape(this.loc, new long[]{this.loc.getShape()[0], this.loc.getShape()[1], 1}).add(this.sd.reshape(this.scale, new long[]{this.scale.getShape()[0], this.scale.getShape()[1], 1}).mul(this.sd.randomNormal(CMAESOptimizer.DEFAULT_STOPFITNESS, 1.0d, new long[]{this.scale.getShape()[0], this.scale.getShape()[1], i})));
    }

    @Override // io.siddhi.extension.execution.streamingml.bayesian.model.Distribution
    public SDVariable klDivergence(Distribution distribution) throws SiddhiAppCreationException {
        if (!(distribution instanceof NormalDistribution)) {
            throw new SiddhiAppCreationException("kl-divergence with normal and other distributions are not supported");
        }
        SDVariable sDVariable = ((NormalDistribution) distribution).loc;
        SDVariable sDVariable2 = ((NormalDistribution) distribution).scale;
        return this.sd.log(sDVariable2.div(this.scale)).add(this.sd.square(this.scale).add(this.sd.square(this.loc.sub(sDVariable))).div(this.sd.square(sDVariable2).mul(2.0d))).sub(0.5d);
    }

    public SDVariable getLoc() {
        return this.loc;
    }

    public SDVariable getScale() {
        return this.scale;
    }
}
