Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multivariate_normal distribution api 设计文档 #320

Merged
merged 5 commits into from
Nov 9, 2022

Conversation

dasenCoding
Copy link
Contributor

No description provided.

@paddle-bot
Copy link

paddle-bot bot commented Oct 29, 2022

你的PR提交成功,感谢你对开源项目的贡献!
请检查PR提交格式和内容是否完备,具体请参考示例模版
Your PR has been submitted. Thanks for your contribution!
Please check its format and content. For this, you can refer to Template and Demo.

@dasenCoding dasenCoding changed the title [Hackathon 3rd No.7] add multivariate_normal distribution api [Hackathon 3rd No.7] multivariate_normal distribution api 设计文档 Oct 29, 2022
@dasenCoding dasenCoding changed the title [Hackathon 3rd No.7] multivariate_normal distribution api 设计文档 multivariate_normal distribution api 设计文档 Nov 2, 2022
# 二、飞桨现状

- 目前 飞桨没有 API `paddle.distribution.MultivariateNormal`
- API `paddle.distribution.Normal`的代码开发风格可以作为`paddle.distribution.MultivariateNormal` 的主要参考。
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Normal是一个比较旧的API,代码风格可以参考Multinomial

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的!

- rsample(value):重参数化采样
```python
rsample(value)
```
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

补充rsample设计思路

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

补充rsample设计思路

done

```python
1 / 2 * paddle.log(paddle.pow(2 * math.pi * math.e, value.shpe.pop(1)) * paddle.linalg.det(self.convariance_matrix))
```

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在kl.py中注册 MultivariateNormal KL散度计算逻辑

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在kl.py中注册 MultivariateNormal KL散度计算逻辑

done

* 均值、方差、标准差通过Numpy计算相应值,对比MultivariateNormal类中相应property的返回值,若一致即正确;

* 采样方法除验证其返回的数据类型及数据形状是否合法外,还需证明采样结果符合MultivariateNormal分布。验证策略如下:随机采样30000个laplace分布下的样本值,计算采样样本的均值和方差,并比较同分布下`scipy.stats.multivariate_normal`返回的均值与方差,检查是否在合理误差范围内;同时通过Kolmogorov-Smirnov test进一步验证采样是否属于multivariate_normal分布,若计算所得ks值小于0.1,则拒绝不一致假设,两者属于同一分布;

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

laplace -> MultivariateNormal

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

laplace -> MultivariateNormal

done

Copy link

@cxxly cxxly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants