Skip to content

Commit

Permalink
[cherry-pick] adapt c_embedding to phi namespace for custom devices (P…
Browse files Browse the repository at this point in the history
…addlePaddle#60774) (PaddlePaddle#61045)

Co-authored-by: Tian <121000916+SylarTiaNII@users.noreply.github.com>
  • Loading branch information
BeingGod and SylarTiaNII authored Feb 2, 2024
1 parent 60325a1 commit 0ccb9cb
Show file tree
Hide file tree
Showing 4 changed files with 201 additions and 5 deletions.
4 changes: 4 additions & 0 deletions paddle/phi/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,10 @@ if(WITH_MKLDNN)
"fusion/onednn/*.cc")
endif()

if(WITH_CUSTOM_DEVICE)
set(cc_search_pattern ${cc_search_pattern} "custom/*.cc")
endif()

file(
GLOB kernel_cc
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
Expand Down
93 changes: 93 additions & 0 deletions paddle/phi/kernels/custom/c_embedding_grad_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/c_embedding_grad_kernel.h"
#include "glog/logging.h"
#include "paddle/phi/api/backward/backward_api.h"
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {

#ifdef PADDLE_WITH_CUSTOM_DEVICE
template <typename T, typename Context>
void CEmbeddingGradKernel(const Context& dev_ctx,
const DenseTensor& w,
const DenseTensor& ids,
const DenseTensor& out_grad,
int64_t start_index,
DenseTensor* w_grad) {
w_grad->Resize(w.dims());
dev_ctx.template Alloc(w_grad, w.dtype());
const auto& index_type = ids.dtype();
if (index_type == phi::DataType::INT32 ||
index_type == phi::DataType::INT64) {
auto K = ids.numel();
auto N = w.dims()[0];
auto D = w.dims()[1];

auto x_tmp = std::make_shared<phi::DenseTensor>();
x_tmp->ShareDataWith(ids).Resize({K});
auto w_tmp = std::make_shared<phi::DenseTensor>();
w_tmp->set_meta(w.meta());
dev_ctx.Alloc(w_tmp.get(), w_tmp->dtype());
auto out_grad_tmp = std::make_shared<phi::DenseTensor>();
out_grad_tmp->ShareDataWith(out_grad).Resize({K, D});
paddle::Tensor x_tensor(x_tmp), w_tensor(w_tmp),
out_grad_tensor(out_grad_tmp);

auto start_index_tensor = paddle::experimental::full_like(
x_tensor, start_index, x_tensor.dtype(), x_tensor.place());
auto end_index_tensor = paddle::experimental::full_like(
x_tensor, start_index + N, x_tensor.dtype(), x_tensor.place());
auto ids_mask_tensor = paddle::experimental::logical_and(
x_tensor.greater_equal(start_index_tensor),
x_tensor.less_than(end_index_tensor));
auto real_ids_tensor = (x_tensor - start_index_tensor)
.multiply(paddle::experimental::cast(
ids_mask_tensor, x_tensor.dtype()));
auto out_grad_tensor_mul_mask =
paddle::experimental::reshape(out_grad_tensor, {K, D})
.multiply(paddle::experimental::reshape(
paddle::experimental::cast(ids_mask_tensor, w.dtype()),
{K, 1}));
paddle::Tensor w_grad_tensor;
paddle::experimental::embedding_grad(real_ids_tensor,
w_tensor,
out_grad_tensor_mul_mask,
-1,
false,
&w_grad_tensor);
w_grad->ShareDataWith(
*reinterpret_cast<phi::DenseTensor*>(w_grad_tensor.impl().get()));

} else {
PADDLE_THROW(phi::errors::Unavailable(
"Custom Device c_embedding_grad ids only support int32 or int64."));
}
}
#endif
} // namespace phi

#ifdef PADDLE_WITH_CUSTOM_DEVICE
PD_REGISTER_KERNEL(c_embedding_grad,
Custom,
ALL_LAYOUT,
phi::CEmbeddingGradKernel,
float,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#endif
84 changes: 84 additions & 0 deletions paddle/phi/kernels/custom/c_embedding_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/c_embedding_kernel.h"
#include "glog/logging.h"
#include "paddle/phi/api/backward/backward_api.h"
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {

#ifdef PADDLE_WITH_CUSTOM_DEVICE
template <typename T, typename Context>
void CEmbeddingKernel(const Context& dev_ctx,
const DenseTensor& w,
const DenseTensor& ids,
int64_t start_index,
int64_t vocab_size,
DenseTensor* out) {
const auto& index_type = ids.dtype();
if (index_type == phi::DataType::INT32 ||
index_type == phi::DataType::INT64) {
auto out_dims = out->dims();
auto K = ids.numel();
auto N = w.dims()[0];
auto D = w.dims()[1];

auto x_tmp = std::make_shared<phi::DenseTensor>();
x_tmp->ShareDataWith(ids).Resize({K});
auto w_tmp = std::make_shared<phi::DenseTensor>();
w_tmp->ShareDataWith(w).Resize({N, D});
paddle::Tensor x_tensor(x_tmp), w_tensor(w_tmp);

auto start_index_tensor = paddle::experimental::full_like(
x_tensor, start_index, x_tensor.dtype(), x_tensor.place());
auto end_index_tensor = paddle::experimental::full_like(
x_tensor, start_index + N, x_tensor.dtype(), x_tensor.place());
auto ids_mask_tensor = paddle::experimental::logical_and(
x_tensor.greater_equal(start_index_tensor),
x_tensor.less_than(end_index_tensor));
auto ids_tensor = (x_tensor - start_index_tensor)
.multiply(paddle::experimental::cast(
ids_mask_tensor, x_tensor.dtype()));
auto out_tensor =
paddle::experimental::reshape(
paddle::experimental::cast(ids_mask_tensor, w_tensor.dtype()),
{K, 1})
.multiply(paddle::experimental::reshape(
paddle::experimental::embedding(
ids_tensor, w_tensor, -1, false),
{K, D}));
out->ShareDataWith(
*reinterpret_cast<phi::DenseTensor*>(out_tensor.impl().get()))
.Resize(out_dims);
} else {
PADDLE_THROW(phi::errors::Unavailable(
"Custom Device c_embedding ids only support int32 or int64."));
}
}
#endif
} // namespace phi

#ifdef PADDLE_WITH_CUSTOM_DEVICE
PD_REGISTER_KERNEL(c_embedding,
Custom,
ALL_LAYOUT,
phi::CEmbeddingKernel,
float,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#endif
25 changes: 20 additions & 5 deletions test/legacy_test/c_embedding_op_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,8 @@ def get_c_embedding(start, end, table, ids):
return output


def c_embedding_wrapper(table, index, start_index=0):
return paddle._legacy_C_ops.c_embedding(
table, index, "start_index", start_index
)
def c_embedding_wrapper(table, index, start_index=0, vocab_size=-1):
return paddle._C_ops.c_embedding(table, index, start_index, vocab_size)


class TestCEmbeddingCPU(OpTest):
Expand All @@ -58,11 +56,15 @@ def initcase(self):
)
self.start_index = 10
self.end_index = self.start_index + 17
self.vocab_size = 34

self.inputs = {'W': table, 'Ids': ids}
np_out = get_c_embedding(self.start_index, self.end_index, table, ids)
self.outputs = {'Out': np_out.reshape((2, 4, 64))}
self.attrs = {'start_index': self.start_index}
self.attrs = {
'start_index': self.start_index,
'vocab_size': self.vocab_size,
}
if core.is_compiled_with_xpu():
self.__class__.use_xpu = True

Expand All @@ -87,12 +89,20 @@ def test_check_output(self):
self.check_output_with_place(core.CUDAPlace(0))
elif core.is_compiled_with_xpu():
self.check_output_with_place(core.XPUPlace(0))
else:
current_place = paddle.framework._current_expected_place()
if isinstance(current_place, paddle.CustomPlace):
self.check_output_with_place(current_place)

def test_check_grad(self):
if core.is_compiled_with_cuda():
self.check_grad_with_place(core.CUDAPlace(0), ['W'], 'Out')
elif core.is_compiled_with_xpu():
self.check_grad_with_place(core.XPUPlace(0), ['W'], 'Out')
else:
current_place = paddle.framework._current_expected_place()
if isinstance(current_place, paddle.CustomPlace):
self.check_grad_with_place(current_place, ['W'], 'Out')

def init_dtype(self):
if core.is_compiled_with_cuda():
Expand All @@ -101,6 +111,11 @@ def init_dtype(self):
elif core.is_compiled_with_xpu():
self.dtype = "float32"
self.ids_dtype = "int64"
else:
current_place = paddle.framework._current_expected_place()
if isinstance(current_place, paddle.CustomPlace):
self.dtype = "float32"
self.ids_dtype = "int64"


class TestCEmbeddingOpFP32(TestCEmbeddingOpBase):
Expand Down

0 comments on commit 0ccb9cb

Please sign in to comment.