-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Fluid] move lars_momentum to phi #55798
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
// {LARS_BLOCK_SIZE << 1}, cuda_ctx); | ||
phi::DenseTensor tmp_buffer_t; | ||
tmp_buffer_t.Resize({LARS_BLOCK_SIZE << 1}); | ||
MT* p_buffer = dev_ctx.template Alloc<MT>(tmp_buffer_t); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要研发大哥帮忙看看
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要研发大哥帮忙看看
Alloc应该传指针
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
emmm, 我的意思是AllocateTmpTensor
在phi
下是这么用的嘛
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
emmm, 我的意思是
AllocateTmpTensor
在phi
下是这么用的嘛
AllocateTmpTensor实现的逻辑就是Resize+Alloc,按照你当前用法就可以,phi下ctx不一样了,不支持直接调用AllocateTmpTensor
看一下PR-CI-Windows-OPENBLAS的报错,全量日志的15598行。看着是重载冲突,显式指定一下 |
看样子是 |
@hitywt 麻烦 review 一下 |
@@ -312,7 +312,7 @@ def setUp(self): | |||
|
|||
def test_check_output(self): | |||
paddle.enable_static() | |||
self.check_output() | |||
self.check_output(check_dygraph=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为什么要改成False?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
此算子没有动态图模式(在python中没找到相关api),C++算子仅用于静态图。lars_momentum单测原本通过paddle.enable_static()设置只测试静态图。本PR将算子迁移到phi后,触发单测框架动态图测试拦截,因而在TestLarsMomentumOp中通过设置check_dygraph=False保持静态图测试。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@gouzil lars_momentum 需要补充 xpu kernel 和 infershape 么? |
* [Fluid] move lars_momentum to phi * add sig * fix optional Output * off check_dygraph * fix input * fix operator[] * fix * try fix AllocateTmpTensor * fix * fix type * Update paddle/phi/kernels/gpu/lars_momentum_kernel.cu * fix type * rollback * Add Registration * try fix win * try fix win * try use double * try use operator *(float,const Derived &) * try auto * fix * fix * fix * fix dtype * fix type * fix index
PR types
Others
PR changes
Others
Description
将 lars_momentum 迁移到 PHI 下
相关issue: