public class NormalDistribution extends Object
sampling uses reparameterization trick [1] to reduce the variance of the gradients reparameterization trick follows q(z) = loc + scale * epsilon, here epsilon ~ N(0, 1)
[1] https://arxiv.org/abs/1312.6114
Constructor and Description |
---|
NormalDistribution(org.nd4j.autodiff.samediff.SDVariable loc,
org.nd4j.autodiff.samediff.SDVariable scale,
org.nd4j.autodiff.samediff.SameDiff sd)
constructs a normal distribution.
|
Modifier and Type | Method and Description |
---|---|
org.nd4j.autodiff.samediff.SDVariable |
getLoc() |
org.nd4j.autodiff.samediff.SDVariable |
getScale() |
org.nd4j.autodiff.samediff.SDVariable |
klDivergence(org.wso2.extension.siddhi.execution.streamingml.bayesian.model.Distribution distribution)
kullback leibler divergence between two normal densities.
|
org.nd4j.autodiff.samediff.SDVariable |
logProbability(org.nd4j.autodiff.samediff.SDVariable value)
computes the log-probability of normal distribution.
|
org.nd4j.autodiff.samediff.SDVariable |
sample()
returns a random sample from the distribution.
|
org.nd4j.autodiff.samediff.SDVariable |
sample(int n)
returns random samples from the distribution.
|
public NormalDistribution(org.nd4j.autodiff.samediff.SDVariable loc, org.nd4j.autodiff.samediff.SDVariable scale, org.nd4j.autodiff.samediff.SameDiff sd)
loc
- location/mean. Expected a 2-dimensional variable s,t (input_size, output_size)scale
- standard deviationsd
- SameDiff contextpublic org.nd4j.autodiff.samediff.SDVariable logProbability(org.nd4j.autodiff.samediff.SDVariable value)
log(p(x)) = -log(sqrt[2*PI]*scale)-([x-loc]/[2*scale])^2
value
- x valuespublic org.nd4j.autodiff.samediff.SDVariable sample()
public org.nd4j.autodiff.samediff.SDVariable sample(int n)
n
- number of samplespublic org.nd4j.autodiff.samediff.SDVariable klDivergence(org.wso2.extension.siddhi.execution.streamingml.bayesian.model.Distribution distribution) throws io.siddhi.core.exception.SiddhiAppCreationException
KL(q(loc_1,scale_1)||p(loc_2,scale_2) = log(scale_2) - log(scale_1) + (scale_1^2 + [loc_1-loc_2]^2)/(2*scale_2^2) + 0.5
distribution
- reference distribution p(x)io.siddhi.core.exception.SiddhiAppCreationException
- if distribution is not Gaussianpublic org.nd4j.autodiff.samediff.SDVariable getLoc()
public org.nd4j.autodiff.samediff.SDVariable getScale()
Copyright © 2019 WSO2. All rights reserved.