Skip to content

Commit

Permalink
[PluggableDevice] custom kernel supports multi cpp_dtype registering (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Aganlengzi authored Feb 10, 2022
1 parent 2a5d858 commit 63d2333
Show file tree
Hide file tree
Showing 3 changed files with 695 additions and 63 deletions.
114 changes: 76 additions & 38 deletions paddle/fluid/framework/custom_kernel_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,12 @@ limitations under the License. */
// user kernel function
namespace custom_kernel {

// Here we use dot <CPU, ANY, UINT8> for test
// This test will fail when these two kernels are aupported in framework
// Here we use fake_dot for test
// input 3: two Tensors and one std::vector<Tensor>
// attribute 11: fake_attributes
// output 2: one Tensor* and one std::vector<Tensor*>
template <typename T>
void FakeDot(const paddle::CPUContext& dev_ctx, const paddle::Tensor& x,
template <typename T, typename Context>
void FakeDot(const Context& dev_ctx, const paddle::Tensor& x,
const paddle::Tensor& y,
const std::vector<paddle::Tensor>& fake_input_vec,
bool fake_attr_bool, int fake_attr_int, float fake_attr_float,
Expand Down Expand Up @@ -93,53 +92,91 @@ void FakeDot(const paddle::CPUContext& dev_ctx, const paddle::Tensor& x,
}
} // namespace custom_kernel

PD_REGISTER_KERNEL(dot, CPU, ALL_LAYOUT, UINT8,
custom_kernel::FakeDot<uint8_t>) {
/* do some args define here
* the only param can be used is OpKernelInfo* kernel */
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UINT8);
}
PD_REGISTER_KERNEL(fake_dot, CPU, ALL_LAYOUT, custom_kernel::FakeDot, float,
double, int, int64_t, int8_t, uint8_t) {}

// Upper code will store dot kernels info into OpKernelInfoMap
TEST(CustomKernel, custom_kernel_dot) {
std::string op_name = "dot";
std::string op_name = "fake_dot";
pten::Backend backend = pten::Backend::CPU;
pten::DataLayout layout = pten::DataLayout::ANY;
pten::DataType dtype = pten::DataType::UINT8;
pten::DataLayout layout = pten::DataLayout::ALL_LAYOUT;

// 1.custom kernel info parsed and store
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance().GetMap().find("dot") !=
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance().GetMap().find(op_name) !=
paddle::OpKernelInfoMap::Instance().GetMap().end());

// 2.info check
EXPECT_EQ(
1, static_cast<int>(paddle::OpKernelInfoMap::Instance()["dot"].size()));
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()["dot"][0].GetBackend() ==
6, static_cast<int>(paddle::OpKernelInfoMap::Instance()[op_name].size()));
// index 0
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][0].GetBackend() ==
backend);
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()["dot"][0].GetDataLayout() ==
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][0].GetDataLayout() ==
layout);
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()["dot"][0].GetDataType() ==
dtype);

// 3.register
EXPECT_TRUE(pten::KernelFactory::Instance().kernels().end() !=
pten::KernelFactory::Instance().kernels().find("dot"));

pten::KernelKey kernel_key(backend, layout, dtype);
EXPECT_TRUE(
pten::KernelFactory::Instance().kernels()["dot"].find(kernel_key) ==
pten::KernelFactory::Instance().kernels()["dot"].end());

EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][0].GetDataType() ==
pten::DataType::FLOAT32);
// index 5
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][5].GetBackend() ==
backend);
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][5].GetDataLayout() ==
layout);
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][5].GetDataType() ==
pten::DataType::UINT8);

// 3.before register
auto& kernel_factory_instance = pten::KernelFactory::Instance();
auto& kernels = pten::KernelFactory::Instance().kernels();
EXPECT_TRUE(!kernel_factory_instance.HasCompatiblePtenKernel(op_name));

// mock fake_dot is supported by pten for HasCompatiblePtenKernel check while
// registering
auto& fake_dot_kernels = kernels[op_name];

EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::FLOAT32)) ==
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::FLOAT64)) ==
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::INT32)) ==
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::INT64)) ==
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::INT8)) ==
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::UINT8)) ==
fake_dot_kernels.end());

// register
paddle::framework::RegisterKernelWithMetaInfoMap(
paddle::OpKernelInfoMap::Instance());

EXPECT_TRUE(
pten::KernelFactory::Instance().kernels()["dot"].find(kernel_key) !=
pten::KernelFactory::Instance().kernels()["dot"].end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::FLOAT32)) !=
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::FLOAT64)) !=
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::INT32)) !=
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::INT64)) !=
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::INT8)) !=
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::UINT8)) !=
fake_dot_kernels.end());

// 4.kernel select
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
op_name, kernel_key);
auto kernel = kernel_factory_instance.SelectKernelOrThrowError(
op_name, pten::KernelKey(backend, layout, pten::DataType::UINT8));

// 5.prepare parameters for kernel
const auto alloc = std::make_unique<paddle::experimental::DefaultAllocator>(
Expand Down Expand Up @@ -252,10 +289,10 @@ TEST(CustomKernel, custom_kernel_dot) {
// test OpKernelInfoHelper
TEST(OpKernelInfoHelper, op_kernel_info_help_getters) {
using OpKernelInfoHelper = paddle::framework::OpKernelInfoHelper;
std::string op_name = "dot";
std::string op_name = "fake_dot";
pten::Backend backend = pten::Backend::CPU;
pten::DataLayout layout = pten::DataLayout::ANY;
pten::DataType dtype = pten::DataType::UINT8;
pten::DataType dtype = pten::DataType::FLOAT32;

auto op_kernel_info = paddle::OpKernelInfoMap::Instance()[op_name][0];

Expand All @@ -268,10 +305,11 @@ TEST(OpKernelInfoHelper, op_kernel_info_help_getters) {
OpKernelInfoHelper::GetKernelKey(op_kernel_info));

paddle::CustomKernelFunc kernel_fn =
PD_PT_KERNEL(custom_kernel::FakeDot<uint8_t>);
PD_PT_KERNEL(custom_kernel::FakeDot<float, paddle::CPUContext>);
EXPECT_EQ(kernel_fn, OpKernelInfoHelper::GetKernelFn(op_kernel_info));

void* variadic_func = PD_PT_VARIADIC_KERNEL(custom_kernel::FakeDot<uint8_t>);
void* variadic_func =
PD_PT_VARIADIC_KERNEL(custom_kernel::FakeDot<float, paddle::CPUContext>);
EXPECT_EQ(variadic_func,
OpKernelInfoHelper::GetVariadicKernelFn(op_kernel_info));

Expand Down
Loading

0 comments on commit 63d2333

Please sign in to comment.