Skip to content

Commit

Permalink
fix bugs (PaddlePaddle#43115)
Browse files Browse the repository at this point in the history
  • Loading branch information
haohongxiang authored and fuyou765 committed Jun 7, 2022
1 parent 342fa47 commit fa1310f
Showing 1 changed file with 6 additions and 11 deletions.
17 changes: 6 additions & 11 deletions python/paddle/distributed/fleet/utils/hybrid_parallel_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,17 +140,12 @@ def broadcast_dp_parameters(model, hcg):


def fused_allreduce_gradients(parameter_list, hcg):
if _in_legacy_dygraph():
data_parallel_group = None if hcg is None else hcg.get_data_parallel_group(
)
logger.debug("dp start fuse allreduce gradients")
with framework.no_grad():
_apply_collective_grads(parameter_list, data_parallel_group)
elif in_dygraph_mode():
assert hcg is None, "It's not support to use hcg in EagerDygraph now."
data_parallel_group = paddle.distributed.collective._get_default_group()
with framework.no_grad():
_apply_collective_grads_eager(parameter_list, data_parallel_group)
data_parallel_group = None if hcg is None else hcg.get_data_parallel_group()
logger.debug("dp start fuse allreduce gradients")
apply_func = _apply_collective_grads_eager if in_dygraph_mode(
) else _apply_collective_grads
with framework.no_grad():
apply_func(parameter_list, data_parallel_group)


def sharding_reduce_gradients(parameter_list, hcg):
Expand Down

0 comments on commit fa1310f

Please sign in to comment.