Skip to content

Commit

Permalink
Fix hang of hybrid parallel in new_group (#33141)
Browse files Browse the repository at this point in the history
* fix hang of hybrid parallel

* fix new_group for hang problem
  • Loading branch information
ForFishes authored Jun 4, 2021
1 parent d523dff commit 1e9299a
Showing 1 changed file with 30 additions and 24 deletions.
54 changes: 30 additions & 24 deletions python/paddle/distributed/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,31 +239,37 @@ def new_group(ranks=None, backend=None):
if global_rank not in ranks:
gp = Group(-1, -1, ring_id, ranks)
_group_map[ring_id] = gp
return gp

ranks = sorted(ranks)
group_rank = ranks.index(global_rank)
group_size = len(ranks)
gp = Group(group_rank, group_size, ring_id, ranks)
_group_map[ring_id] = gp

if group_size < 2:
return gp

strategy = core.ParallelStrategy()
strategy.nranks = group_size
strategy.local_rank = group_rank
strategy.trainer_endpoints = [genv.trainer_endpoints[i] for i in ranks]
strategy.current_endpoint = genv.current_endpoint
strategy.nrings = 1

if core.is_compiled_with_cuda():
place = core.CUDAPlace(genv.device_id)
core.NCCLParallelContext(strategy, place).init_with_ring_id(ring_id)
else:
assert False, ("no cuda device found")
# need to barrier to construct group
barrier(gp)
ranks = sorted(ranks)
group_rank = ranks.index(global_rank)
group_size = len(ranks)
gp = Group(group_rank, group_size, ring_id, ranks)
_group_map[ring_id] = gp

if group_size >= 2:
strategy = core.ParallelStrategy()
strategy.nranks = group_size
strategy.local_rank = group_rank
strategy.trainer_endpoints = [
genv.trainer_endpoints[i] for i in ranks
]
strategy.current_endpoint = genv.current_endpoint
strategy.nrings = 1

if core.is_compiled_with_cuda():
place = core.CUDAPlace(genv.device_id)
core.NCCLParallelContext(strategy,
place).init_with_ring_id(ring_id)
else:
assert False, ("no cuda device found")
else:
return gp

# TODO(shenliang03): This is a temporary solution to solve the problem of
# hang caused by cross-creation of new_group
tmp = fill_constant([0], dtype="int32", value="1")
paddle.distributed.all_reduce(tmp, use_calc_stream=True)
paddle.distributed.wait(tmp)
return gp


Expand Down

0 comments on commit 1e9299a

Please sign in to comment.