Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

segment parallel parameter sync and grad sync #57061

Merged
merged 15 commits into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 26 additions & 8 deletions python/paddle/distributed/fleet/base/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class ParallelMode:
TENSOR_PARALLEL = 1
PIPELINE_PARALLEL = 2
SHARDING_PARALLEL = 3
SEGMENT_PARALLEL = 4


class CommunicateTopology:
Expand Down Expand Up @@ -277,24 +278,41 @@ def __init__(self, topology):
_HYBRID_PARALLEL_GROUP = self

def get_parallel_mode(self):
# there are four modes : DataParallel / TensorParallel / PipelineParallel / ShardingParallel
# there are five modes : DataParallel / TensorParallel / PipelineParallel / ShardingParallel / SepParalel
# NOTE when sharding conjugates with other parallel, sharding should act like a optimizer and
# adding its parallel logic within that parallelism
# when use sharding alone, it should have its own parallelism for its parallel logic
# TODO modify 3 others parallel to support sharding

# pp -> mp -> sep -> sharding -> dp
if (
self._mp_degree == 1
and self._pp_degree == 1
and self._dp_degree == 1
self._pp_degree == 1
and self._mp_degree == 1
and self._sep_degree == 1
and self._sharding_degree == 1
and self._dp_degree > 1
):
return ParallelMode.DATA_PARALLEL
elif (
self._pp_degree == 1
and self._mp_degree == 1
and self._sep_degree == 1
and self._sharding_degree > 1
):
# sharding may coexist with dp
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果sharding_degree > 1 and dp_degree > 1时,这个逻辑跟之前不太一样?

Copy link
Contributor Author

@liuzhenhai93 liuzhenhai93 Sep 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯, 这个之前写的不优雅;
sharding > 1 and dp > 1 改成在 shardding_parallel里处理

return ParallelMode.SHARDING_PARALLEL
elif self._mp_degree == 1 and self._pp_degree == 1:
return ParallelMode.DATA_PARALLEL
elif self._mp_degree > 1 and self._pp_degree == 1:
elif (
self._pp_degree == 1
and self._mp_degree == 1
and self._sep_degree > 1
):
# sep may coexist with dp and sharding
return ParallelMode.SEGMENT_PARALLEL
elif self._pp_degree == 1 and self._mp_degree > 1:
# tp may coexist with sep、dp and sharding
# initialize the seed
return ParallelMode.TENSOR_PARALLEL
elif self._pp_degree > 1:
# pp may coexist with mp、sep、dp and sharding
return ParallelMode.PIPELINE_PARALLEL

def _check_vaild_topo(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

from ...base.topology import ParallelMode
from ...utils.hybrid_parallel_util import (
fused_allreduce_gradients,
fused_allreduce_gradients_with_group,
unwrap_optimizer,
)
from ...utils.log_util import logger
Expand Down Expand Up @@ -271,6 +271,8 @@ def __init__(self, optimizer, hcg, strategy):

self._sharding_enable = self._hcg.get_sharding_parallel_world_size() > 1

self._sep_enabled = self._hcg.get_sep_parallel_world_size() > 1

if (
isinstance(self._inner_opt._grad_clip, ClipGradByGlobalNorm)
and not self._use_dp_mode
Expand Down Expand Up @@ -425,6 +427,21 @@ def _step(self, parameters_list):
moment2, src_rank, mp_group, mp_configs.sync_mode
)

def _fused_allreduce_gradients(self, parameter_list):
liuzhenhai93 marked this conversation as resolved.
Show resolved Hide resolved
if (not self._dp_enable) and (not self._sep_enabled):
return
group = None
# sep all reduce is not scaled
scale = 1.0
if self._dp_enable:
group = self._hcg.get_data_parallel_group()
scale = group.nranks
if self._sep_enabled:
sep_group = self._hcg.get_sep_parallel_group()
dp_sep_group = self._hcg.get_dp_sep_parallel_group()
group = sep_group if group is None else dp_sep_group
fused_allreduce_gradients_with_group(parameter_list, group, scale=scale)

@no_grad()
@framework.dygraph_only
def step(self):
Expand All @@ -433,8 +450,7 @@ def step(self):
assert isinstance(self._inner_opt, DygraphShardingOptimizer)
self._inner_opt.reduce_gradients(list(parameters_list), self._hcg)

if self._dp_enable:
fused_allreduce_gradients(list(parameters_list), self._hcg)
self._fused_allreduce_gradients(list(parameters_list))

self._step(parameters_list)

Expand All @@ -455,8 +471,7 @@ def minimize(
assert isinstance(self._inner_opt, DygraphShardingOptimizer)
self._inner_opt.reduce_gradients(list(parameter_list), self._hcg)

if self._dp_enable:
fused_allreduce_gradients(list(parameter_list), self._hcg)
self._fused_allreduce_gradients(list(parameter_list))

return self._inner_opt.minimize(
loss, startup_program, parameter_list, no_grad_set
Expand Down
1 change: 1 addition & 0 deletions python/paddle/distributed/fleet/meta_parallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,6 @@
from .pipeline_parallel import PipelineParallel # noqa: F401
from .pipeline_parallel import PipelineParallelWithInterleave # noqa: F401
from .sharding_parallel import ShardingParallel # noqa: F401
from .segment_parallel import SegmentParallel # noqa: F401

__all__ = []
12 changes: 12 additions & 0 deletions python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ..utils.hybrid_parallel_util import (
broadcast_dp_parameters,
broadcast_mp_parameters,
broadcast_sep_parameters,
broadcast_sharding_parameters,
)
from ..utils.log_util import logger
Expand Down Expand Up @@ -138,6 +139,7 @@ def __init__(self, layers, hcg, strategy):
super().__init__(layers, hcg, strategy)
self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1
self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1
self.use_sep_parallel = self._hcg.get_sep_parallel_world_size() > 1
self.use_sharding_parallel = (
self._hcg.get_sharding_parallel_world_size() > 1
)
Expand All @@ -161,7 +163,13 @@ def __init__(self, layers, hcg, strategy):
self.stage_id = self._hcg.get_stage_id()
self.global_rank = self._hcg.get_global_rank()
self.pp_group = self._hcg.get_pipe_parallel_group()

self.dp_group = self._hcg.get_data_parallel_group()

# fused sep and dp
if self.use_sep_parallel:
self.dp_group = self._hcg.get_dp_sep_parallel_group()

self.sharding_group = self._hcg.get_sharding_parallel_group()

self._virtual_pp_world_size = None
Expand Down Expand Up @@ -251,6 +259,10 @@ def __init__(self, layers, hcg, strategy):
logger.info("start broadcast mp parameters")
broadcast_mp_parameters(self._layers, self._hcg)

if self.use_sep_parallel:
logger.info("start broadcast sep parameters")
broadcast_sep_parameters(self._layers, self._hcg)

if self.use_sharding_parallel:
logger.info("start broadcast sharding parameters")
broadcast_sharding_parameters(self._layers, self._hcg)
Expand Down
40 changes: 40 additions & 0 deletions python/paddle/distributed/fleet/meta_parallel/segment_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from ..utils.hybrid_parallel_util import (
broadcast_dp_parameters,
broadcast_sep_parameters,
broadcast_sharding_parameters,
)
from ..utils.log_util import logger
from .meta_parallel_base import MetaParallelBase

__all__ = []


class SegmentParallel(MetaParallelBase):
def __init__(self, layers, hcg, **kwargs):
super().__init__(layers, hcg, **kwargs)

def _prepare_for_model(self):
logger.info("start broadcast sep parameters")
broadcast_sep_parameters(self._layers, self._hcg)

if self._hcg.get_sharding_parallel_world_size() > 1:
logger.info("start broadcast sharding parameters")
broadcast_sharding_parameters(self._layers, self._hcg)

if self._hcg.get_data_parallel_world_size() > 1:
logger.info("start broadcast dp parameters")
broadcast_dp_parameters(self._layers, self._hcg)
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from ..utils.hybrid_parallel_util import broadcast_sharding_parameters
from ..utils.hybrid_parallel_util import (
broadcast_dp_parameters,
broadcast_sharding_parameters,
)
from ..utils.log_util import logger
from .meta_parallel_base import MetaParallelBase

Expand All @@ -27,6 +30,8 @@ def _prepare_for_model(self):
logger.info("start broadcast sharding parameters")
broadcast_sharding_parameters(self._layers, self._hcg)

# TODO (JZ-LIANG) to support Sharding-DP
if self._hcg.get_data_parallel_world_size() > 1:
logger.info("start broadcast dp parameters")
broadcast_dp_parameters(self._layers, self._hcg)

logger.info("sharding's parameters is ready")
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
broadcast_dp_parameters,
broadcast_input_data,
broadcast_mp_parameters,
broadcast_sep_parameters,
broadcast_sharding_parameters,
)
from ..utils.log_util import logger
Expand All @@ -32,6 +33,10 @@ def _prepare_for_model(self):
logger.info("start broadcast mp parameters")
broadcast_mp_parameters(self._layers, self._hcg)

if self._hcg.get_sep_parallel_world_size() > 1:
logger.info("start broadcast sep parameters")
broadcast_sep_parameters(self._layers, self._hcg)

if self._hcg.get_sharding_parallel_world_size() > 1:
logger.info("start broadcast sharding parameters")
broadcast_sharding_parameters(self._layers, self._hcg)
Expand Down
15 changes: 3 additions & 12 deletions python/paddle/distributed/fleet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
PipelineLayer,
PipelineParallel,
PipelineParallelWithInterleave,
SegmentParallel,
ShardingParallel,
TensorParallel,
)
Expand Down Expand Up @@ -130,25 +131,15 @@ def distributed_model(model):
if fleet_env._hcg.get_parallel_mode() == ParallelMode.SHARDING_PARALLEL:
model = ShardingParallel(model, fleet_env._hcg, strategy=strategy)
elif fleet_env._hcg.get_parallel_mode() == ParallelMode.DATA_PARALLEL:
# NOTE (JZ-LIANG) init parameters broadcast within sharding group
# normally it should be done inside DataParallel
if fleet_env.sharding_degree > 1:
from paddle.distributed.fleet.utils.hybrid_parallel_util import (
broadcast_sharding_parameters,
)

assert (
fleet_env.sharding_degree
== fleet_env._hcg.get_sharding_parallel_world_size()
)
broadcast_sharding_parameters(model, fleet_env._hcg)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个地方删除代码是为了什么?

Copy link
Contributor Author

@liuzhenhai93 liuzhenhai93 Sep 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sharding > 1 and dp > 1 改成在 shardding_parallel里处理

model = paddle.DataParallel(
model,
comm_buffer_size=strategy.fuse_grad_size_in_MB,
last_comm_buffer_size=strategy.last_comm_group_size_MB,
find_unused_parameters=strategy.find_unused_parameters,
group=fleet_env._hcg.get_data_parallel_group(),
)
elif fleet_env._hcg.get_parallel_mode() == ParallelMode.SEGMENT_PARALLEL:
model = SegmentParallel(model, fleet_env._hcg, strategy=strategy)
elif fleet_env._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL:
model = TensorParallel(model, fleet_env._hcg, strategy=strategy)
elif fleet_env._hcg.get_parallel_mode() == ParallelMode.PIPELINE_PARALLEL:
Expand Down
5 changes: 5 additions & 0 deletions python/paddle/distributed/fleet/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,17 @@ def _dygraph_distributed_optimizer(optimizer, strategy=None):
if fleet_env._user_defined_strategy.hybrid_configs[
"pp_configs"
].dp_comm_overlap:
# grad all-reduce of dp and sep with be fused
hp_optim._dp_enable = False
hp_optim._sep_enabled = False

if fleet_env._user_defined_strategy.hybrid_configs[
"pp_configs"
].sharding_comm_overlap:
hp_optim._sharding_enable = False
assert (
not hp_optim._sep_enabled
), "sep parallel can not coexist with sharding_comm_overlap"

return hp_optim
else:
Expand Down
8 changes: 8 additions & 0 deletions python/paddle/distributed/fleet/utils/hybrid_parallel_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,14 @@ def broadcast_sharding_parameters(model, hcg):
)


def broadcast_sep_parameters(model, hcg):
# TODO TO save memory, use un-fused broadcast to avoid potentional OOM
logger.debug("sep start init parameters sync")
sep_group = hcg.get_sep_parallel_group()
src_rank = hcg.get_sep_parallel_group_src_rank()
sync_params_buffers(model, sep_group, src_rank, is_model_parallel=False)


def unwrap_optimizer(optimizer, optimizer_instances=()):
_inner_opt = optimizer
while isinstance(_inner_opt, optimizer_instances):
Expand Down
Loading