-
Notifications
You must be signed in to change notification settings - Fork 272
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
multivariate_normal distribution api 设计文档 #320
Conversation
# 二、飞桨现状 | ||
|
||
- 目前 飞桨没有 API `paddle.distribution.MultivariateNormal` | ||
- API `paddle.distribution.Normal`的代码开发风格可以作为`paddle.distribution.MultivariateNormal` 的主要参考。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Normal
是一个比较旧的API,代码风格可以参考Multinomial
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的!
- rsample(value):重参数化采样 | ||
```python | ||
rsample(value) | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
补充rsample设计思路
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
补充rsample设计思路
done
```python | ||
1 / 2 * paddle.log(paddle.pow(2 * math.pi * math.e, value.shpe.pop(1)) * paddle.linalg.det(self.convariance_matrix)) | ||
``` | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在kl.py中注册 MultivariateNormal KL散度计算逻辑
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在kl.py中注册 MultivariateNormal KL散度计算逻辑
done
* 均值、方差、标准差通过Numpy计算相应值,对比MultivariateNormal类中相应property的返回值,若一致即正确; | ||
|
||
* 采样方法除验证其返回的数据类型及数据形状是否合法外,还需证明采样结果符合MultivariateNormal分布。验证策略如下:随机采样30000个laplace分布下的样本值,计算采样样本的均值和方差,并比较同分布下`scipy.stats.multivariate_normal`返回的均值与方差,检查是否在合理误差范围内;同时通过Kolmogorov-Smirnov test进一步验证采样是否属于multivariate_normal分布,若计算所得ks值小于0.1,则拒绝不一致假设,两者属于同一分布; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
laplace -> MultivariateNormal
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
laplace -> MultivariateNormal
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
No description provided.