Skip to content

Commit

Permalink
[bf16] add bf16 kernel: elementwise_div (#39602)
Browse files Browse the repository at this point in the history
* add elementwise_div

* refine rocm

* refine code

* refine op register

* solve conflict

* refine unittest

* refine unittest precision

* add rocm
  • Loading branch information
zhangbo9674 authored Feb 23, 2022
1 parent 1fcaab4 commit ca4df33
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 1 deletion.
6 changes: 6 additions & 0 deletions paddle/fluid/operators/elementwise/elementwise_div_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext,
paddle::platform::bfloat16>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int64_t>,
Expand All @@ -65,6 +67,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::bfloat16>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
Expand All @@ -78,6 +82,8 @@ REGISTER_OP_CUDA_KERNEL(
float>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::bfloat16>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
double>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
Expand Down
12 changes: 12 additions & 0 deletions paddle/fluid/platform/device/gpu/cuda/cuda_device_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,18 @@ __forceinline__ __device__ float16 CudaShuffleXorSync(unsigned mask,
return float16(__shfl_xor_sync(mask, val.to_half(), width));
}

template <>
__forceinline__ __device__ bfloat16 CudaShuffleXorSync(unsigned mask,
bfloat16 val,
int width) {
#if defined(PADDLE_CUDA_BF16)
return bfloat16(__shfl_xor_sync(mask, static_cast<nv_bfloat16>(val), width));
#else
PADDLE_ENFORCE(
false, "__shfl_xor_sync with bfloat16 is not supported on cuda <= 11.");
#endif
}

template <>
__forceinline__ __device__ paddle::platform::complex<float> CudaShuffleXorSync(
unsigned mask, paddle::platform::complex<float> val, int width) {
Expand Down
7 changes: 7 additions & 0 deletions paddle/fluid/platform/device/gpu/rocm/rocm_device_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,13 @@ __forceinline__ __device__ float16 CudaShuffleXorSync(unsigned mask,
return float16(__shfl_xor(static_cast<float>(val), width));
}

template <>
__forceinline__ __device__ bfloat16 CudaShuffleXorSync(unsigned mask,
bfloat16 val,
int width) {
return bfloat16(__shfl_xor(static_cast<float>(val), width));
}

template <>
__forceinline__ __device__ paddle::platform::complex<float> CudaShuffleXorSync(
unsigned mask, paddle::platform::complex<float> val, int width) {
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/gpu/math_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ DEFINE_CUDA_ELEMENTWISE_OP(Divide)
} // namespace phi

using float16 = phi::dtype::float16;
using bfloat16 = phi::dtype::bfloat16;
using complex64 = ::phi::dtype::complex<float>;
using complex128 = ::phi::dtype::complex<double>;

Expand Down Expand Up @@ -128,6 +129,7 @@ PD_REGISTER_KERNEL(divide_raw,
int,
int64_t,
float16,
bfloat16,
complex64,
complex128) {}
PD_REGISTER_KERNEL(multiply_raw,
Expand Down
38 changes: 37 additions & 1 deletion python/paddle/fluid/tests/unittests/test_elementwise_div_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from op_test import OpTest, skip_check_grad_ci
from op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16


class ElementwiseDivOp(OpTest):
Expand Down Expand Up @@ -55,6 +55,42 @@ def init_dtype(self):
pass


@unittest.skipIf(
not core.is_compiled_with_cuda() or core.cudnn_version() < 8100,
"core is not compiled with CUDA and cudnn version need larger than 8.1.0")
class TestElementwiseDivOpBF16(OpTest):
def setUp(self):
self.op_type = "elementwise_div"
self.dtype = np.uint16

x = np.random.uniform(0.1, 1, [12, 13]).astype(np.float32)
y = np.random.uniform(0.1, 1, [12, 13]).astype(np.float32)

out = np.divide(x, y)

self.inputs = {
'X': convert_float_to_uint16(x),
'Y': convert_float_to_uint16(y)
}
self.outputs = {'Out': convert_float_to_uint16(out)}

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)

def test_check_grad_normal(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X', 'Y'], 'Out')

def test_check_grad_ingore_x(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['Y'], 'Out', no_grad_set=set("X"))

def test_check_grad_ingore_y(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out', no_grad_set=set('Y'))


@skip_check_grad_ci(
reason="[skip shape check] Use y_shape(1) to test broadcast.")
class TestElementwiseDivOp_scalar(ElementwiseDivOp):
Expand Down

0 comments on commit ca4df33

Please sign in to comment.