diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py index 53a9858316046..39260fd7b340f 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py @@ -513,6 +513,14 @@ def __init__(self, optimizer, hcg): self.pp_overlap = pp_config.sharding_comm_overlap self.pp_release_grads = pp_config.release_gradients + # Check nccl reduce_avg setting + self.use_reduce_avg = sharding_config.use_reduce_avg + if self.use_reduce_avg and (not is_avg_reduce_op_supported()): + self.use_reduce_avg = False + warnings.warn( + "nccl reduce_avg requires paddle compiled with cuda and nccl>=2.10.0, please check compilation setups." + ) + self._build_comm_buffers(acc_steps) # NOTE(shenliang03): Sort the comm_buffers by dst rank, # it will improve the performance in reduce communicate. Default @@ -579,6 +587,7 @@ def _build_comm_buffers(self, acc_steps, group_size=256 * 1024 * 1024): acc_steps, act=HOOK_ACTION.REDUCE_SCATTER, release_grads=self.pp_release_grads, + use_reduce_avg=self.use_reduce_avg, ) self._comm_buffer_list.append(buffer)