Skip to content

Commit

Permalink
supports mp and dp hybrid (#34377)
Browse files Browse the repository at this point in the history
  • Loading branch information
FeixLiu authored Jul 27, 2021
1 parent 846be13 commit 937e21a
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 3 deletions.
22 changes: 19 additions & 3 deletions python/paddle/distributed/fleet/base/fleet_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,11 +269,27 @@ def init(self, role_maker=None, is_collective=False, strategy=None):
cg.set_comm_group('global', global_rank, global_world_size,
global_ring_id, global_ranks)

use_tensor_parallel = self._user_defined_strategy.tensor_parallel
use_mp = use_sharding or use_tensor_parallel

# hybrid group
if use_sharding is False: return
if use_mp is False: return

mp_degree_sharding = 1
mp_degree_tensor_parallel = 1
if use_sharding:
sharding_configs = self._user_defined_strategy.sharding_configs
mp_degree_sharding = int(sharding_configs['mp_degree'])

if use_tensor_parallel:
tensor_parallel_configs = self._user_defined_strategy.tensor_parallel_configs
mp_degree_tensor_parallel = int(tensor_parallel_configs[
'tensor_parallel_degree'])

if use_sharding and use_tensor_parallel:
assert mp_degree_sharding == mp_degree_tensor_parallel

sharding_configs = self._user_defined_strategy.sharding_configs
mp_degree = int(sharding_configs['mp_degree'])
mp_degree = mp_degree_sharding if use_sharding else mp_degree_tensor_parallel

if mp_degree > 1:
assert global_world_size % mp_degree == 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def setUp(self):
"mp_degree": self.model_parallel_size,
"sharding_degree": 2,
}
strategy.tensor_parallel = True
strategy.tensor_parallel_configs = {"tensor_parallel_degree": 2}
fleet.init(is_collective=True, strategy=strategy)

def get_program(self):
Expand Down

0 comments on commit 937e21a

Please sign in to comment.