From 0ccb9cbe1566029847ed26352c61b5b2e009cfad Mon Sep 17 00:00:00 2001 From: Ruibin Cheung Date: Fri, 2 Feb 2024 20:34:00 +0800 Subject: [PATCH] [cherry-pick] adapt c_embedding to phi namespace for custom devices (#60774) (#61045) Co-authored-by: Tian <121000916+SylarTiaNII@users.noreply.github.com> --- paddle/phi/kernels/CMakeLists.txt | 4 + .../kernels/custom/c_embedding_grad_kernel.cc | 93 +++++++++++++++++++ .../phi/kernels/custom/c_embedding_kernel.cc | 84 +++++++++++++++++ test/legacy_test/c_embedding_op_base.py | 25 ++++- 4 files changed, 201 insertions(+), 5 deletions(-) create mode 100644 paddle/phi/kernels/custom/c_embedding_grad_kernel.cc create mode 100644 paddle/phi/kernels/custom/c_embedding_kernel.cc diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 82bc55a19fdf9..0e3882f0493d8 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -199,6 +199,10 @@ if(WITH_MKLDNN) "fusion/onednn/*.cc") endif() +if(WITH_CUSTOM_DEVICE) + set(cc_search_pattern ${cc_search_pattern} "custom/*.cc") +endif() + file( GLOB kernel_cc RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" diff --git a/paddle/phi/kernels/custom/c_embedding_grad_kernel.cc b/paddle/phi/kernels/custom/c_embedding_grad_kernel.cc new file mode 100644 index 0000000000000..ff61688513b13 --- /dev/null +++ b/paddle/phi/kernels/custom/c_embedding_grad_kernel.cc @@ -0,0 +1,93 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/c_embedding_grad_kernel.h" +#include "glog/logging.h" +#include "paddle/phi/api/backward/backward_api.h" +#include "paddle/phi/api/include/api.h" +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +#ifdef PADDLE_WITH_CUSTOM_DEVICE +template +void CEmbeddingGradKernel(const Context& dev_ctx, + const DenseTensor& w, + const DenseTensor& ids, + const DenseTensor& out_grad, + int64_t start_index, + DenseTensor* w_grad) { + w_grad->Resize(w.dims()); + dev_ctx.template Alloc(w_grad, w.dtype()); + const auto& index_type = ids.dtype(); + if (index_type == phi::DataType::INT32 || + index_type == phi::DataType::INT64) { + auto K = ids.numel(); + auto N = w.dims()[0]; + auto D = w.dims()[1]; + + auto x_tmp = std::make_shared(); + x_tmp->ShareDataWith(ids).Resize({K}); + auto w_tmp = std::make_shared(); + w_tmp->set_meta(w.meta()); + dev_ctx.Alloc(w_tmp.get(), w_tmp->dtype()); + auto out_grad_tmp = std::make_shared(); + out_grad_tmp->ShareDataWith(out_grad).Resize({K, D}); + paddle::Tensor x_tensor(x_tmp), w_tensor(w_tmp), + out_grad_tensor(out_grad_tmp); + + auto start_index_tensor = paddle::experimental::full_like( + x_tensor, start_index, x_tensor.dtype(), x_tensor.place()); + auto end_index_tensor = paddle::experimental::full_like( + x_tensor, start_index + N, x_tensor.dtype(), x_tensor.place()); + auto ids_mask_tensor = paddle::experimental::logical_and( + x_tensor.greater_equal(start_index_tensor), + x_tensor.less_than(end_index_tensor)); + auto real_ids_tensor = (x_tensor - start_index_tensor) + .multiply(paddle::experimental::cast( + ids_mask_tensor, x_tensor.dtype())); + auto out_grad_tensor_mul_mask = + paddle::experimental::reshape(out_grad_tensor, {K, D}) + .multiply(paddle::experimental::reshape( + paddle::experimental::cast(ids_mask_tensor, w.dtype()), + {K, 1})); + paddle::Tensor w_grad_tensor; + paddle::experimental::embedding_grad(real_ids_tensor, + w_tensor, + out_grad_tensor_mul_mask, + -1, + false, + &w_grad_tensor); + w_grad->ShareDataWith( + *reinterpret_cast(w_grad_tensor.impl().get())); + + } else { + PADDLE_THROW(phi::errors::Unavailable( + "Custom Device c_embedding_grad ids only support int32 or int64.")); + } +} +#endif +} // namespace phi + +#ifdef PADDLE_WITH_CUSTOM_DEVICE +PD_REGISTER_KERNEL(c_embedding_grad, + Custom, + ALL_LAYOUT, + phi::CEmbeddingGradKernel, + float, + phi::dtype::float16, + phi::dtype::bfloat16) {} +#endif diff --git a/paddle/phi/kernels/custom/c_embedding_kernel.cc b/paddle/phi/kernels/custom/c_embedding_kernel.cc new file mode 100644 index 0000000000000..0cacf61d46f3a --- /dev/null +++ b/paddle/phi/kernels/custom/c_embedding_kernel.cc @@ -0,0 +1,84 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/c_embedding_kernel.h" +#include "glog/logging.h" +#include "paddle/phi/api/backward/backward_api.h" +#include "paddle/phi/api/include/api.h" +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +#ifdef PADDLE_WITH_CUSTOM_DEVICE +template +void CEmbeddingKernel(const Context& dev_ctx, + const DenseTensor& w, + const DenseTensor& ids, + int64_t start_index, + int64_t vocab_size, + DenseTensor* out) { + const auto& index_type = ids.dtype(); + if (index_type == phi::DataType::INT32 || + index_type == phi::DataType::INT64) { + auto out_dims = out->dims(); + auto K = ids.numel(); + auto N = w.dims()[0]; + auto D = w.dims()[1]; + + auto x_tmp = std::make_shared(); + x_tmp->ShareDataWith(ids).Resize({K}); + auto w_tmp = std::make_shared(); + w_tmp->ShareDataWith(w).Resize({N, D}); + paddle::Tensor x_tensor(x_tmp), w_tensor(w_tmp); + + auto start_index_tensor = paddle::experimental::full_like( + x_tensor, start_index, x_tensor.dtype(), x_tensor.place()); + auto end_index_tensor = paddle::experimental::full_like( + x_tensor, start_index + N, x_tensor.dtype(), x_tensor.place()); + auto ids_mask_tensor = paddle::experimental::logical_and( + x_tensor.greater_equal(start_index_tensor), + x_tensor.less_than(end_index_tensor)); + auto ids_tensor = (x_tensor - start_index_tensor) + .multiply(paddle::experimental::cast( + ids_mask_tensor, x_tensor.dtype())); + auto out_tensor = + paddle::experimental::reshape( + paddle::experimental::cast(ids_mask_tensor, w_tensor.dtype()), + {K, 1}) + .multiply(paddle::experimental::reshape( + paddle::experimental::embedding( + ids_tensor, w_tensor, -1, false), + {K, D})); + out->ShareDataWith( + *reinterpret_cast(out_tensor.impl().get())) + .Resize(out_dims); + } else { + PADDLE_THROW(phi::errors::Unavailable( + "Custom Device c_embedding ids only support int32 or int64.")); + } +} +#endif +} // namespace phi + +#ifdef PADDLE_WITH_CUSTOM_DEVICE +PD_REGISTER_KERNEL(c_embedding, + Custom, + ALL_LAYOUT, + phi::CEmbeddingKernel, + float, + phi::dtype::float16, + phi::dtype::bfloat16) {} +#endif diff --git a/test/legacy_test/c_embedding_op_base.py b/test/legacy_test/c_embedding_op_base.py index 83758b6bb0bc9..cfb9df8e69d22 100644 --- a/test/legacy_test/c_embedding_op_base.py +++ b/test/legacy_test/c_embedding_op_base.py @@ -34,10 +34,8 @@ def get_c_embedding(start, end, table, ids): return output -def c_embedding_wrapper(table, index, start_index=0): - return paddle._legacy_C_ops.c_embedding( - table, index, "start_index", start_index - ) +def c_embedding_wrapper(table, index, start_index=0, vocab_size=-1): + return paddle._C_ops.c_embedding(table, index, start_index, vocab_size) class TestCEmbeddingCPU(OpTest): @@ -58,11 +56,15 @@ def initcase(self): ) self.start_index = 10 self.end_index = self.start_index + 17 + self.vocab_size = 34 self.inputs = {'W': table, 'Ids': ids} np_out = get_c_embedding(self.start_index, self.end_index, table, ids) self.outputs = {'Out': np_out.reshape((2, 4, 64))} - self.attrs = {'start_index': self.start_index} + self.attrs = { + 'start_index': self.start_index, + 'vocab_size': self.vocab_size, + } if core.is_compiled_with_xpu(): self.__class__.use_xpu = True @@ -87,12 +89,20 @@ def test_check_output(self): self.check_output_with_place(core.CUDAPlace(0)) elif core.is_compiled_with_xpu(): self.check_output_with_place(core.XPUPlace(0)) + else: + current_place = paddle.framework._current_expected_place() + if isinstance(current_place, paddle.CustomPlace): + self.check_output_with_place(current_place) def test_check_grad(self): if core.is_compiled_with_cuda(): self.check_grad_with_place(core.CUDAPlace(0), ['W'], 'Out') elif core.is_compiled_with_xpu(): self.check_grad_with_place(core.XPUPlace(0), ['W'], 'Out') + else: + current_place = paddle.framework._current_expected_place() + if isinstance(current_place, paddle.CustomPlace): + self.check_grad_with_place(current_place, ['W'], 'Out') def init_dtype(self): if core.is_compiled_with_cuda(): @@ -101,6 +111,11 @@ def init_dtype(self): elif core.is_compiled_with_xpu(): self.dtype = "float32" self.ids_dtype = "int64" + else: + current_place = paddle.framework._current_expected_place() + if isinstance(current_place, paddle.CustomPlace): + self.dtype = "float32" + self.ids_dtype = "int64" class TestCEmbeddingOpFP32(TestCEmbeddingOpBase):