diff --git a/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op_xpu.cc b/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op_xpu.cc index 42b4db460bbfc9..000f4d38564f60 100644 --- a/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op_xpu.cc +++ b/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op_xpu.cc @@ -102,13 +102,17 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor { // reduce last dim int dims[1] = {1}; auto f = [](xpu::Context* ctx, - const XPUType* x, - XPUType* y, + const T* x, + T* y, const std::vector& xdims, const std::vector& reduce_dims) { - return xpu::reduce_max(ctx, x, y, xdims, reduce_dims); + return xpu::reduce_max(ctx, + reinterpret_cast(x), + reinterpret_cast(y), + xdims, + reduce_dims); }; - ret = phi::XPUReduce( + ret = phi::XPUReduce( dev_ctx, logits_2d, std::vector(dims, dims + 1), @@ -194,13 +198,17 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor { { int dims[1] = {1}; auto f = [](xpu::Context* ctx, - const XPUType* x, - XPUType* y, + const T* x, + T* y, const std::vector& xdims, const std::vector& reduce_dims) { - return xpu::reduce_sum(ctx, x, y, xdims, reduce_dims); + return xpu::reduce_sum(ctx, + reinterpret_cast(x), + reinterpret_cast(y), + xdims, + reduce_dims); }; - ret = phi::XPUReduce( + ret = phi::XPUReduce( dev_ctx, softmax_2d, std::vector(dims, dims + 1), @@ -323,13 +331,17 @@ struct CSoftmaxWithCrossEntropyFunctor { { int dims[1] = {1}; auto f = [](xpu::Context* ctx, - const XPUType* x, - XPUType* y, + const T* x, + T* y, const std::vector& xdims, const std::vector& reduce_dims) { - return xpu::reduce_max(ctx, x, y, xdims, reduce_dims); + return xpu::reduce_max(ctx, + reinterpret_cast(x), + reinterpret_cast(y), + xdims, + reduce_dims); }; - ret = phi::XPUReduce( + ret = phi::XPUReduce( dev_ctx, logits_2d, std::vector(dims, dims + 1), @@ -436,13 +448,17 @@ struct CSoftmaxWithCrossEntropyFunctor { { int dims[1] = {1}; auto f = [](xpu::Context* ctx, - const XPUType* x, - XPUType* y, + const T* x, + T* y, const std::vector& xdims, const std::vector& reduce_dims) { - return xpu::reduce_sum(ctx, x, y, xdims, reduce_dims); + return xpu::reduce_sum(ctx, + reinterpret_cast(x), + reinterpret_cast(y), + xdims, + reduce_dims); }; - ret = phi::XPUReduce( + ret = phi::XPUReduce( dev_ctx, softmax_2d, std::vector(dims, dims + 1), @@ -567,9 +583,11 @@ PD_REGISTER_STRUCT_KERNEL(c_softmax_with_cross_entropy, XPU, ALL_LAYOUT, ops::CSoftmaxWithCrossEntropyOp, - float) {} + float, + phi::dtype::bfloat16) {} PD_REGISTER_STRUCT_KERNEL(c_softmax_with_cross_entropy_grad, XPU, ALL_LAYOUT, ops::CSoftmaxWithCrossEntropyGrad, - float) {} + float, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/backends/xpu/xpu3_op_list.cc b/paddle/phi/backends/xpu/xpu3_op_list.cc index 20c649ee4ba978..efc38e7fb15b5d 100644 --- a/paddle/phi/backends/xpu/xpu3_op_list.cc +++ b/paddle/phi/backends/xpu/xpu3_op_list.cc @@ -143,9 +143,10 @@ XPUOpMap& get_kl3_ops() { phi::DataType::BFLOAT16, phi::DataType::INT32, phi::DataType::INT64})}, - {"c_softmax_with_cross_entropy", XPUKernelSet({phi::DataType::FLOAT32})}, + {"c_softmax_with_cross_entropy", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::BFLOAT16})}, {"c_softmax_with_cross_entropy_grad", - XPUKernelSet({phi::DataType::FLOAT32})}, + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::BFLOAT16})}, {"c_reduce_sum", XPUKernelSet({phi::DataType::FLOAT32})}, {"c_split", XPUKernelSet({phi::DataType::FLOAT16, diff --git a/test/xpu/test_collective_softmax_with_cross_entropy_xpu.py b/test/xpu/test_collective_softmax_with_cross_entropy_xpu.py index 21d333d222ef98..bfb89fd978074f 100644 --- a/test/xpu/test_collective_softmax_with_cross_entropy_xpu.py +++ b/test/xpu/test_collective_softmax_with_cross_entropy_xpu.py @@ -167,6 +167,8 @@ def check_with_place( support_types = get_xpu_op_support_types('c_softmax_with_cross_entropy') for stype in support_types: + if stype == "bfloat16": + continue create_test_class( globals(), XPUTestCSoftmaxWithCEOP,