Skip to content

Commit

Permalink
Update op support gpu impl (#39386)
Browse files Browse the repository at this point in the history
* find gpu kernel in pten factory; test=develop

* check in functional kernel first; test=develop
  • Loading branch information
phlrain authored Feb 8, 2022
1 parent 196dbfc commit ba88265
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ limitations under the License. */
#include "paddle/fluid/platform/profiler.h"
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/kernel_factory.h"
#include "paddle/pten/ops/compat/signatures.h"

namespace pten {
Expand Down Expand Up @@ -598,6 +599,17 @@ std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
}

bool OpSupportGPU(const std::string& op_type) {
// check in new Function kernel first
auto& kernel_factory = pten::KernelFactory::Instance();
auto kernel_key_map =
kernel_factory.SelectKernelMap(pten::TransToPtenKernelName(op_type));
for (auto& kernel : kernel_key_map) {
if (platform::is_gpu_place(
pten::TransToFluidPlace(kernel.first.backend()))) {
return true;
}
}

auto& all_kernels = OperatorWithKernel::AllOpKernels();
auto it = all_kernels.find(op_type);
if (it == all_kernels.end()) {
Expand All @@ -609,6 +621,7 @@ bool OpSupportGPU(const std::string& op_type) {
return true;
}
}

return false;
}

Expand Down

0 comments on commit ba88265

Please sign in to comment.