Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon No.10】新增 LogNormal API #46426

Merged
merged 44 commits into from
Oct 10, 2022
Merged

【Hackathon No.10】新增 LogNormal API #46426

merged 44 commits into from
Oct 10, 2022

Conversation

MayYouBeProsperous
Copy link
Contributor

@MayYouBeProsperous MayYouBeProsperous commented Sep 22, 2022

@paddle-bot
Copy link

paddle-bot bot commented Sep 22, 2022

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added contributor External developers status: proposed labels Sep 22, 2022
@MayYouBeProsperous MayYouBeProsperous changed the title 【Hackathon No.10】新增 LogNormal API [WIP]【Hackathon No.10】新增 LogNormal API Sep 23, 2022
@MayYouBeProsperous MayYouBeProsperous changed the title [WIP]【Hackathon No.10】新增 LogNormal API 【Hackathon No.10】新增 LogNormal API Sep 23, 2022


class LogNormal(TransformedDistribution):
r"""The Normal distribution with location `loc` and `scale` parameters.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LogNormal

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

In the above equation:

* :math:`loc = \mu`: is the means of the underlying Normal distribution.
* :math:`scale = \sigma`: is the stddevs of the underlying Normal distribution.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LogNormal

Copy link
Contributor Author

@MayYouBeProsperous MayYouBeProsperous Sep 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是指LogNormal的基础分布

Args:
loc(int|float|list|tuple|numpy.ndarray|Tensor): The means of normal distribution.
scale(int|float|list|tuple|numpy.ndarray|Tensor): The stddevs of normal distribution.
name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


"""
return self.sample(shape, seed)

Copy link
Contributor

@cxxly cxxly Sep 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rsample不能直接调用sample,sample不支持反向,需要通过标准正态分布重参数化实现。如果目前Paddle现有功能确实无法支持实现Normal sample的重参数化,可以暂时raise NotImplementedError

Copy link
Contributor Author

@MayYouBeProsperous MayYouBeProsperous Sep 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

新的commit补充了rsample的实现,但是我不确定实现是否正确,想请教一下,应该用什么方法验证rsample支持反向呢?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

param = paddle.rand(...)
param.stop_gradient = False
d = paddle.distribution.xxx(param = param)
y = d.rsample(...)
paddle.grad(y, param)

可以新提交个PR进行验证,补充相应测试用例

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,谢谢您

@@ -63,7 +63,7 @@ def __init__(self, base, transforms):
chain = transform.ChainTransform(transforms)
if len(base.batch_shape + base.event_shape) < chain._domain.event_rank:
raise ValueError(
f"'base' needs to have shape with size at least {chain._domain.event_rank}, bug got {len(base_shape)}."
f"'base' needs to have shape with size at least {chain._domain.event_rank}, bug got {len(base.batch_shape + base.event_shape)}."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but got

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

cxxly
cxxly previously approved these changes Sep 30, 2022
Copy link
Contributor

@cxxly cxxly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
TODO:增加rsample测试用例

Ligoml
Ligoml previously approved these changes Oct 8, 2022
Copy link
Contributor

@Ligoml Ligoml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good job!LGTM for docs

Args:
loc(int|float|list|tuple|numpy.ndarray|Tensor): The means of the underlying Normal distribution.
scale(int|float|list|tuple|numpy.ndarray|Tensor): The stddevs of the underlying Normal distribution.
name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要加name参数吗?之前好像说是不用的?@cxxly

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不用,name参数是为了追踪静态图下运行过程,每个方法也要处理,是一个通用逻辑,后面我会统一处理。此处可以先删除

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯嗯好的,那辛苦把这个参数去掉吧 @MayYouBeProsperous

Copy link
Contributor Author

@MayYouBeProsperous MayYouBeProsperous Oct 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Ligoml @cxxly 已经删除~ 麻烦再次review

@MayYouBeProsperous MayYouBeProsperous dismissed stale reviews from Ligoml and cxxly via 4745c39 October 9, 2022 06:38
Copy link
Contributor

@jeff41404 jeff41404 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@Ligoml Ligoml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for docs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants