Skip to content

Commit

Permalink
c_softmax_with_cross_entropy support bf16 for xpu
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyk0314 committed Dec 29, 2023
1 parent 63776cf commit ec7c129
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,17 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::XPUContext, T> {
// reduce last dim
int dims[1] = {1};
auto f = [](xpu::Context* ctx,
const XPUType* x,
XPUType* y,
const T* x,
T* y,
const std::vector<int>& xdims,
const std::vector<int>& reduce_dims) {
return xpu::reduce_max<XPUType>(ctx, x, y, xdims, reduce_dims);
return xpu::reduce_max<XPUType>(ctx,
reinterpret_cast<const XPUType*>(x),
reinterpret_cast<XPUType*>(y),
xdims,
reduce_dims);
};
ret = phi::XPUReduce<phi::XPUContext, XPUType>(
ret = phi::XPUReduce<phi::XPUContext, T>(
dev_ctx,
logits_2d,
std::vector<int64_t>(dims, dims + 1),
Expand Down Expand Up @@ -194,13 +198,17 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::XPUContext, T> {
{
int dims[1] = {1};
auto f = [](xpu::Context* ctx,
const XPUType* x,
XPUType* y,
const T* x,
T* y,
const std::vector<int>& xdims,
const std::vector<int>& reduce_dims) {
return xpu::reduce_sum<XPUType>(ctx, x, y, xdims, reduce_dims);
return xpu::reduce_sum<XPUType>(ctx,
reinterpret_cast<const XPUType*>(x),
reinterpret_cast<XPUType*>(y),
xdims,
reduce_dims);
};
ret = phi::XPUReduce<phi::XPUContext, XPUType>(
ret = phi::XPUReduce<phi::XPUContext, T>(
dev_ctx,
softmax_2d,
std::vector<int64_t>(dims, dims + 1),
Expand Down Expand Up @@ -323,13 +331,17 @@ struct CSoftmaxWithCrossEntropyFunctor<phi::XPUContext, T> {
{
int dims[1] = {1};
auto f = [](xpu::Context* ctx,
const XPUType* x,
XPUType* y,
const T* x,
T* y,
const std::vector<int>& xdims,
const std::vector<int>& reduce_dims) {
return xpu::reduce_max<XPUType>(ctx, x, y, xdims, reduce_dims);
return xpu::reduce_max<XPUType>(ctx,
reinterpret_cast<const XPUType*>(x),
reinterpret_cast<XPUType*>(y),
xdims,
reduce_dims);
};
ret = phi::XPUReduce<phi::XPUContext, XPUType>(
ret = phi::XPUReduce<phi::XPUContext, T>(
dev_ctx,
logits_2d,
std::vector<int64_t>(dims, dims + 1),
Expand Down Expand Up @@ -436,13 +448,17 @@ struct CSoftmaxWithCrossEntropyFunctor<phi::XPUContext, T> {
{
int dims[1] = {1};
auto f = [](xpu::Context* ctx,
const XPUType* x,
XPUType* y,
const T* x,
T* y,
const std::vector<int>& xdims,
const std::vector<int>& reduce_dims) {
return xpu::reduce_sum<XPUType>(ctx, x, y, xdims, reduce_dims);
return xpu::reduce_sum<XPUType>(ctx,
reinterpret_cast<const XPUType*>(x),
reinterpret_cast<XPUType*>(y),
xdims,
reduce_dims);
};
ret = phi::XPUReduce<phi::XPUContext, XPUType>(
ret = phi::XPUReduce<phi::XPUContext, T>(
dev_ctx,
softmax_2d,
std::vector<int64_t>(dims, dims + 1),
Expand Down Expand Up @@ -567,9 +583,11 @@ PD_REGISTER_STRUCT_KERNEL(c_softmax_with_cross_entropy,
XPU,
ALL_LAYOUT,
ops::CSoftmaxWithCrossEntropyOp,
float) {}
float,
phi::dtype::bfloat16) {}
PD_REGISTER_STRUCT_KERNEL(c_softmax_with_cross_entropy_grad,
XPU,
ALL_LAYOUT,
ops::CSoftmaxWithCrossEntropyGrad,
float) {}
float,
phi::dtype::bfloat16) {}
5 changes: 3 additions & 2 deletions paddle/phi/backends/xpu/xpu3_op_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,10 @@ XPUOpMap& get_kl3_ops() {
phi::DataType::BFLOAT16,
phi::DataType::INT32,
phi::DataType::INT64})},
{"c_softmax_with_cross_entropy", XPUKernelSet({phi::DataType::FLOAT32})},
{"c_softmax_with_cross_entropy",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::BFLOAT16})},
{"c_softmax_with_cross_entropy_grad",
XPUKernelSet({phi::DataType::FLOAT32})},
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::BFLOAT16})},
{"c_reduce_sum", XPUKernelSet({phi::DataType::FLOAT32})},
{"c_split",
XPUKernelSet({phi::DataType::FLOAT16,
Expand Down
2 changes: 2 additions & 0 deletions test/xpu/test_collective_softmax_with_cross_entropy_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ def check_with_place(

support_types = get_xpu_op_support_types('c_softmax_with_cross_entropy')
for stype in support_types:
if stype == "bfloat16":
continue
create_test_class(
globals(),
XPUTestCSoftmaxWithCEOP,
Expand Down

0 comments on commit ec7c129

Please sign in to comment.