Skip to content

Commit

Permalink
update the logic
Browse files Browse the repository at this point in the history
  • Loading branch information
FeixLiu committed Jul 26, 2021
1 parent 7c6b80a commit 085510f
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions python/paddle/distributed/fleet/base/fleet_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,18 +269,28 @@ def init(self, role_maker=None, is_collective=False, strategy=None):
cg.set_comm_group('global', global_rank, global_world_size,
global_ring_id, global_ranks)

use_tensor_parallel = self._user_defined_strategy.tensor_parallel
use_mp = use_sharding or use_tensor_parallel

# hybrid group
if use_sharding is False: return
if use_mp is False: return

sharding_configs = self._user_defined_strategy.sharding_configs
mp_degree = int(sharding_configs['mp_degree'])
mp_degree_sharding = 1
mp_degree_tensor_parallel = 1
if use_sharding:
sharding_configs = self._user_defined_strategy.sharding_configs
mp_degree_sharding = int(sharding_configs['mp_degree'])

use_tensor_parallel = self._user_defined_strategy.tensor_parallel
if use_tensor_parallel:
tensor_parallel_configs = self._user_defined_strategy.tensor_parallel_configs
mp_degree = int(tensor_parallel_configs[
mp_degree_tensor_parallel = int(tensor_parallel_configs[
'tensor_parallel_degree'])

if use_sharding and use_tensor_parallel:
assert mp_degree_sharding == mp_degree_tensor_parallel

mp_degree = mp_degree_sharding if use_sharding else mp_degree_tensor_parallel

if mp_degree > 1:
assert global_world_size % mp_degree == 0
# NOTE(wangxi): mp_ring_id sync with sharding_optimizer.py _build_groups
Expand Down

1 comment on commit 085510f

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.