From ba8826574a2bbda2b4bb01ff795f680349d9c07b Mon Sep 17 00:00:00 2001 From: hong <43953930+phlrain@users.noreply.github.com> Date: Tue, 8 Feb 2022 18:22:44 +0800 Subject: [PATCH] Update op support gpu impl (#39386) * find gpu kernel in pten factory; test=develop * check in functional kernel first; test=develop --- paddle/fluid/framework/operator.cc | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 5ab14a1daba226..0f558b46872a2d 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -32,6 +32,7 @@ limitations under the License. */ #include "paddle/fluid/platform/profiler.h" #include "paddle/pten/common/scalar.h" #include "paddle/pten/common/scalar_array.h" +#include "paddle/pten/core/kernel_factory.h" #include "paddle/pten/ops/compat/signatures.h" namespace pten { @@ -598,6 +599,17 @@ std::vector ExecutionContext::MultiOutput( } bool OpSupportGPU(const std::string& op_type) { + // check in new Function kernel first + auto& kernel_factory = pten::KernelFactory::Instance(); + auto kernel_key_map = + kernel_factory.SelectKernelMap(pten::TransToPtenKernelName(op_type)); + for (auto& kernel : kernel_key_map) { + if (platform::is_gpu_place( + pten::TransToFluidPlace(kernel.first.backend()))) { + return true; + } + } + auto& all_kernels = OperatorWithKernel::AllOpKernels(); auto it = all_kernels.find(op_type); if (it == all_kernels.end()) { @@ -609,6 +621,7 @@ bool OpSupportGPU(const std::string& op_type) { return true; } } + return false; }