Skip to content

Commit

Permalink
segment parallel parameter sync and grad sync (#57061)
Browse files Browse the repository at this point in the history
* sep grad all sync

* polish

* polish

* polish

* polish

* add test

* add test, and align with dp

* follow comment

* polish
  • Loading branch information
liuzhenhai93 authored Sep 14, 2023
1 parent 0ade0f8 commit cc045c0
Show file tree
Hide file tree
Showing 11 changed files with 332 additions and 51 deletions.
34 changes: 26 additions & 8 deletions python/paddle/distributed/fleet/base/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class ParallelMode:
TENSOR_PARALLEL = 1
PIPELINE_PARALLEL = 2
SHARDING_PARALLEL = 3
SEGMENT_PARALLEL = 4


class CommunicateTopology:
Expand Down Expand Up @@ -277,24 +278,41 @@ def __init__(self, topology):
_HYBRID_PARALLEL_GROUP = self

def get_parallel_mode(self):
# there are four modes : DataParallel / TensorParallel / PipelineParallel / ShardingParallel
# there are five modes : DataParallel / TensorParallel / PipelineParallel / ShardingParallel / SepParalel
# NOTE when sharding conjugates with other parallel, sharding should act like a optimizer and
# adding its parallel logic within that parallelism
# when use sharding alone, it should have its own parallelism for its parallel logic
# TODO modify 3 others parallel to support sharding

# pp -> mp -> sep -> sharding -> dp
if (
self._mp_degree == 1
and self._pp_degree == 1
and self._dp_degree == 1
self._pp_degree == 1
and self._mp_degree == 1
and self._sep_degree == 1
and self._sharding_degree == 1
and self._dp_degree > 1
):
return ParallelMode.DATA_PARALLEL
elif (
self._pp_degree == 1
and self._mp_degree == 1
and self._sep_degree == 1
and self._sharding_degree > 1
):
# sharding may coexist with dp
return ParallelMode.SHARDING_PARALLEL
elif self._mp_degree == 1 and self._pp_degree == 1:
return ParallelMode.DATA_PARALLEL
elif self._mp_degree > 1 and self._pp_degree == 1:
elif (
self._pp_degree == 1
and self._mp_degree == 1
and self._sep_degree > 1
):
# sep may coexist with dp and sharding
return ParallelMode.SEGMENT_PARALLEL
elif self._pp_degree == 1 and self._mp_degree > 1:
# tp may coexist with sep、dp and sharding
# initialize the seed
return ParallelMode.TENSOR_PARALLEL
elif self._pp_degree > 1:
# pp may coexist with mp、sep、dp and sharding
return ParallelMode.PIPELINE_PARALLEL

def _check_vaild_topo(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,8 @@ def __init__(self, optimizer, hcg, strategy):

self._sharding_enable = self._hcg.get_sharding_parallel_world_size() > 1

self._sep_enable = self._hcg.get_sep_parallel_world_size() > 1

if (
isinstance(self._inner_opt._grad_clip, ClipGradByGlobalNorm)
and not self._use_dp_mode
Expand Down Expand Up @@ -439,10 +441,7 @@ def _step(self, parameters_list):
moment2, src_rank, mp_group, mp_configs.sync_mode
)

@no_grad()
@framework.dygraph_only
def step(self):
parameter_list = list(obtain_optimizer_parameters_list(self._inner_opt))
def _hybrid_sync_grad(self, parameter_list):
dp_parameter_list = parameter_list
if self._sharding_enable:
assert isinstance(self._inner_opt, DygraphShardingOptimizer)
Expand All @@ -452,10 +451,14 @@ def step(self):
dp_parameter_list = self._inner_opt.filter_parameters(
parameter_list, self._hcg
)

if self._dp_enable:
if self._dp_enable or self._sep_enable:
fused_allreduce_gradients(dp_parameter_list, self._hcg)

@no_grad()
@framework.dygraph_only
def step(self):
parameter_list = list(obtain_optimizer_parameters_list(self._inner_opt))
self._hybrid_sync_grad(parameter_list)
self._step(parameter_list)

@no_grad()
Expand All @@ -470,21 +473,7 @@ def minimize(
else obtain_optimizer_parameters_list(self._inner_opt)
)
parameter_list = list(parameter_list)
dp_parameter_list = parameter_list
# Here sharding should use global parameter list
if self._sharding_enable:
assert isinstance(self._inner_opt, DygraphShardingOptimizer)
self._inner_opt.reduce_gradients(parameter_list, self._hcg)

# dp later do not need to use global parameter list
if not g_shard_norm_align_dp:
dp_parameter_list = self._inner_opt.filter_parameters(
parameter_list, self._hcg
)

if self._dp_enable:
fused_allreduce_gradients(dp_parameter_list, self._hcg)

self._hybrid_sync_grad(parameter_list)
return self._inner_opt.minimize(
loss, startup_program, parameter_list, no_grad_set
)
Expand Down
1 change: 1 addition & 0 deletions python/paddle/distributed/fleet/meta_parallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,6 @@
from .pipeline_parallel import PipelineParallel # noqa: F401
from .pipeline_parallel import PipelineParallelWithInterleave # noqa: F401
from .sharding_parallel import ShardingParallel # noqa: F401
from .segment_parallel import SegmentParallel # noqa: F401

__all__ = []
12 changes: 12 additions & 0 deletions python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ..utils.hybrid_parallel_util import (
broadcast_dp_parameters,
broadcast_mp_parameters,
broadcast_sep_parameters,
broadcast_sharding_parameters,
)
from ..utils.log_util import logger
Expand Down Expand Up @@ -138,6 +139,7 @@ def __init__(self, layers, hcg, strategy):
super().__init__(layers, hcg, strategy)
self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1
self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1
self.use_sep_parallel = self._hcg.get_sep_parallel_world_size() > 1
self.use_sharding_parallel = (
self._hcg.get_sharding_parallel_world_size() > 1
)
Expand All @@ -161,7 +163,13 @@ def __init__(self, layers, hcg, strategy):
self.stage_id = self._hcg.get_stage_id()
self.global_rank = self._hcg.get_global_rank()
self.pp_group = self._hcg.get_pipe_parallel_group()

self.dp_group = self._hcg.get_data_parallel_group()

# fused sep and dp
if self.use_sep_parallel:
self.dp_group = self._hcg.get_dp_sep_parallel_group()

self.sharding_group = self._hcg.get_sharding_parallel_group()

self._virtual_pp_world_size = None
Expand Down Expand Up @@ -251,6 +259,10 @@ def __init__(self, layers, hcg, strategy):
logger.info("start broadcast mp parameters")
broadcast_mp_parameters(self._layers, self._hcg)

if self.use_sep_parallel:
logger.info("start broadcast sep parameters")
broadcast_sep_parameters(self._layers, self._hcg)

if self.use_sharding_parallel:
logger.info("start broadcast sharding parameters")
broadcast_sharding_parameters(self._layers, self._hcg)
Expand Down
40 changes: 40 additions & 0 deletions python/paddle/distributed/fleet/meta_parallel/segment_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from ..utils.hybrid_parallel_util import (
broadcast_dp_parameters,
broadcast_sep_parameters,
broadcast_sharding_parameters,
)
from ..utils.log_util import logger
from .meta_parallel_base import MetaParallelBase

__all__ = []


class SegmentParallel(MetaParallelBase):
def __init__(self, layers, hcg, **kwargs):
super().__init__(layers, hcg, **kwargs)

def _prepare_for_model(self):
logger.info("start broadcast sep parameters")
broadcast_sep_parameters(self._layers, self._hcg)

if self._hcg.get_sharding_parallel_world_size() > 1:
logger.info("start broadcast sharding parameters")
broadcast_sharding_parameters(self._layers, self._hcg)

if self._hcg.get_data_parallel_world_size() > 1:
logger.info("start broadcast dp parameters")
broadcast_dp_parameters(self._layers, self._hcg)
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from ..utils.hybrid_parallel_util import broadcast_sharding_parameters
from ..utils.hybrid_parallel_util import (
broadcast_dp_parameters,
broadcast_sharding_parameters,
)
from ..utils.log_util import logger
from .meta_parallel_base import MetaParallelBase

Expand All @@ -27,6 +30,8 @@ def _prepare_for_model(self):
logger.info("start broadcast sharding parameters")
broadcast_sharding_parameters(self._layers, self._hcg)

# TODO (JZ-LIANG) to support Sharding-DP
if self._hcg.get_data_parallel_world_size() > 1:
logger.info("start broadcast dp parameters")
broadcast_dp_parameters(self._layers, self._hcg)

logger.info("sharding's parameters is ready")
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
broadcast_dp_parameters,
broadcast_input_data,
broadcast_mp_parameters,
broadcast_sep_parameters,
broadcast_sharding_parameters,
)
from ..utils.log_util import logger
Expand All @@ -32,6 +33,10 @@ def _prepare_for_model(self):
logger.info("start broadcast mp parameters")
broadcast_mp_parameters(self._layers, self._hcg)

if self._hcg.get_sep_parallel_world_size() > 1:
logger.info("start broadcast sep parameters")
broadcast_sep_parameters(self._layers, self._hcg)

if self._hcg.get_sharding_parallel_world_size() > 1:
logger.info("start broadcast sharding parameters")
broadcast_sharding_parameters(self._layers, self._hcg)
Expand Down
15 changes: 3 additions & 12 deletions python/paddle/distributed/fleet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
PipelineLayer,
PipelineParallel,
PipelineParallelWithInterleave,
SegmentParallel,
ShardingParallel,
TensorParallel,
)
Expand Down Expand Up @@ -130,25 +131,15 @@ def distributed_model(model):
if fleet_env._hcg.get_parallel_mode() == ParallelMode.SHARDING_PARALLEL:
model = ShardingParallel(model, fleet_env._hcg, strategy=strategy)
elif fleet_env._hcg.get_parallel_mode() == ParallelMode.DATA_PARALLEL:
# NOTE (JZ-LIANG) init parameters broadcast within sharding group
# normally it should be done inside DataParallel
if fleet_env.sharding_degree > 1:
from paddle.distributed.fleet.utils.hybrid_parallel_util import (
broadcast_sharding_parameters,
)

assert (
fleet_env.sharding_degree
== fleet_env._hcg.get_sharding_parallel_world_size()
)
broadcast_sharding_parameters(model, fleet_env._hcg)
model = paddle.DataParallel(
model,
comm_buffer_size=strategy.fuse_grad_size_in_MB,
last_comm_buffer_size=strategy.last_comm_group_size_MB,
find_unused_parameters=strategy.find_unused_parameters,
group=fleet_env._hcg.get_data_parallel_group(),
)
elif fleet_env._hcg.get_parallel_mode() == ParallelMode.SEGMENT_PARALLEL:
model = SegmentParallel(model, fleet_env._hcg, strategy=strategy)
elif fleet_env._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL:
model = TensorParallel(model, fleet_env._hcg, strategy=strategy)
elif fleet_env._hcg.get_parallel_mode() == ParallelMode.PIPELINE_PARALLEL:
Expand Down
5 changes: 5 additions & 0 deletions python/paddle/distributed/fleet/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,17 @@ def _dygraph_distributed_optimizer(optimizer, strategy=None):
if fleet_env._user_defined_strategy.hybrid_configs[
"pp_configs"
].dp_comm_overlap:
# grad all-reduce of dp and sep with be fused
hp_optim._dp_enable = False
hp_optim._sep_enable = False

if fleet_env._user_defined_strategy.hybrid_configs[
"pp_configs"
].sharding_comm_overlap:
hp_optim._sharding_enable = False
assert (
not hp_optim._sep_enable
), "sep parallel can not coexist with sharding_comm_overlap"

return hp_optim
else:
Expand Down
32 changes: 29 additions & 3 deletions python/paddle/distributed/fleet/utils/hybrid_parallel_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,27 @@ def fused_allreduce_gradients_with_group(


def fused_allreduce_gradients(parameter_list, hcg):
data_parallel_group = None if hcg is None else hcg.get_data_parallel_group()
logger.debug("dp start fuse allreduce gradients")
fused_allreduce_gradients_with_group(parameter_list, data_parallel_group)
group = None
scale = None
if hcg is not None:
dp_enabled = hcg.get_data_parallel_world_size() > 1
sep_enabled = hcg.get_sep_parallel_world_size() > 1
assert (
dp_enabled or sep_enabled
), f"dp_enabled {dp_enabled}; sep_enabled {sep_enabled}"
group = None
# sep all reduce is not scaled
scale = 1.0
if dp_enabled:
group = hcg.get_data_parallel_group()
scale = group.nranks
if sep_enabled:
sep_group = hcg.get_sep_parallel_group()
dp_sep_group = hcg.get_dp_sep_parallel_group()
group = sep_group if group is None else dp_sep_group

logger.debug("dp or sep start fuse allreduce gradients")
fused_allreduce_gradients_with_group(parameter_list, group, scale=scale)


def broadcast_sharding_parameters(model, hcg):
Expand All @@ -254,6 +272,14 @@ def broadcast_sharding_parameters(model, hcg):
)


def broadcast_sep_parameters(model, hcg):
# TODO TO save memory, use un-fused broadcast to avoid potentional OOM
logger.debug("sep start init parameters sync")
sep_group = hcg.get_sep_parallel_group()
src_rank = hcg.get_sep_parallel_group_src_rank()
sync_params_buffers(model, sep_group, src_rank, is_model_parallel=False)


def unwrap_optimizer(optimizer, optimizer_instances=()):
_inner_opt = optimizer
while isinstance(_inner_opt, optimizer_instances):
Expand Down
Loading

0 comments on commit cc045c0

Please sign in to comment.