Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PTen] Polish kernel register marco design #38078

Merged

Conversation

chenwhql
Copy link
Contributor

@chenwhql chenwhql commented Dec 13, 2021

PR types

Function optimization

PR changes

Others

Describe

[PTen] Polish kernel register marco design

整理了原先的注册宏设计,经过内部讨论,整体上提供的宏体系组织如下:

  1. 对于一个无模板的函数式Kernel,采用PT_REGISTER_KERNEL注册

用法示例如下:

  • PT_REGISTER_NO_TEMPLATE_KERNEL(reshape, CPU, ALL_LAYOUT, pten::Reshape, float)
  • PT_REGISTER_NO_TEMPLATE_KERNEL(reshape, CPU, ALL_LAYOUT, pten::Reshape, ALL_DTYPE)
  • PT_REGISTER_NO_TEMPLATE_KERNEL(reshape, ALL_BACKEND, ALL_LAYOUT, pten::Reshape, ALL_DTYPE)
  1. 对于有一个模板参数T(data type)的函数式Kernel,使用PT_REGISTER_KERNEL注册(这个宏在迁移期间使用频次最高,使用最标准的命名)

用法示例如下:

PT_REGISTER_KERNEL(sign, CPU, ALL_LAYOUT, pten::Sign, float, double) {}
PT_REGISTER_KERNEL(mean, CPU, ALL_LAYOUT, pten::Mean, float, double, bool) {}
PT_REGISTER_KERNEL(scale,
                   CPU,
                   ALL_LAYOUT,
                   pten::Scale,
                   float,
                   double,
                   paddle::platform::bfloat16,
                   uint8_t,
                   int8_t,
                   int16_t,
                   int,
                   int64_t) {}
PT_REGISTER_KERNEL(add,
                   CPU,
                   ALL_LAYOUT,
                   pten::ElementwiseAdd,
                   float,
                   double,
                   int,
                   int64_t,
                   complex64,
                   complex128) {}

通过这样的方式,减少注册代码编写时的函数名重复问题,使注册写法更简洁

  1. 对于有两个模板参数T(data type),ContextT(backend)的函数式kernel,使用PT_REGISTER_CTX_KERNEL注册(后续PR添加)

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@@ -30,6 +30,8 @@ using bfloat16 = ::paddle::platform::bfloat16;

enum class DataType {
UNDEFINED = 0,
// See Note [ Why we need ALL in baisc kernel key member? ]
ALL_DTYPE = UNDEFINED,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ALL_DTYPE = UNDEFINED会不会在判断相等时出现问题?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

比如说会出现什么问题

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

比如某个API接口有DataType参数,如果传进来的是ALL_DTYPE可能会当成UNDEFINED去处理

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个没有问题,ALL_DTYPE当成UNDEFINED处理就正常,ALL_DTYPE使注册时名字更清晰,但语义仍然是UNDEFINED

@@ -20,7 +20,10 @@ namespace experimental {

enum class DataLayout {
UNDEFINED = 0,
ANY,
// See Note [ Why we need ALL in baisc kernel key member? ]
ALL_LAYOUT = UNDEFINED,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

pten::ReshapeFromVectorValWithXShape,
ALL_DTYPE) {}
PT_REGISTER_NO_TEMPLATE_KERNEL(
reshape_host, CPU, ALL_LAYOUT, pten::ReshapeFromDT, ALL_DTYPE) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NO_TEMPLATE_KERNEL可不可以把ALL_LAYOUT和ALL_DTYPE直接置为默认值,不用手写?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那需要另外一个宏,PT_REGISTER_NO_TEMPLATE_KERNEL这里并不一定是ALL,例如PT_REGISTER_NO_TEMPLATE_KERNEL(reshape_host, CPU, NCHW, pten::ReshapeFromDT, FLOAT))

ANY,
pten::ReshapeFromVectorDT) {
PT_REGISTER_NO_TEMPLATE_KERNEL(
reshape_mulhost, CPU, ALL_LAYOUT, pten::ReshapeFromVectorDT, ALL_DTYPE) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

YuanRisheng
YuanRisheng previously approved these changes Dec 14, 2021
cpp_dtype, \
__VA_ARGS__); \
void __PT_KERNEL_args_def_FN_##kernel_name(::pten::Kernel* kernel)
#define _PT_REGISTER_1TA_KERNEL( \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是否需要加注释解释一下"1TA"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

@chenwhql chenwhql merged commit c9da845 into PaddlePaddle:develop Dec 14, 2021
Caozhou1995 pushed a commit to Caozhou1995/Paddle that referenced this pull request Dec 29, 2021
* polish register marco

* resolve compile failed

* revert needless change

* revert eager related change

* revert eager related change

* change register marco name

* polish deetails
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants