-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PTen] Polish kernel register marco design #38078
[PTen] Polish kernel register marco design #38078
Conversation
Thanks for your contribution! |
paddle/pten/common/data_type.h
Outdated
@@ -30,6 +30,8 @@ using bfloat16 = ::paddle::platform::bfloat16; | |||
|
|||
enum class DataType { | |||
UNDEFINED = 0, | |||
// See Note [ Why we need ALL in baisc kernel key member? ] | |||
ALL_DTYPE = UNDEFINED, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ALL_DTYPE = UNDEFINED
会不会在判断相等时出现问题?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
比如说会出现什么问题
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
比如某个API接口有DataType参数,如果传进来的是ALL_DTYPE可能会当成UNDEFINED去处理
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个没有问题,ALL_DTYPE当成UNDEFINED处理就正常,ALL_DTYPE使注册时名字更清晰,但语义仍然是UNDEFINED
paddle/pten/common/layout.h
Outdated
@@ -20,7 +20,10 @@ namespace experimental { | |||
|
|||
enum class DataLayout { | |||
UNDEFINED = 0, | |||
ANY, | |||
// See Note [ Why we need ALL in baisc kernel key member? ] | |||
ALL_LAYOUT = UNDEFINED, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
pten::ReshapeFromVectorValWithXShape, | ||
ALL_DTYPE) {} | ||
PT_REGISTER_NO_TEMPLATE_KERNEL( | ||
reshape_host, CPU, ALL_LAYOUT, pten::ReshapeFromDT, ALL_DTYPE) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NO_TEMPLATE_KERNEL可不可以把ALL_LAYOUT和ALL_DTYPE直接置为默认值,不用手写?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
那需要另外一个宏,PT_REGISTER_NO_TEMPLATE_KERNEL这里并不一定是ALL,例如PT_REGISTER_NO_TEMPLATE_KERNEL(reshape_host, CPU, NCHW, pten::ReshapeFromDT, FLOAT))
ANY, | ||
pten::ReshapeFromVectorDT) { | ||
PT_REGISTER_NO_TEMPLATE_KERNEL( | ||
reshape_mulhost, CPU, ALL_LAYOUT, pten::ReshapeFromVectorDT, ALL_DTYPE) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
cpp_dtype, \ | ||
__VA_ARGS__); \ | ||
void __PT_KERNEL_args_def_FN_##kernel_name(::pten::Kernel* kernel) | ||
#define _PT_REGISTER_1TA_KERNEL( \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是否需要加注释解释一下"1TA"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thx
d2bdad9
* polish register marco * resolve compile failed * revert needless change * revert eager related change * revert eager related change * change register marco name * polish deetails
PR types
Function optimization
PR changes
Others
Describe
[PTen] Polish kernel register marco design
整理了原先的注册宏设计,经过内部讨论,整体上提供的宏体系组织如下:
PT_REGISTER_KERNEL
注册用法示例如下:
PT_REGISTER_NO_TEMPLATE_KERNEL(reshape, CPU, ALL_LAYOUT, pten::Reshape, float)
PT_REGISTER_NO_TEMPLATE_KERNEL(reshape, CPU, ALL_LAYOUT, pten::Reshape, ALL_DTYPE)
PT_REGISTER_NO_TEMPLATE_KERNEL(reshape, ALL_BACKEND, ALL_LAYOUT, pten::Reshape, ALL_DTYPE)
PT_REGISTER_KERNEL
注册(这个宏在迁移期间使用频次最高,使用最标准的命名)用法示例如下:
PT_REGISTER_CTX_KERNEL
注册(后续PR添加)