-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon No.16】add PoissonNLLLoss API #51117
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
❌ The PR is not created using PR's template. You can refer to this Demo. |
91ed3bc
to
f570567
Compare
f570567
to
53d9a69
Compare
for r in res: | ||
np.allclose(out_ref, r, rtol=1e-5) | ||
|
||
def test_api(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不同case的测试分成多个test class吧,方便后续定位具体是哪个case异常
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经对单元测试进行了拆分
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个单测有点分的太细了,动态图和静态图可以整合到一个class下面的两个方法,然后再不同的case分成不同的class, test error的也可以整合到一个class,不同的方法测试不同的error
See more detail in :ref:`NLLLoss <api_paddle_nn_PoissonNLLLoss>` . | ||
Parameters:: | ||
|
||
input (Tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tensor 类型的输入都说明一下支持的数据类型
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经添加
self.test_dynamic_case('float64', full=True) | ||
self.test_dynamic_case(log_input=False, full=True) | ||
self.test_static_case(full=True, reduction='none') | ||
self.test_dynamic_case('float64', full=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为什么不测 float16,int16,in32
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
poisson_nll_loss当前只支持float32和float64。paddle.poisson_nll_loss依赖exp算子和log算子,exp算子当前不支持float16和int 16,log算子不支持int32,在所依赖的算子支持对应数据类型后我们将添加相应支持。
53d9a69
to
1df68bb
Compare
@GGBond8488 对问题进行了修改,辛苦review |
python/paddle/nn/functional/loss.py
Outdated
label, | ||
log_input=True, | ||
full=False, | ||
eps=1e-8, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已根据意见进行修改
python/paddle/nn/functional/loss.py
Outdated
If ``True``, the Stirling approximation term is added. | ||
If ``False``, the Stirling approximation is dropped. | ||
Default: ``False``. | ||
eps (float, optioonal): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
文档里面的说明需要同步修改(eps)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同步更新了API文档和doc string。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
以及rfc设计文档也提一个pr修改吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经提交pr。
python/paddle/nn/functional/loss.py
Outdated
) | ||
# check input dtype and dimension | ||
if not in_dygraph_mode(): | ||
check_variable_and_dtype( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check_variable_and_dtype 会自动跳过动态图,所以不用增加这个分支,让代码更加动静统一
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
去除了判断分支
python/paddle/nn/functional/loss.py
Outdated
check_variable_and_dtype( | ||
input, | ||
'input', | ||
['int32', 'int64', 'float32', 'float64'], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docstr说仅支持float32以及float64,这里为什么还会支持int32以及int64呢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
修复了数据类型支持的不一致
python/paddle/nn/functional/loss.py
Outdated
check_variable_and_dtype( | ||
label, | ||
'label', | ||
['int32', 'int64', 'float32', 'float64'], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
python/paddle/nn/layer/loss.py
Outdated
If ``True``, the Stirling approximation term is added. | ||
If ``False``, the Stirling approximation is dropped. | ||
Default: ``False``. | ||
eps (float, optioonal): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同eps使用全称
c7c4aa7
to
a8b159d
Compare
@GGBond8488 辛苦review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
最后一个小问题,麻烦修改~
python/paddle/nn/functional/loss.py
Outdated
It's data type should be float32, float64. | ||
log_input (bool, optional): | ||
Whether to the treat input tensor as log input. | ||
If ``True`` the loss is computed as,:math:`\exp(\text{input}) - \text{label} * \text{input}`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in our new commit, thanks for your kind review : )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
#51117 (comment) 需要处理下 |
We follow the suggesstions to support the mentioned data dtypes of float16 and blfoat16, but just adding support in What we did?
What we find?Our unittests on these newly supported data types failed both in 733: test_poisson_nll_loss failed Next StepAs we are unfamilar with the implementation of |
Please give the code |
The code is put in another branch, you can see it here |
|
d724acb
to
c49854e
Compare
We have supported these data types in our new commit, but due to our local test is not running on CPU compiled version, these tests can only be checked in CI. |
This pr is ready to be reviewed @luotao1 . |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
hi, @LyndonKong
|
PR types
New features
PR changes
APIs
Describe
Add paddle.nn.PoissonNLLLoss and paddle.nn.functional.poisson_nll_loss
document pr: PaddlePaddle/docs#5675
rfc pr1: PaddlePaddle/community#395
rfc pr2: PaddlePaddle/community#463