-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon No.10】新增 LogNormal API #46426
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
|
||
|
||
class LogNormal(TransformedDistribution): | ||
r"""The Normal distribution with location `loc` and `scale` parameters. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LogNormal
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
In the above equation: | ||
|
||
* :math:`loc = \mu`: is the means of the underlying Normal distribution. | ||
* :math:`scale = \sigma`: is the stddevs of the underlying Normal distribution. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LogNormal
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是指LogNormal的基础分布
Args: | ||
loc(int|float|list|tuple|numpy.ndarray|Tensor): The means of normal distribution. | ||
scale(int|float|list|tuple|numpy.ndarray|Tensor): The stddevs of normal distribution. | ||
name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
""" | ||
return self.sample(shape, seed) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rsample不能直接调用sample,sample不支持反向,需要通过标准正态分布重参数化实现。如果目前Paddle现有功能确实无法支持实现Normal sample的重参数化,可以暂时raise NotImplementedError
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
新的commit补充了rsample的实现,但是我不确定实现是否正确,想请教一下,应该用什么方法验证rsample支持反向呢?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
param = paddle.rand(...)
param.stop_gradient = False
d = paddle.distribution.xxx(param = param)
y = d.rsample(...)
paddle.grad(y, param)
可以新提交个PR进行验证,补充相应测试用例
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,谢谢您
@@ -63,7 +63,7 @@ def __init__(self, base, transforms): | |||
chain = transform.ChainTransform(transforms) | |||
if len(base.batch_shape + base.event_shape) < chain._domain.event_rank: | |||
raise ValueError( | |||
f"'base' needs to have shape with size at least {chain._domain.event_rank}, bug got {len(base_shape)}." | |||
f"'base' needs to have shape with size at least {chain._domain.event_rank}, bug got {len(base.batch_shape + base.event_shape)}." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but got
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
TODO:增加rsample测试用例
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good job!LGTM for docs
Args: | ||
loc(int|float|list|tuple|numpy.ndarray|Tensor): The means of the underlying Normal distribution. | ||
scale(int|float|list|tuple|numpy.ndarray|Tensor): The stddevs of the underlying Normal distribution. | ||
name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要加name参数吗?之前好像说是不用的?@cxxly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不用,name参数是为了追踪静态图下运行过程,每个方法也要处理,是一个通用逻辑,后面我会统一处理。此处可以先删除
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
嗯嗯好的,那辛苦把这个参数去掉吧 @MayYouBeProsperous
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for docs
PR types
New features
PR changes
APIs
Describe
新增 LogNormal API