Skip to content

Commit

Permalink
c_softmax_with_cross_entropy support bf16 for xpu
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyk0314 committed Jan 4, 2024
1 parent 7b616c4 commit f0776c1
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,17 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::XPUContext, T> {
// reduce last dim
int dims[1] = {1};
auto f = [](xpu::Context* ctx,
const XPUType* x,
XPUType* y,
const T* x,
T* y,
const std::vector<int>& xdims,
const std::vector<int>& reduce_dims) {
return xpu::reduce_max<XPUType>(ctx, x, y, xdims, reduce_dims);
return xpu::reduce_max<XPUType>(ctx,
reinterpret_cast<const XPUType*>(x),
reinterpret_cast<XPUType*>(y),
xdims,
reduce_dims);
};
ret = phi::XPUReduce<phi::XPUContext, XPUType>(
ret = phi::XPUReduce<phi::XPUContext, T>(
dev_ctx,
logits_2d,
std::vector<int64_t>(dims, dims + 1),
Expand Down Expand Up @@ -194,13 +198,17 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::XPUContext, T> {
{
int dims[1] = {1};
auto f = [](xpu::Context* ctx,
const XPUType* x,
XPUType* y,
const T* x,
T* y,
const std::vector<int>& xdims,
const std::vector<int>& reduce_dims) {
return xpu::reduce_sum<XPUType>(ctx, x, y, xdims, reduce_dims);
return xpu::reduce_sum<XPUType>(ctx,
reinterpret_cast<const XPUType*>(x),
reinterpret_cast<XPUType*>(y),
xdims,
reduce_dims);
};
ret = phi::XPUReduce<phi::XPUContext, XPUType>(
ret = phi::XPUReduce<phi::XPUContext, T>(
dev_ctx,
softmax_2d,
std::vector<int64_t>(dims, dims + 1),
Expand Down Expand Up @@ -323,13 +331,17 @@ struct CSoftmaxWithCrossEntropyFunctor<phi::XPUContext, T> {
{
int dims[1] = {1};
auto f = [](xpu::Context* ctx,
const XPUType* x,
XPUType* y,
const T* x,
T* y,
const std::vector<int>& xdims,
const std::vector<int>& reduce_dims) {
return xpu::reduce_max<XPUType>(ctx, x, y, xdims, reduce_dims);
return xpu::reduce_max<XPUType>(ctx,
reinterpret_cast<const XPUType*>(x),
reinterpret_cast<XPUType*>(y),
xdims,
reduce_dims);
};
ret = phi::XPUReduce<phi::XPUContext, XPUType>(
ret = phi::XPUReduce<phi::XPUContext, T>(
dev_ctx,
logits_2d,
std::vector<int64_t>(dims, dims + 1),
Expand Down Expand Up @@ -436,13 +448,17 @@ struct CSoftmaxWithCrossEntropyFunctor<phi::XPUContext, T> {
{
int dims[1] = {1};
auto f = [](xpu::Context* ctx,
const XPUType* x,
XPUType* y,
const T* x,
T* y,
const std::vector<int>& xdims,
const std::vector<int>& reduce_dims) {
return xpu::reduce_sum<XPUType>(ctx, x, y, xdims, reduce_dims);
return xpu::reduce_sum<XPUType>(ctx,
reinterpret_cast<const XPUType*>(x),
reinterpret_cast<XPUType*>(y),
xdims,
reduce_dims);
};
ret = phi::XPUReduce<phi::XPUContext, XPUType>(
ret = phi::XPUReduce<phi::XPUContext, T>(
dev_ctx,
softmax_2d,
std::vector<int64_t>(dims, dims + 1),
Expand Down Expand Up @@ -567,9 +583,11 @@ PD_REGISTER_STRUCT_KERNEL(c_softmax_with_cross_entropy,
XPU,
ALL_LAYOUT,
ops::CSoftmaxWithCrossEntropyOp,
float) {}
float,
phi::dtype::bfloat16) {}
PD_REGISTER_STRUCT_KERNEL(c_softmax_with_cross_entropy_grad,
XPU,
ALL_LAYOUT,
ops::CSoftmaxWithCrossEntropyGrad,
float) {}
float,
phi::dtype::bfloat16) {}
5 changes: 3 additions & 2 deletions paddle/phi/backends/xpu/xpu3_op_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,10 @@ XPUOpMap& get_kl3_ops() {
phi::DataType::BFLOAT16,
phi::DataType::INT32,
phi::DataType::INT64})},
{"c_softmax_with_cross_entropy", XPUKernelSet({phi::DataType::FLOAT32})},
{"c_softmax_with_cross_entropy",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::BFLOAT16})},
{"c_softmax_with_cross_entropy_grad",
XPUKernelSet({phi::DataType::FLOAT32})},
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::BFLOAT16})},
{"c_reduce_sum", XPUKernelSet({phi::DataType::FLOAT32})},
{"c_split",
XPUKernelSet({phi::DataType::FLOAT16,
Expand Down
20 changes: 15 additions & 5 deletions test/xpu/collective_softmax_with_cross_entropy_op_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import sys

import numpy as np
from op_test import convert_float_to_uint16
from test_collective_base_xpu import (
DataTypeCast,
TestCollectiveRunnerBase,
Expand Down Expand Up @@ -44,7 +45,7 @@ def get_model(self, main_prog, startup_program, rank):
logits = data(
name="Logits",
shape=[self.batch_size, self.local_elements],
dtype='float32',
dtype=self.dtype,
)
label = data(
name="Label", shape=[self.batch_size, 1], dtype='int32'
Expand Down Expand Up @@ -110,6 +111,7 @@ def run_trainer(self, args):
self.initCommunicator(
startup_prog, rank, self.nranks, True, current_endpoint, endpoints
)
self.dtype = args["dtype"]
np_dtype = DataTypeCast(args["dtype"])
loss, softmax = self.get_model(train_prog, startup_prog, rank)
device_id = int(os.getenv("FLAGS_selected_xpus", "0"))
Expand All @@ -126,15 +128,23 @@ def run_trainer(self, args):
dtype='int32',
)
# use FAKE loss_grad here, only to examine the correctness of grad func
loss_grad = np.random.uniform(
loss_grad_fp32 = np.random.uniform(
low=-10.0, high=10.0, size=(self.batch_size, 1)
).astype(np_dtype)
).astype(np.float32)
if args["dtype"] == "bfloat16":
loss_grad = convert_float_to_uint16(loss_grad_fp32)
else:
loss_grad = loss_grad_fp32.astype(np_dtype)

# each xpu uses own half of logits
np.random.seed(os.getpid())
logits = np.random.uniform(
logits_fp32 = np.random.uniform(
low=-40.0, high=40.0, size=(self.batch_size, self.local_elements)
).astype(np_dtype)
).astype(np.float32)
if args["dtype"] == "bfloat16":
logits = convert_float_to_uint16(logits_fp32)
else:
logits = logits_fp32.astype(np_dtype)
out = exe.run(
train_prog,
feed={'Logits': logits, 'Label': label, 'Loss@GRAD': loss_grad},
Expand Down
26 changes: 21 additions & 5 deletions test/xpu/test_collective_softmax_with_cross_entropy_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
create_test_class,
get_xpu_op_support_types,
)
from op_test import convert_uint16_to_float
from test_collective_base_xpu import DataTypeCast, TestDistBase

import paddle
Expand Down Expand Up @@ -154,15 +155,30 @@ def check_with_place(
# get real result
loss0, softmax0, logits_grad0 = tr0_out
loss1, softmax1, logits_grad1 = tr1_out
if dtype == "bfloat16":
loss0 = convert_uint16_to_float(loss0)
softmax0 = convert_uint16_to_float(softmax0)
logits_grad0 = convert_uint16_to_float(logits_grad0)
loss1 = convert_uint16_to_float(loss1)
softmax1 = convert_uint16_to_float(softmax1)
logits_grad1 = convert_uint16_to_float(logits_grad1)
softmax = np.concatenate((softmax0, softmax1), axis=1)
logits_grad = np.concatenate((logits_grad0, logits_grad1), axis=1)

# compare results
rtol = 1e-6
np.testing.assert_allclose(loss0, need_loss, rtol=rtol)
np.testing.assert_allclose(loss1, need_loss, rtol=rtol)
np.testing.assert_allclose(softmax, need_softmax, rtol=rtol)
np.testing.assert_allclose(logits_grad, need_logits_grad, rtol=rtol)
atol = 0
if dtype == "bfloat16":
rtol = 0.1
atol = 0.1
np.testing.assert_allclose(loss0, need_loss, rtol=rtol, atol=atol)
np.testing.assert_allclose(loss1, need_loss, rtol=rtol, atol=atol)
np.testing.assert_allclose(
softmax, need_softmax, rtol=rtol, atol=atol
)
np.testing.assert_allclose(
logits_grad, need_logits_grad, rtol=rtol, atol=atol
)


support_types = get_xpu_op_support_types('c_softmax_with_cross_entropy')
Expand All @@ -171,7 +187,7 @@ def check_with_place(
globals(),
XPUTestCSoftmaxWithCEOP,
stype,
ignore_device_version=[core.XPUVersion.XPU1],
ignore_device_version=[core.XPUVersion.XPU1, core.XPUVersion.XPU3],
)

if __name__ == '__main__':
Expand Down

0 comments on commit f0776c1

Please sign in to comment.