-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 5th No.19】Add ContinuousBernoulli and MultivariateNormal API #58004
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
Math Derivation for entropy of the Continuous Bernoulli distribution and kl_divergence of 2 Continuous Bernoulli distributions:
|
Math Derivation for entropy of the Multivariate Normal distribution and kl_divergence of 2 Multivariate Normal distributions:
|
Sorry to inform you that 064e8a9's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
Sorry to inform you that 4ce267e's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
|
Sorry to inform you that 8b913d3's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
edce334
to
2c5ce90
Compare
# convert type | ||
if isinstance(probability, (float, int)): | ||
probability = [probability] | ||
probability = paddle.to_tensor(probability, dtype=self.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果probability本身是Tensor,这里会改变probability的数据类型。别入用户传入p是fp32, 默认数据类型是fp64
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
code is fine, please add link of rfc in description above |
the design of |
the rfc of |
添加了rfc链接,rfc设计文档需要做一些修改,已提相应pr |
对应中文文档可以提上来 |
已提中文pr 又修改了一下对应的英文文档 |
Args: | ||
probability(int|float|Tensor): The probability of Continuous Bernoulli distribution between [0, 1], | ||
which characterize the shape of the pdf. If the input data type is int or float, the data type of | ||
`probability` will be convert to a 1-D Tensor the paddle global default dtype. | ||
eps(float): Specify the bandwith of the unstable calculation region near 0.5. The unstable calculation region | ||
would be [0.5 - eps, 0.5 + eps], where the calculation is approximated by talyor expansion. The | ||
default value is 0.02. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Args: | |
probability(int|float|Tensor): The probability of Continuous Bernoulli distribution between [0, 1], | |
which characterize the shape of the pdf. If the input data type is int or float, the data type of | |
`probability` will be convert to a 1-D Tensor the paddle global default dtype. | |
eps(float): Specify the bandwith of the unstable calculation region near 0.5. The unstable calculation region | |
would be [0.5 - eps, 0.5 + eps], where the calculation is approximated by talyor expansion. The | |
default value is 0.02. | |
Args: | |
probability(int|float|Tensor): The probability of Continuous Bernoulli distribution between [0, 1], | |
which characterize the shape of the pdf. If the input data type is int or float, the data type of | |
`probability` will be convert to a 1-D Tensor the paddle global default dtype. | |
eps(float): Specify the bandwith of the unstable calculation region near 0.5. The unstable calculation region | |
would be [0.5 - eps, 0.5 + eps], where the calculation is approximated by talyor expansion. The | |
default value is 0.02. |
对于 Args 下的每个参数,同一个参数的描述换行需要加下缩进
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
r"""The Continuous Bernoulli distribution with parameter: `probability` characterizing the shape of the density function. | ||
The Continuous Bernoulli distribution is defined on [0, 1], and it can be viewed as a continuous version of the Bernoulli distribution. | ||
|
||
[1] Loaiza-Ganem, G., & Cunningham, J. P. The continuous Bernoulli: fixing a pervasive error in variational autoencoders. 2019. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
能不能直接贴上论文的连接?引用方式参考 如何让文档相互引用;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
would be [0.5 - eps, 0.5 + eps], where the calculation is approximated by talyor expansion. The | ||
default value is 0.02. | ||
|
||
Examples: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
|
||
In the above equation: | ||
|
||
* :math:\Omega: is the support of the distribution. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* :math:\Omega: is the support of the distribution. | |
* :math:`\Omega` is the support of the distribution. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
python/paddle/distribution/kl.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这篇文档和 continuous_bernoulli.py
有同样的问题,不赘述了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
对应处已修改
抱歉,再补充个Comment,ContinuousBernoulli 签名和PyTorch保持一致 |
已修改 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM~
doc-preview CI 中发现新的 system message 错误, @ooooo-create 之后全量的再检查一遍相关错误并修复叭
[0.20103608, 0.07641447]) | ||
""" | ||
|
||
def __init__(self, probs=None, lims=(0.499, 0.501)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不用加None,probs是必选参数
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There are spelling and capitalization issues in the link of rfc, eg should be |
Links have been updated |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…PI (PaddlePaddle#58004) * add api and test * add kl-div registrition for cb and mvn * fix docs annd test * fix test * fix test * fix mvn test coverage * fix docs * update docs * update cb and mvn * fix mvn test * fix test * fix test * fix test * fix test * fix unstable region calculation * fix test * update dtype convertion and tests * fix test * fix test * fix test * refine docs * update docs * update docs * update docs * update cb api * increase cb static test timeout * fix test time * fix test * update cb
PR types
New features
PR changes
APIs
Description
Add ContinuousBernoulli and MultivariateNormal API