From 025d508347ad503d52c254afd1d52fca1bb889f6 Mon Sep 17 00:00:00 2001 From: Ghost Screaming Date: Wed, 11 Oct 2023 02:09:35 -0500 Subject: [PATCH] Polish mp async allreduce implementation. (#57965) * Remove paddle.ones as PR 57574 has fixed mp aysnc allreduce overlap failure * Add warnings for mp async all_reduce when CUDA_DEVICE_MAX_CONNECTIONS is unset. --- .../distributed/fleet/layers/mpu/mp_layers.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/python/paddle/distributed/fleet/layers/mpu/mp_layers.py b/python/paddle/distributed/fleet/layers/mpu/mp_layers.py index 7b59a6d5946403..57b3b3419a7663 100644 --- a/python/paddle/distributed/fleet/layers/mpu/mp_layers.py +++ b/python/paddle/distributed/fleet/layers/mpu/mp_layers.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + import paddle from paddle.autograd import PyLayer from paddle.base import core @@ -20,6 +22,7 @@ from ....communication.reduce import ReduceOp, _get_reduce_op from ...base import topology as tp +from ...utils.log_util import logger from . import mp_ops from .random import get_rng_state_tracker @@ -177,6 +180,9 @@ def forward(self, x): return output +_raise_cuda_env_unset_warning = True + + class InnerOverlapLinear(paddle.autograd.PyLayer): @staticmethod def forward( @@ -216,8 +222,17 @@ def backward(ctx, dy): task = ctx.model_parallel_group.process_group.all_reduce( dx, op_type, sync_op=False ) - # TODO(GhostScreaming): remove it in future. - tmp = paddle.ones([512]) + # Using small operation to preempt GPU SMs for all_reduce to achieve overlap. + if int(os.getenv("CUDA_DEVICE_MAX_CONNECTIONS", "0")) != 1: + global _raise_cuda_env_unset_warning + if _raise_cuda_env_unset_warning: + logger.warning( + "You set mp_async_allreduce=True, but you forget to set environment " + "variable CUDA_DEVICE_MAX_CONNECTIONS=1, which may leads to performance " + "loss. Try to export CUDA_DEVICE_MAX_CONNECTIONS=1 for better performance." + ) + _raise_cuda_env_unset_warning = False + tmp = paddle.ones([512]) if ctx.mp_fused_linear_param_grad_add: if not is_fused_linear_param_grad_add_supported():