diff --git a/paddle/pten/common/backend.h b/paddle/pten/common/backend.h index 622fb7d0258282..95bbc88681a965 100644 --- a/paddle/pten/common/backend.h +++ b/paddle/pten/common/backend.h @@ -39,6 +39,21 @@ namespace experimental { enum class Backend : uint8_t { UNDEFINED = 0, + // basic kernel backend + CPU, + + // various acceleration devices' backends + CUDA, + XPU, // XPU currently does not exist at the same time as CUDA + NPU, // NPU currently does not exist at the same time as CUDA + + // the third library backend + MKLDNN, + CUDNN, + + // end of backend types + NUM_BACKENDS, + /** * [ Why we need ALL in baisc kernel key member? ] * @@ -53,12 +68,12 @@ enum class Backend : uint8_t { * so if we provide the ALL field with Register the kernel in this statement. * * Of course, we have also considered solving this problem through different - * named macros, for example, we define + * named macros, for example, if we define * * PT_REGISTER_KERNEL_FOR_ALL_BACKEND * - * However, dtype and layout also have the same requirements, we need to - * define a series of macros + * Based on this design pattern, the dtype and layout also have the same + * requirements, this cause we need to define a series of macros * * PT_REGISTER_KERNEL_FOR_ALL_DTYPE * PT_REGISTER_KERNEL_FOR_ALL_LAYOUT @@ -67,28 +82,13 @@ enum class Backend : uint8_t { * PT_REGISTER_KERNEL_FOR_ALL_LAYOUT_AND_DTYPE * PT_REGISTER_KERNEL_FOR_ALL_BACKEND_AND_LAYOUT_AND_DTYPE * - * This makes the system of registering macros more complicated, we think + * It makes the system of registering macros more complicated, we think * this is not a simple design, so we still adopt the design of providing * the ALL field. * * Note: ALL_BACKEND only used for Kernel registration and selection */ ALL_BACKEND = UNDEFINED, - - // basic kernel backend - CPU, - - // various acceleration devices' backends - CUDA, - XPU, // XPU currently does not exist at the same time as CUDA - NPU, // NPU currently does not exist at the same time as CUDA - - // the third library backend - MKLDNN, - CUDNN, - - // end of backend types - NUM_BACKENDS, }; inline std::ostream& operator<<(std::ostream& os, Backend backend) { diff --git a/paddle/pten/common/data_type.h b/paddle/pten/common/data_type.h index 8868d8475a55c5..a00d68c535415d 100644 --- a/paddle/pten/common/data_type.h +++ b/paddle/pten/common/data_type.h @@ -30,8 +30,6 @@ using bfloat16 = ::paddle::platform::bfloat16; enum class DataType { UNDEFINED = 0, - // See Note [ Why we need ALL in baisc kernel key member? ] - ALL_DTYPE = UNDEFINED, BOOL, INT8, // Char UINT8, // BYte @@ -47,7 +45,9 @@ enum class DataType { FLOAT64, COMPLEX64, COMPLEX128, - NUM_DATA_TYPES + NUM_DATA_TYPES, + // See Note [ Why we need ALL in baisc kernel key member? ] + ALL_DTYPE = UNDEFINED, }; inline size_t SizeOf(DataType data_type) { diff --git a/paddle/pten/common/layout.h b/paddle/pten/common/layout.h index 99ad88a23233b0..b7c151e7e6a7c8 100644 --- a/paddle/pten/common/layout.h +++ b/paddle/pten/common/layout.h @@ -20,14 +20,14 @@ namespace experimental { enum class DataLayout { UNDEFINED = 0, - // See Note [ Why we need ALL in baisc kernel key member? ] - ALL_LAYOUT = UNDEFINED, // TODO(chenweihang): keep ANY for compatibility, remove it later ANY = UNDEFINED, NHWC, NCHW, MKLDNN, NUM_DATA_LAYOUTS, + // See Note [ Why we need ALL in baisc kernel key member? ] + ALL_LAYOUT = UNDEFINED, }; inline std::ostream& operator<<(std::ostream& os, DataLayout layout) { diff --git a/paddle/pten/core/kernel_registry.h b/paddle/pten/core/kernel_registry.h index 9f198b7c1fb3fd..8a58ba060d1a5e 100644 --- a/paddle/pten/core/kernel_registry.h +++ b/paddle/pten/core/kernel_registry.h @@ -209,7 +209,7 @@ struct KernelRegistrar { * pointer of the corresponding data type is automatically instantiated * during registration. * - * Note: If needed, add more marco to support 2 template arguments deduce + * Note: `1TA` means `1 template argument` */ #define PT_REGISTER_KERNEL( \ kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \