public class CategoricalDistribution extends Object
Constructor and Description |
---|
CategoricalDistribution(org.nd4j.autodiff.samediff.SDVariable logits,
org.nd4j.autodiff.samediff.SameDiff sd)
Construct the categorical distribution.
|
Modifier and Type | Method and Description |
---|---|
org.nd4j.autodiff.samediff.SDVariable |
getProb() |
org.nd4j.autodiff.samediff.SDVariable |
klDivergence(org.wso2.extension.siddhi.execution.streamingml.bayesian.model.Distribution distribution)
returns the kl divergence w.r.t the given distribution.
|
org.nd4j.autodiff.samediff.SDVariable |
logProbability(org.nd4j.autodiff.samediff.SDVariable values)
categorical log probability is implemented based on softmax-crossentropy
the implementation is based on following formula
|
org.nd4j.autodiff.samediff.SDVariable |
sample()
returns a random sample from the distribution.
|
org.nd4j.autodiff.samediff.SDVariable |
sample(int n)
returns random samples from the distribution.
|
public CategoricalDistribution(org.nd4j.autodiff.samediff.SDVariable logits, org.nd4j.autodiff.samediff.SameDiff sd)
logits
- should be 2-dimensional.
the dimensions should follow the order (input_size, num_classes)sd
- SameDiff contextpublic org.nd4j.autodiff.samediff.SDVariable logProbability(org.nd4j.autodiff.samediff.SDVariable values)
log(p(y)) = sum[1:num_classes]{log(softmax[logits])*y}
however, log(softmax(logits)) can be infinity for some case. hence, unecessary log computations are avoided using transformed formula
log(p(y)) = sum[1:num_classes]{log(softmax[logits]*y)}
output of the both formulas are equivalent if y is one-hot encoded
values
- one-hot encoded labelspublic org.nd4j.autodiff.samediff.SDVariable sample()
public org.nd4j.autodiff.samediff.SDVariable sample(int n)
n
- number of samplespublic org.nd4j.autodiff.samediff.SDVariable klDivergence(org.wso2.extension.siddhi.execution.streamingml.bayesian.model.Distribution distribution) throws org.wso2.siddhi.core.exception.SiddhiAppCreationException
distribution
- reference distribution p(x)org.wso2.siddhi.core.exception.SiddhiAppCreationException
public org.nd4j.autodiff.samediff.SDVariable getProb()
Copyright © 2019 WSO2. All rights reserved.