Skip to content

Commit

Permalink
Register custom kernel for some all_bakcend kernel (#51639)
Browse files Browse the repository at this point in the history
* register some custom kernel

* fix bug
  • Loading branch information
zyfncg authored Mar 20, 2023
1 parent 1d5cad2 commit e8530a3
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 2 deletions.
18 changes: 18 additions & 0 deletions paddle/phi/kernels/cpu/numel_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,21 @@ PD_REGISTER_KERNEL(numel,
bool) {
kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
}

#ifdef PADDLE_WITH_CUSTOM_DEVICE
PD_REGISTER_KERNEL(numel,
Custom,
ALL_LAYOUT,
phi::NumelKernel,
uint8_t,
int16_t,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16,
float,
double,
bool) {
kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
}
#endif
15 changes: 15 additions & 0 deletions paddle/phi/kernels/flatten_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,18 @@ PD_REGISTER_KERNEL(flatten_grad,
int64_t) {}

#endif

#ifdef PADDLE_WITH_CUSTOM_DEVICE
PD_REGISTER_KERNEL(flatten_grad,
Custom,
ALL_LAYOUT,
phi::FlattenGradKernel,
float,
phi::dtype::float16,
double,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
#endif
28 changes: 28 additions & 0 deletions paddle/phi/kernels/flatten_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,31 @@ PD_REGISTER_KERNEL(flatten,
int,
int64_t) {}
#endif

#ifdef PADDLE_WITH_CUSTOM_DEVICE
PD_REGISTER_KERNEL(flatten_infer,
Custom,
ALL_LAYOUT,
phi::FlattenInferKernel,
float,
phi::dtype::float16,
double,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}

PD_REGISTER_KERNEL(flatten,
Custom,
ALL_LAYOUT,
phi::FlattenKernel,
float,
phi::dtype::float16,
double,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
#endif
13 changes: 13 additions & 0 deletions paddle/phi/kernels/reshape_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,16 @@ PD_REGISTER_GENERAL_KERNEL(reshape_double_grad,
phi::ReshapeDoubleGradKernel<phi::XPUContext>,
ALL_DTYPE) {}
#endif

#ifdef PADDLE_WITH_CUSTOM_DEVICE
PD_REGISTER_GENERAL_KERNEL(reshape_grad,
Custom,
ALL_LAYOUT,
phi::ReshapeGradKernel<phi::CustomContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(reshape_double_grad,
Custom,
ALL_LAYOUT,
phi::ReshapeDoubleGradKernel<phi::CustomContext>,
ALL_DTYPE) {}
#endif
13 changes: 13 additions & 0 deletions paddle/phi/kernels/reshape_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,16 @@ PD_REGISTER_GENERAL_KERNEL(reshape_infer,
PD_REGISTER_GENERAL_KERNEL(
reshape, XPU, ALL_LAYOUT, phi::ReshapeKernel<phi::XPUContext>, ALL_DTYPE) {}
#endif

#ifdef PADDLE_WITH_CUSTOM_DEVICE
PD_REGISTER_GENERAL_KERNEL(reshape_infer,
Custom,
ALL_LAYOUT,
phi::ReshapeInferKernel<phi::CustomContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(reshape,
Custom,
ALL_LAYOUT,
phi::ReshapeKernel<phi::CustomContext>,
ALL_DTYPE) {}
#endif
32 changes: 30 additions & 2 deletions paddle/phi/kernels/selected_rows/shape_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,11 @@ PD_REGISTER_KERNEL(shape_sr,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
phi::dtype::complex<double>) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
kernel->OutputAt(0).SetBackend(phi::Backend::CPU);
kernel->OutputAt(0).SetDataType(phi::DataType::INT32);
}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(shape_sr,
Expand All @@ -60,5 +64,29 @@ PD_REGISTER_KERNEL(shape_sr,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
phi::dtype::complex<double>) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
kernel->OutputAt(0).SetBackend(phi::Backend::CPU);
kernel->OutputAt(0).SetDataType(phi::DataType::INT32);
}
#endif

#ifdef PADDLE_WITH_CUSTOM_DEVICE
PD_REGISTER_KERNEL(shape_sr,
Custom,
ALL_LAYOUT,
phi::sr::ShapeKernel,
bool,
int,
int8_t,
uint8_t,
int64_t,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
kernel->OutputAt(0).SetBackend(phi::Backend::CPU);
kernel->OutputAt(0).SetDataType(phi::DataType::INT32);
}
#endif
21 changes: 21 additions & 0 deletions paddle/phi/kernels/shape_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,24 @@ PD_REGISTER_KERNEL(shape,
kernel->OutputAt(0).SetDataType(phi::DataType::INT32);
}
#endif

#ifdef PADDLE_WITH_CUSTOM_DEVICE
PD_REGISTER_KERNEL(shape,
Custom,
ALL_LAYOUT,
phi::ShapeKernel,
bool,
int,
int8_t,
uint8_t,
int64_t,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
phi::dtype::float16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
kernel->OutputAt(0).SetBackend(phi::Backend::CPU);
kernel->OutputAt(0).SetDataType(phi::DataType::INT32);
}
#endif

0 comments on commit e8530a3

Please sign in to comment.