package org.wso2.extension.siddhi.execution.streamingml.bayesian.model;

import io.siddhi.core.exception.SiddhiAppCreationException;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;

/* loaded from: input_file:org/wso2/extension/siddhi/execution/streamingml/bayesian/model/BernoulliDistribution.class */
public class BernoulliDistribution extends Distribution {
    private static final long serialVersionUID = 5823219877164613491L;
    private SDVariable prob;

    public BernoulliDistribution(SDVariable sDVariable, SameDiff sameDiff) {
        this.prob = sameDiff.sigmoid(sDVariable);
        this.sd = sameDiff;
    }

    @Override // org.wso2.extension.siddhi.execution.streamingml.bayesian.model.Distribution
    public SDVariable logProbability(SDVariable sDVariable) {
        if (sDVariable.getShape().length == 1) {
            sDVariable = this.sd.reshape(sDVariable, -1, 1);
        }
        return this.sd.log(this.prob.mul(sDVariable).add(this.prob.sub(1.0d).mul(sDVariable.sub(1.0d))));
    }

    @Override // org.wso2.extension.siddhi.execution.streamingml.bayesian.model.Distribution
    public SDVariable sample() {
        return null;
    }

    @Override // org.wso2.extension.siddhi.execution.streamingml.bayesian.model.Distribution
    public SDVariable sample(int i) {
        return null;
    }

    @Override // org.wso2.extension.siddhi.execution.streamingml.bayesian.model.Distribution
    public SDVariable klDivergence(Distribution distribution) throws SiddhiAppCreationException {
        throw new SiddhiAppCreationException("kl-divergence is not implemented for categorical distribution");
    }

    public SDVariable getProb() {
        return this.prob;
    }
}
