Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Completes bfloat16 dtype for collective api in eager mode #45844

Merged
merged 7 commits into from
Oct 11, 2022
3 changes: 3 additions & 0 deletions paddle/fluid/distributed/collective/ProcessGroupGloo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ namespace distributed {
case experimental::DataType::BOOL: \
func<bool>(args); \
break; \
case experimental::DataType::BFLOAT16: \
func<bfloat16>(args); \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这样是不是有问题,gloo支持bfloat吗

Copy link
Contributor Author

@HermitSun HermitSun Sep 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不加这行就会报错,加上这行之后就跑起来了🤔从昨天的测试结果来看好像没问题

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不是,不加这行肯定会报错,就我看gloo内部好像不支持bf16,比较好奇为什么这么可以过测试

Copy link
Contributor Author

@HermitSun HermitSun Sep 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有没有一种可能,paddle 的 bf16 tensor 里面装的实际上是 uint16,种种迹象表明他在 host 上好像并没有真正用 bf16?因为用的实际上是 uint16 所以能跑起来

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那nccl里支持的bfloat16和直接用uint16传有啥区别吗

Copy link
Contributor Author

@HermitSun HermitSun Sep 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nccl好像会判断cuda然后看能不能用bf16,gloo可能直接就用uint16了?

break; \
default: \
VLOG(0) << "Error: Unknown DataType."; \
exit(-1); \
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/distributed/collective/ProcessGroupNCCL.cc
Original file line number Diff line number Diff line change
Expand Up @@ -996,6 +996,9 @@ void* GetPointerByOffset(void* raw_pointer,
} else if (type == experimental::DataType::BOOL) {
return reinterpret_cast<void*>(reinterpret_cast<bool*>(raw_pointer) +
offset);
} else if (type == experimental::DataType::BFLOAT16) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AllReduce uint16 data?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As the code below, they use uint16 to represent bf16 for some reason🤔

#if defined(PADDLE_CUDA_BF16)
__nv_bfloat16 tmp = __float2bfloat16(val);
x = *reinterpret_cast<uint16_t*>(&tmp);
#else
std::memcpy(&x, reinterpret_cast<char*>(&val) + 2, 2);
#endif

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And it seems that we cannot use to_tensor or cast to get a uint16 tensor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This issue mentioned the uint16 problem, #34927

return reinterpret_cast<void*>(reinterpret_cast<uint16_t*>(raw_pointer) +
offset);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"This datatype in nccl is not supported."));
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/platform/device/gpu/nccl_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ inline ncclDataType_t ToNCCLDataType(framework::proto::VarType::Type type) {
return ncclUint8;
} else if (type == framework::proto::VarType::BOOL) {
return ncclUint8;
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
#if NCCL_VERSION_CODE >= 21000
} else if (type == framework::proto::VarType::BF16) {
return ncclBfloat16;
#endif
Expand All @@ -86,7 +86,7 @@ inline ncclDataType_t ToNCCLDataType(experimental::DataType type) {
return ncclInt8;
} else if (type == experimental::DataType::BOOL) {
return ncclUint8;
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
#if NCCL_VERSION_CODE >= 21000
} else if (type == experimental::DataType::BFLOAT16) {
return ncclBfloat16;
#endif
Expand Down
35 changes: 18 additions & 17 deletions python/paddle/distributed/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,8 @@ def is_initialized():

Check whether the distributed environment has been initialized

Returns (bool): `True` if distributed environment has been initialized, otherwise `False`.
Returns:
`True` if distributed environment has been initialized, otherwise `False`.

Examples:
.. code-block:: python
Expand Down Expand Up @@ -626,7 +627,7 @@ def broadcast(tensor, src, group=None, sync_op=True):

Args:
tensor (Tensor): The Tensor to send if current rank is the source, or the Tensor to receive otherwise. Its data type
should be float16, float32, float64, int32, int64, int8, uint8 or bool.
should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
src (int): The source rank.
group (Group, optional): The group instance return by new_group or None for global default group.
sync_op (bool, optional): Whether this op is a sync op. The default value is True.
Expand Down Expand Up @@ -709,7 +710,7 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, sync_op=True):

Args:
tensor (Tensor): The output Tensor for the destination and the input Tensor otherwise. Its data type
should be float16, float32, float64, int32, int64, int8, uint8 or bool.
should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
dst (int): The destination rank id.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD, optional): The operation used. Default value is ReduceOp.SUM.
group (Group, optional): The group instance return by new_group or None for global default group.
Expand Down Expand Up @@ -817,7 +818,7 @@ def all_gather(tensor_list, tensor, group=None, sync_op=True):

Args:
tensor_list (list): A list of output Tensors. Every element in the list must be a Tensor whose data type
should be float16, float32, float64, int32, int64, int8, uint8, bool, complex64 or complex128.
should be float16, float32, float64, int32, int64, int8, uint8, bool, bfloat16, complex64 or complex128.
tensor (Tensor): The Tensor to send. Its data type
should be float16, float32, float64, int32, int64, int8, uint8, bool, complex64 or complex128.
group (Group, optional): The group instance return by new_group or None for global default group.
Expand Down Expand Up @@ -999,9 +1000,9 @@ def scatter(tensor, tensor_list=None, src=0, group=None, sync_op=True):

Args:
tensor (Tensor): The output Tensor. Its data type
should be float16, float32, float64, int32, int64, int8, uint8 or bool.
should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
tensor_list (list|tuple): A list/tuple of Tensors to scatter. Every element in the list must be a Tensor whose data type
should be float16, float32, float64, int32, int64, int8, uint8 or bool. Default value is None.
should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16. Default value is None.
src (int): The source rank id. Default value is 0.
group (Group, optional): The group instance return by new_group or None for global default group.
sync_op (bool, optional): Whether this op is a sync op. The default value is True.
Expand Down Expand Up @@ -1096,7 +1097,7 @@ def alltoall(in_tensor_list, out_tensor_list, group=None, sync_op=True):

Args:
in_tensor_list (list): A list of input Tensors. Every element in the list must be a Tensor whose data type
should be float16, float32, float64, int32, int64, int8, uint8 or bool.
should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
out_tensor_list (list): A list of output Tensors. The data type of its elements should be the same as the
data type of the input Tensors.
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
Expand Down Expand Up @@ -1197,7 +1198,7 @@ def alltoall_single(in_tensor,
``alltoall_single`` is only supported in eager mode.

Args:
in_tensor (Tensor): Input tensor. The data type should be float16, float32, float64, int32, int64, int8, uint8 or bool.
in_tensor (Tensor): Input tensor. The data type should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
out_tensor (Tensor): Output Tensor. The data type should be the same as the data type of the input Tensor.
in_split_sizes (list[int], optional): Split sizes of ``in_tensor`` for dim[0]. If not given, dim[0] of ``in_tensor``
must be divisible by group size and ``in_tensor`` will be scattered averagely to all participators. Default: None.
Expand Down Expand Up @@ -1286,7 +1287,7 @@ def send(tensor, dst=0, group=None, sync_op=True):

Args:
tensor (Tensor): The Tensor to send. Its data type
should be float16, float32, float64, int32, int64, int8, uint8 or bool.
should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
dst (int): The destination rank id.
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
sync_op (bool, optional): Whether this op is a sync op. The default value is True.
Expand Down Expand Up @@ -1352,7 +1353,7 @@ def recv(tensor, src=0, group=None, sync_op=True):

Args:
tensor (Tensor): The Tensor to receive. Its data type
should be float16, float32, float64, int32, int64, int8, uint8 or bool.
should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
src (int): The source rank id.
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
sync_op (bool, optional): Whether this op is a sync op. The default value is True.
Expand Down Expand Up @@ -1435,7 +1436,7 @@ def isend(tensor, dst, group=None):

Args:
tensor (Tensor): The Tensor to send. Its data type
should be float16, float32, float64, int32, int64, int8, uint8 or bool.
should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
dst (int): The destination rank.
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.

Expand Down Expand Up @@ -1485,7 +1486,7 @@ def irecv(tensor, src=None, group=None):

Args:
tensor (Tensor): The Tensor to receive. Its data type
should be float16, float32, float64, int32, int64, int8, uint8 or bool.
should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
src (int): The source rank id.
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.

Expand Down Expand Up @@ -1594,7 +1595,7 @@ def batch_isend_irecv(p2p_op_list):
corresponding tasks. NCCL are currently supported.

Args:
p2p_op_list: A list of point-to-point operations(type of each operator is
p2p_op_list (List[P2POp]): A list of point-to-point operations(type of each operator is
``paddle.distributed.P2POp``). The order of the isend/irecv in the list
matters and it needs to match with corresponding isend/irecv on the
remote end.
Expand Down Expand Up @@ -1668,9 +1669,9 @@ def reduce_scatter(tensor,
Reduces, then scatters a list of tensors to all processes in a group

Args:
tensor (Tensor): Output tensor. Its data type should be float16, float32, float64, int32, int64, int8, uint8 or bool.
tensor (Tensor): Output tensor. Its data type should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
tensor_list (list[Tensor]): List of tensors to reduce and scatter. Every element in the list must be a Tensor whose data type
should be float16, float32, float64, int32, int64, int8, uint8 or bool.
should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM.
group (Group, optional): The group instance return by new_group or None for global
default group. Default: None.
Expand Down Expand Up @@ -1736,9 +1737,9 @@ def _reduce_scatter_base(output,
Reduces, then scatters a flattened tensor to all processes in a group.

Args:
output (Tensor): Output tensor. Its data type should be float16, float32, float64, int32, int64, int8, uint8 or bool.
output (Tensor): Output tensor. Its data type should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
input (Tensor): Input tensor that is of size output tensor size times world size. Its data type
should be float16, float32, float64, int32, int64, int8, uint8 or bool.
should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM.
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
Expand Down
18 changes: 9 additions & 9 deletions python/paddle/fluid/tests/unittests/collective/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,14 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
test_collective_allreduce_api MODULES test_collective_allreduce_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_collective_allreduce_api
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST")
PROPERTIES TIMEOUT "180" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
test_collective_alltoall_api MODULES test_collective_alltoall_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_collective_alltoall_api
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST")
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
bash_test_modules(
Expand All @@ -98,7 +98,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
test_collective_alltoall_single_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_collective_alltoall_single_api
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST")
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
Expand All @@ -125,7 +125,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
test_collective_broadcast_api MODULES test_collective_broadcast_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_collective_broadcast_api
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST")
PROPERTIES TIMEOUT "180" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
Expand Down Expand Up @@ -154,7 +154,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
test_collective_isend_irecv_api MODULES test_collective_isend_irecv_api
ENVS "http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_collective_isend_irecv_api
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST")
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
Expand Down Expand Up @@ -187,7 +187,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
test_collective_reduce_api MODULES test_collective_reduce_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_collective_reduce_api
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST")
PROPERTIES TIMEOUT "180" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
bash_test_modules(
Expand All @@ -207,7 +207,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
test_collective_reduce_scatter_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_collective_reduce_scatter_api
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST")
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
Expand All @@ -221,7 +221,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
test_collective_scatter_api MODULES test_collective_scatter_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_collective_scatter_api
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST")
PROPERTIES TIMEOUT "180" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
Expand All @@ -235,7 +235,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
test_collective_sendrecv_api MODULES test_collective_sendrecv_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_collective_sendrecv_api
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST")
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import paddle
import paddle.distributed as dist
import paddle.fluid as fluid
import test_collective_api_base as test_base

Expand All @@ -24,10 +25,18 @@ def __init__(self):

def get_model(self, main_prog, startup_program, rank, indata=None):
with fluid.program_guard(main_prog, startup_program):
tindata = paddle.to_tensor(indata)
tensor_list = []
paddle.distributed.all_gather(tensor_list, tindata)
return [tensor.numpy() for tensor in tensor_list]
# NOTE: this is a hack relying on an undocumented behavior that `to_tensor` uses uint16 to replace bfloat16
if indata.dtype == "bfloat16":
tindata = paddle.to_tensor(indata, "float32").cast("uint16")
dist.all_gather(tensor_list, tindata)
return [
tensor.cast("float32").numpy() for tensor in tensor_list
]
else:
tindata = paddle.to_tensor(indata)
dist.all_gather(tensor_list, tindata)
return [tensor.numpy() for tensor in tensor_list]


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import paddle
import paddle.distributed as dist
import paddle.fluid as fluid
import test_collective_api_base as test_base

Expand All @@ -24,9 +25,15 @@ def __init__(self):

def get_model(self, main_prog, startup_program, rank, indata=None):
with fluid.program_guard(main_prog, startup_program):
tindata = paddle.to_tensor(indata)
paddle.distributed.all_reduce(tindata)
return [tindata.numpy()]
# NOTE: this is a hack relying on an undocumented behavior that `to_tensor` uses uint16 to replace bfloat16
if indata.dtype == "bfloat16":
tindata = paddle.to_tensor(indata, "float32").cast("uint16")
dist.all_reduce(tindata)
return [tindata.cast("float32").numpy()]
else:
tindata = paddle.to_tensor(indata)
dist.all_reduce(tindata)
return [tindata.numpy()]


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,31 @@
# limitations under the License.

import paddle
import paddle.distributed as dist
import paddle.fluid as fluid
from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main
import test_collective_api_base as test_base


class TestCollectiveAllToAllAPI(TestCollectiveAPIRunnerBase):
class TestCollectiveAllToAllAPI(test_base.TestCollectiveAPIRunnerBase):

def __init__(self):
self.global_ring_id = 0

def get_model(self, main_prog, startup_program, rank, indata=None):
with fluid.program_guard(main_prog, startup_program):
tindata = paddle.to_tensor(indata)
tindata = paddle.split(tindata, 2, axis=0)
toutdata = []
paddle.distributed.alltoall(tindata, toutdata)
return [data.numpy() for data in toutdata]
# NOTE: this is a hack relying on an undocumented behavior that `to_tensor` uses uint16 to replace bfloat16
if indata.dtype == "bfloat16":
tindata = paddle.to_tensor(indata, "float32").cast("uint16")
tindata = paddle.split(tindata, 2, axis=0)
dist.alltoall(tindata, toutdata)
return [data.cast("float32").numpy() for data in toutdata]
else:
tindata = paddle.to_tensor(indata)
tindata = paddle.split(tindata, 2, axis=0)
dist.alltoall(tindata, toutdata)
return [data.numpy() for data in toutdata]


if __name__ == "__main__":
runtime_main(TestCollectiveAllToAllAPI, "alltoall")
test_base.runtime_main(TestCollectiveAllToAllAPI, "alltoall")
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import paddle
import paddle.distributed as dist
import paddle.fluid as fluid
import test_collective_api_base as test_base

Expand All @@ -24,10 +25,17 @@ def __init__(self):

def get_model(self, main_prog, startup_program, rank, indata=None):
with fluid.program_guard(main_prog, startup_program):
tindata = paddle.to_tensor(indata)
toutdata = paddle.to_tensor(indata)
paddle.distributed.alltoall_single(tindata, toutdata)
return [toutdata.numpy()]
# NOTE: this is a hack relying on an undocumented behavior that `to_tensor` uses uint16 to replace bfloat16
if indata.dtype == "bfloat16":
tindata = paddle.to_tensor(indata, "float32").cast("uint16")
toutdata = paddle.to_tensor(tindata, "float32").cast("uint16")
dist.alltoall_single(tindata, toutdata)
return [toutdata.cast("float32").numpy()]
else:
tindata = paddle.to_tensor(indata)
toutdata = paddle.to_tensor(indata)
dist.alltoall_single(tindata, toutdata)
return [toutdata.numpy()]


if __name__ == "__main__":
Expand Down
Loading