diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index a88ff6107a2f5..d0020a2776be7 100644 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -269,18 +269,28 @@ def init(self, role_maker=None, is_collective=False, strategy=None): cg.set_comm_group('global', global_rank, global_world_size, global_ring_id, global_ranks) + use_tensor_parallel = self._user_defined_strategy.tensor_parallel + use_mp = use_sharding or use_tensor_parallel + # hybrid group - if use_sharding is False: return + if use_mp is False: return - sharding_configs = self._user_defined_strategy.sharding_configs - mp_degree = int(sharding_configs['mp_degree']) + mp_degree_sharding = 1 + mp_degree_tensor_parallel = 1 + if use_sharding: + sharding_configs = self._user_defined_strategy.sharding_configs + mp_degree_sharding = int(sharding_configs['mp_degree']) - use_tensor_parallel = self._user_defined_strategy.tensor_parallel if use_tensor_parallel: tensor_parallel_configs = self._user_defined_strategy.tensor_parallel_configs - mp_degree = int(tensor_parallel_configs[ + mp_degree_tensor_parallel = int(tensor_parallel_configs[ 'tensor_parallel_degree']) + if use_sharding and use_tensor_parallel: + assert mp_degree_sharding == mp_degree_tensor_parallel + + mp_degree = mp_degree_sharding if use_sharding else mp_degree_tensor_parallel + if mp_degree > 1: assert global_world_size % mp_degree == 0 # NOTE(wangxi): mp_ring_id sync with sharding_optimizer.py _build_groups