From 1ee344f584bbc412a3feefb98cc4fcaf42ad0b66 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Mon, 25 Jul 2022 13:57:18 +0000 Subject: [PATCH 1/2] add complex for einsum grad kernel --- paddle/phi/kernels/cpu/einsum_grad_kernel.cc | 10 ++++++++-- paddle/phi/kernels/cpu/tile_kernel.cc | 14 +++++++++++--- paddle/phi/kernels/gpu/einsum_grad_kernel.cu | 4 +++- paddle/phi/kernels/gpu/tile_kernel.cu | 4 +++- 4 files changed, 25 insertions(+), 7 deletions(-) diff --git a/paddle/phi/kernels/cpu/einsum_grad_kernel.cc b/paddle/phi/kernels/cpu/einsum_grad_kernel.cc index 2cfc2f92204fc..0e583f25edfbb 100644 --- a/paddle/phi/kernels/cpu/einsum_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/einsum_grad_kernel.cc @@ -18,5 +18,11 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/einsum_grad_impl.h" -PD_REGISTER_KERNEL( - einsum_grad, CPU, ALL_LAYOUT, phi::EinsumGradKernel, float, double) {} +PD_REGISTER_KERNEL(einsum_grad, + CPU, + ALL_LAYOUT, + phi::EinsumGradKernel, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/tile_kernel.cc b/paddle/phi/kernels/cpu/tile_kernel.cc index 3b590ed475aa2..2320c30310a64 100644 --- a/paddle/phi/kernels/cpu/tile_kernel.cc +++ b/paddle/phi/kernels/cpu/tile_kernel.cc @@ -18,6 +18,14 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/tile_kernel_impl.h" -PD_REGISTER_KERNEL( - tile, CPU, ALL_LAYOUT, phi::TileKernel, bool, float, double, int, int64_t) { -} +PD_REGISTER_KERNEL(tile, + CPU, + ALL_LAYOUT, + phi::TileKernel, + bool, + float, + double, + int, + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/einsum_grad_kernel.cu b/paddle/phi/kernels/gpu/einsum_grad_kernel.cu index a8464be3bb3c6..0be139721d464 100644 --- a/paddle/phi/kernels/gpu/einsum_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/einsum_grad_kernel.cu @@ -24,4 +24,6 @@ PD_REGISTER_KERNEL(einsum_grad, float, double, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/tile_kernel.cu b/paddle/phi/kernels/gpu/tile_kernel.cu index 990877a8445cb..ba598862f5910 100644 --- a/paddle/phi/kernels/gpu/tile_kernel.cu +++ b/paddle/phi/kernels/gpu/tile_kernel.cu @@ -28,4 +28,6 @@ PD_REGISTER_KERNEL(tile, int, int64_t, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} From 43226843efa22303f9f225b7f3bd029fb1d05f58 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Mon, 25 Jul 2022 15:32:21 +0000 Subject: [PATCH 2/2] pass the ci --- paddle/phi/kernels/funcs/eigen/broadcast.cc | 5 +++++ paddle/phi/kernels/funcs/eigen/broadcast.cu | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/paddle/phi/kernels/funcs/eigen/broadcast.cc b/paddle/phi/kernels/funcs/eigen/broadcast.cc index 008c51249f249..c806cdeaad60b 100644 --- a/paddle/phi/kernels/funcs/eigen/broadcast.cc +++ b/paddle/phi/kernels/funcs/eigen/broadcast.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/common/complex.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" @@ -75,6 +76,8 @@ struct EigenBroadcastGrad { INSTANTIATION(EigenBroadcast, bool); INSTANTIATION(EigenBroadcast, dtype::float16); INSTANTIATION(EigenBroadcast, dtype::bfloat16); +INSTANTIATION(EigenBroadcast, dtype::complex); +INSTANTIATION(EigenBroadcast, dtype::complex); INSTANTIATION(EigenBroadcast, float); INSTANTIATION(EigenBroadcast, double); INSTANTIATION(EigenBroadcast, int); @@ -82,6 +85,8 @@ INSTANTIATION(EigenBroadcast, int64_t); INSTANTIATION(EigenBroadcastGrad, bool); INSTANTIATION(EigenBroadcastGrad, float); INSTANTIATION(EigenBroadcastGrad, dtype::float16); +INSTANTIATION(EigenBroadcastGrad, dtype::complex); +INSTANTIATION(EigenBroadcastGrad, dtype::complex); INSTANTIATION(EigenBroadcastGrad, double); INSTANTIATION(EigenBroadcastGrad, int); INSTANTIATION(EigenBroadcastGrad, int64_t); diff --git a/paddle/phi/kernels/funcs/eigen/broadcast.cu b/paddle/phi/kernels/funcs/eigen/broadcast.cu index 742081a30c1a0..0b749f5c009a5 100644 --- a/paddle/phi/kernels/funcs/eigen/broadcast.cu +++ b/paddle/phi/kernels/funcs/eigen/broadcast.cu @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/common/complex.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" @@ -77,12 +78,16 @@ INSTANTIATION(EigenBroadcast, dtype::float16); INSTANTIATION(EigenBroadcast, dtype::bfloat16); INSTANTIATION(EigenBroadcast, float); INSTANTIATION(EigenBroadcast, double); +INSTANTIATION(EigenBroadcast, dtype::complex); +INSTANTIATION(EigenBroadcast, dtype::complex); INSTANTIATION(EigenBroadcast, int); INSTANTIATION(EigenBroadcast, int64_t); INSTANTIATION(EigenBroadcastGrad, bool); INSTANTIATION(EigenBroadcastGrad, float); INSTANTIATION(EigenBroadcastGrad, dtype::float16); INSTANTIATION(EigenBroadcastGrad, double); +INSTANTIATION(EigenBroadcastGrad, dtype::complex); +INSTANTIATION(EigenBroadcastGrad, dtype::complex); INSTANTIATION(EigenBroadcastGrad, int); INSTANTIATION(EigenBroadcastGrad, int64_t); template struct EigenBroadcastGrad;