From d049b52c6d7b82ced6b47111a4eaea1d668c315c Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Wed, 6 Sep 2023 11:19:50 +0000 Subject: [PATCH 1/9] sep grad all sync --- .../hybrid_parallel_optimizer.py | 22 ++++++++++++++----- .../fleet/meta_parallel/pipeline_parallel.py | 6 +++++ python/paddle/distributed/fleet/optimizer.py | 5 +++++ 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py index 6e11c19a40f80..3294dea79a50a 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py @@ -31,7 +31,7 @@ from ...base.topology import ParallelMode from ...utils.hybrid_parallel_util import ( - fused_allreduce_gradients, + fused_allreduce_gradients_with_group, unwrap_optimizer, ) from ...utils.log_util import logger @@ -271,6 +271,8 @@ def __init__(self, optimizer, hcg, strategy): self._sharding_enable = self._hcg.get_sharding_parallel_world_size() > 1 + self._sep_enabled = self._hcg.get_sep_parallel_world_size() > 1 + if ( isinstance(self._inner_opt._grad_clip, ClipGradByGlobalNorm) and not self._use_dp_mode @@ -425,6 +427,18 @@ def _step(self, parameters_list): moment2, src_rank, mp_group, mp_configs.sync_mode ) + def _fused_allreduce_gradients(self, parameter_list): + if (not self._dp_enable) and (not self._sep_enabled): + return + group = None + if self._dp_enable: + group = self._hcg.get_data_parallel_group() + if self._sep_enabled: + sep_group = self._hcg.get_sep_parallel_group() + dp_sep_group = self._hcg.get_dp_sep_parallel_group() + group = sep_group if group is None else dp_sep_group + fused_allreduce_gradients_with_group(parameter_list, group) + @no_grad() @framework.dygraph_only def step(self): @@ -433,8 +447,7 @@ def step(self): assert isinstance(self._inner_opt, DygraphShardingOptimizer) self._inner_opt.reduce_gradients(list(parameters_list), self._hcg) - if self._dp_enable: - fused_allreduce_gradients(list(parameters_list), self._hcg) + self._fused_allreduce_gradients(list(parameters_list)) self._step(parameters_list) @@ -455,8 +468,7 @@ def minimize( assert isinstance(self._inner_opt, DygraphShardingOptimizer) self._inner_opt.reduce_gradients(list(parameter_list), self._hcg) - if self._dp_enable: - fused_allreduce_gradients(list(parameter_list), self._hcg) + self._fused_allreduce_gradients(list(parameter_list)) return self._inner_opt.minimize( loss, startup_program, parameter_list, no_grad_set diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index c194ab8d4e5e2..9e2a891483629 100755 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -161,7 +161,13 @@ def __init__(self, layers, hcg, strategy): self.stage_id = self._hcg.get_stage_id() self.global_rank = self._hcg.get_global_rank() self.pp_group = self._hcg.get_pipe_parallel_group() + self.dp_group = self._hcg.get_data_parallel_group() + + # fused sep and dp + if self._hcg.get_sep_parallel_world_size() > 1: + self.dp_group = self._hcg.get_dp_sep_parallel_group() + self.sharding_group = self._hcg.get_sharding_parallel_group() self._virtual_pp_world_size = None diff --git a/python/paddle/distributed/fleet/optimizer.py b/python/paddle/distributed/fleet/optimizer.py index 85c27fd138ef3..edba794280b82 100755 --- a/python/paddle/distributed/fleet/optimizer.py +++ b/python/paddle/distributed/fleet/optimizer.py @@ -72,12 +72,17 @@ def _dygraph_distributed_optimizer(optimizer, strategy=None): if fleet_env._user_defined_strategy.hybrid_configs[ "pp_configs" ].dp_comm_overlap: + # grad all-reduce of dp and sep with be fused hp_optim._dp_enable = False + hp_optim._sep_enabled = False if fleet_env._user_defined_strategy.hybrid_configs[ "pp_configs" ].sharding_comm_overlap: hp_optim._sharding_enable = False + assert ( + not hp_optim._sep_enabled + ), "sep parallel can not coexist with sharding_comm_overlap" return hp_optim else: From 6edbe8cb7433eccd9b886222c6e3f4410157bc81 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Thu, 7 Sep 2023 05:33:40 +0000 Subject: [PATCH 2/9] polish --- .../paddle/distributed/fleet/base/topology.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/python/paddle/distributed/fleet/base/topology.py b/python/paddle/distributed/fleet/base/topology.py index bec592e6bb534..277053a5ae6c8 100644 --- a/python/paddle/distributed/fleet/base/topology.py +++ b/python/paddle/distributed/fleet/base/topology.py @@ -55,6 +55,7 @@ class ParallelMode: TENSOR_PARALLEL = 1 PIPELINE_PARALLEL = 2 SHARDING_PARALLEL = 3 + SEP_PARALLEL = 4 class CommunicateTopology: @@ -277,21 +278,27 @@ def __init__(self, topology): _HYBRID_PARALLEL_GROUP = self def get_parallel_mode(self): - # there are four modes : DataParallel / TensorParallel / PipelineParallel / ShardingParallel + # there are five modes : DataParallel / TensorParallel / PipelineParallel / ShardingParallel / SepParalel # NOTE when sharding conjugates with other parallel, sharding should act like a optimizer and # adding its parallel logic within that parallelism # when use sharding alone, it should have its own parallelism for its parallel logic # TODO modify 3 others parallel to support sharding if ( - self._mp_degree == 1 - and self._pp_degree == 1 + self._pp_degree == 1 + and self._mp_degree == 1 + and self._sep_degree == 1 and self._dp_degree == 1 - and self._sharding_degree > 1 ): return ParallelMode.SHARDING_PARALLEL - elif self._mp_degree == 1 and self._pp_degree == 1: + elif ( + self._pp_degree == 1 + and self._mp_degree == 1 + and self._sep_degree == 1 + ): return ParallelMode.DATA_PARALLEL - elif self._mp_degree > 1 and self._pp_degree == 1: + elif self._pp_degree == 1 and self._mp_degree == 1: + return ParallelMode.SEP_PARALLEL + elif self._pp_degree == 1: # initialize the seed return ParallelMode.TENSOR_PARALLEL elif self._pp_degree > 1: From a5700a37214a32dbef96fe3e9df054620146a45a Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Thu, 7 Sep 2023 06:27:29 +0000 Subject: [PATCH 3/9] polish --- .../paddle/distributed/fleet/base/topology.py | 27 +++++++++---- .../fleet/meta_parallel/pipeline_parallel.py | 8 +++- .../fleet/meta_parallel/segment_parallel.py | 39 +++++++++++++++++++ .../fleet/meta_parallel/sharding_parallel.py | 9 ++++- .../fleet/meta_parallel/tensor_parallel.py | 5 +++ python/paddle/distributed/fleet/model.py | 15 ++----- .../fleet/utils/hybrid_parallel_util.py | 8 ++++ 7 files changed, 88 insertions(+), 23 deletions(-) create mode 100644 python/paddle/distributed/fleet/meta_parallel/segment_parallel.py diff --git a/python/paddle/distributed/fleet/base/topology.py b/python/paddle/distributed/fleet/base/topology.py index 277053a5ae6c8..da6ceee08257e 100644 --- a/python/paddle/distributed/fleet/base/topology.py +++ b/python/paddle/distributed/fleet/base/topology.py @@ -55,7 +55,7 @@ class ParallelMode: TENSOR_PARALLEL = 1 PIPELINE_PARALLEL = 2 SHARDING_PARALLEL = 3 - SEP_PARALLEL = 4 + SEGMENT_PARALLEL = 4 class CommunicateTopology: @@ -282,26 +282,37 @@ def get_parallel_mode(self): # NOTE when sharding conjugates with other parallel, sharding should act like a optimizer and # adding its parallel logic within that parallelism # when use sharding alone, it should have its own parallelism for its parallel logic - # TODO modify 3 others parallel to support sharding + + # pp -> mp -> sep -> sharding -> dp if ( self._pp_degree == 1 and self._mp_degree == 1 and self._sep_degree == 1 - and self._dp_degree == 1 + and self._sharding_degree == 1 + and self._dp_degree > 1 ): - return ParallelMode.SHARDING_PARALLEL + return ParallelMode.DATA_PARALLEL elif ( self._pp_degree == 1 and self._mp_degree == 1 and self._sep_degree == 1 + and self._sharding_degree > 1 ): - return ParallelMode.DATA_PARALLEL - elif self._pp_degree == 1 and self._mp_degree == 1: - return ParallelMode.SEP_PARALLEL - elif self._pp_degree == 1: + # sharding may coexist with dp + return ParallelMode.SHARDING_PARALLEL + elif ( + self._pp_degree == 1 + and self._mp_degree == 1 + and self._sep_degree > 1 + ): + # sep may coexist with dp and sharding + return ParallelMode.SEGMENT_PARALLEL + elif self._pp_degree == 1 and self._mp_degree > 1: + # tp may coexist with sep、dp and sharding # initialize the seed return ParallelMode.TENSOR_PARALLEL elif self._pp_degree > 1: + # tp may coexist with mp、sep、dp and sharding return ParallelMode.PIPELINE_PARALLEL def _check_vaild_topo(self): diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 9e2a891483629..c0b0e2c258638 100755 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -24,6 +24,7 @@ from ..utils.hybrid_parallel_util import ( broadcast_dp_parameters, broadcast_mp_parameters, + broadcast_sep_parameters, broadcast_sharding_parameters, ) from ..utils.log_util import logger @@ -138,6 +139,7 @@ def __init__(self, layers, hcg, strategy): super().__init__(layers, hcg, strategy) self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1 self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1 + self.use_sep_parallel = self._hcg.get_sep_parallel_world_size() > 1 self.use_sharding_parallel = ( self._hcg.get_sharding_parallel_world_size() > 1 ) @@ -165,7 +167,7 @@ def __init__(self, layers, hcg, strategy): self.dp_group = self._hcg.get_data_parallel_group() # fused sep and dp - if self._hcg.get_sep_parallel_world_size() > 1: + if self.use_sep_parallel: self.dp_group = self._hcg.get_dp_sep_parallel_group() self.sharding_group = self._hcg.get_sharding_parallel_group() @@ -257,6 +259,10 @@ def __init__(self, layers, hcg, strategy): logger.info("start broadcast mp parameters") broadcast_mp_parameters(self._layers, self._hcg) + if self.use_sep_parallel: + logger.info("start broadcast mp parameters") + broadcast_sep_parameters(self._layers, self._hcg) + if self.use_sharding_parallel: logger.info("start broadcast sharding parameters") broadcast_sharding_parameters(self._layers, self._hcg) diff --git a/python/paddle/distributed/fleet/meta_parallel/segment_parallel.py b/python/paddle/distributed/fleet/meta_parallel/segment_parallel.py new file mode 100644 index 0000000000000..ff2024ced2a4d --- /dev/null +++ b/python/paddle/distributed/fleet/meta_parallel/segment_parallel.py @@ -0,0 +1,39 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..utils.hybrid_parallel_util import ( + broadcast_dp_parameters, + broadcast_sep_parameters, + broadcast_sharding_parameters, +) +from ..utils.log_util import logger +from .meta_parallel_base import MetaParallelBase + +__all__ = [] + + +class SegmentParallel(MetaParallelBase): + def __init__(self, layers, hcg, **kwargs): + super().__init__(layers, hcg, **kwargs) + + def _prepare_for_model(self): + broadcast_sep_parameters(self._layers, self._hcg) + + if self._hcg.get_sharding_parallel_world_size() > 1: + logger.info("start broadcast sharding parameters") + broadcast_sharding_parameters(self._layers, self._hcg) + + if self._hcg.get_data_parallel_world_size() > 1: + logger.info("start broadcast sharding parameters") + broadcast_dp_parameters(self._layers, self._hcg) diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding_parallel.py b/python/paddle/distributed/fleet/meta_parallel/sharding_parallel.py index d50d0b62e1bdf..1a90d3ae9f7b2 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding_parallel.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ..utils.hybrid_parallel_util import broadcast_sharding_parameters +from ..utils.hybrid_parallel_util import ( + broadcast_dp_parameters, + broadcast_sharding_parameters, +) from ..utils.log_util import logger from .meta_parallel_base import MetaParallelBase @@ -27,6 +30,8 @@ def _prepare_for_model(self): logger.info("start broadcast sharding parameters") broadcast_sharding_parameters(self._layers, self._hcg) - # TODO (JZ-LIANG) to support Sharding-DP + if self._hcg.get_data_parallel_world_size() > 1: + logger.info("start broadcast dp parameters") + broadcast_dp_parameters(self._layers, self._hcg) logger.info("sharding's parameters is ready") diff --git a/python/paddle/distributed/fleet/meta_parallel/tensor_parallel.py b/python/paddle/distributed/fleet/meta_parallel/tensor_parallel.py index 883533d8e1724..6da9dae096d46 100755 --- a/python/paddle/distributed/fleet/meta_parallel/tensor_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/tensor_parallel.py @@ -16,6 +16,7 @@ broadcast_dp_parameters, broadcast_input_data, broadcast_mp_parameters, + broadcast_sep_parameters, broadcast_sharding_parameters, ) from ..utils.log_util import logger @@ -32,6 +33,10 @@ def _prepare_for_model(self): logger.info("start broadcast mp parameters") broadcast_mp_parameters(self._layers, self._hcg) + if self._hcg.get_sep_parallel_world_size() > 1: + logger.info("start broadcast sep parameters") + broadcast_sep_parameters(self._layers, self._hcg) + if self._hcg.get_sharding_parallel_world_size() > 1: logger.info("start broadcast sharding parameters") broadcast_sharding_parameters(self._layers, self._hcg) diff --git a/python/paddle/distributed/fleet/model.py b/python/paddle/distributed/fleet/model.py index c849a94dcea5d..1213107b11b23 100755 --- a/python/paddle/distributed/fleet/model.py +++ b/python/paddle/distributed/fleet/model.py @@ -20,6 +20,7 @@ PipelineLayer, PipelineParallel, PipelineParallelWithInterleave, + SegmentParallel, ShardingParallel, TensorParallel, ) @@ -130,18 +131,6 @@ def distributed_model(model): if fleet_env._hcg.get_parallel_mode() == ParallelMode.SHARDING_PARALLEL: model = ShardingParallel(model, fleet_env._hcg, strategy=strategy) elif fleet_env._hcg.get_parallel_mode() == ParallelMode.DATA_PARALLEL: - # NOTE (JZ-LIANG) init parameters broadcast within sharding group - # normally it should be done inside DataParallel - if fleet_env.sharding_degree > 1: - from paddle.distributed.fleet.utils.hybrid_parallel_util import ( - broadcast_sharding_parameters, - ) - - assert ( - fleet_env.sharding_degree - == fleet_env._hcg.get_sharding_parallel_world_size() - ) - broadcast_sharding_parameters(model, fleet_env._hcg) model = paddle.DataParallel( model, comm_buffer_size=strategy.fuse_grad_size_in_MB, @@ -149,6 +138,8 @@ def distributed_model(model): find_unused_parameters=strategy.find_unused_parameters, group=fleet_env._hcg.get_data_parallel_group(), ) + elif fleet_env._hcg.get_parallel_mode() == ParallelMode.SEP_PARALLEL: + model = SegmentParallel(model, fleet_env._hcg, strategy=strategy) elif fleet_env._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL: model = TensorParallel(model, fleet_env._hcg, strategy=strategy) elif fleet_env._hcg.get_parallel_mode() == ParallelMode.PIPELINE_PARALLEL: diff --git a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py index 340ace6ed7b80..d7c49c58d3b08 100644 --- a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py +++ b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py @@ -254,6 +254,14 @@ def broadcast_sharding_parameters(model, hcg): ) +def broadcast_sep_parameters(model, hcg): + # TODO TO save memory, use un-fused broadcast to avoid potentional OOM + logger.debug("sep start init parameters sync") + sep_group = hcg.get_sep_parallel_group() + src_rank = hcg.get_sep_parallel_group_src_rank() + sync_params_buffers(model, sep_group, src_rank, is_model_parallel=False) + + def unwrap_optimizer(optimizer, optimizer_instances=()): _inner_opt = optimizer while isinstance(_inner_opt, optimizer_instances): From 07ede603cfd9f15714a74c1794bf384370523b80 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Thu, 7 Sep 2023 06:29:57 +0000 Subject: [PATCH 4/9] polish --- python/paddle/distributed/fleet/meta_parallel/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/distributed/fleet/meta_parallel/__init__.py b/python/paddle/distributed/fleet/meta_parallel/__init__.py index f507e2f636884..d50eb940b72bb 100644 --- a/python/paddle/distributed/fleet/meta_parallel/__init__.py +++ b/python/paddle/distributed/fleet/meta_parallel/__init__.py @@ -26,5 +26,6 @@ from .pipeline_parallel import PipelineParallel # noqa: F401 from .pipeline_parallel import PipelineParallelWithInterleave # noqa: F401 from .sharding_parallel import ShardingParallel # noqa: F401 +from .segment_parallel import SegmentParallel # noqa: F401 __all__ = [] From 08a9137a5d77949d097105a4eb4d3781a3908012 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Thu, 7 Sep 2023 06:33:23 +0000 Subject: [PATCH 5/9] polish --- .../distributed/fleet/meta_parallel/pipeline_parallel.py | 2 +- .../paddle/distributed/fleet/meta_parallel/segment_parallel.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index c0b0e2c258638..52ed996d2803e 100755 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -260,7 +260,7 @@ def __init__(self, layers, hcg, strategy): broadcast_mp_parameters(self._layers, self._hcg) if self.use_sep_parallel: - logger.info("start broadcast mp parameters") + logger.info("start broadcast sep parameters") broadcast_sep_parameters(self._layers, self._hcg) if self.use_sharding_parallel: diff --git a/python/paddle/distributed/fleet/meta_parallel/segment_parallel.py b/python/paddle/distributed/fleet/meta_parallel/segment_parallel.py index ff2024ced2a4d..3a48543c50fbc 100644 --- a/python/paddle/distributed/fleet/meta_parallel/segment_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/segment_parallel.py @@ -28,6 +28,7 @@ def __init__(self, layers, hcg, **kwargs): super().__init__(layers, hcg, **kwargs) def _prepare_for_model(self): + logger.info("start broadcast sep parameters") broadcast_sep_parameters(self._layers, self._hcg) if self._hcg.get_sharding_parallel_world_size() > 1: @@ -35,5 +36,5 @@ def _prepare_for_model(self): broadcast_sharding_parameters(self._layers, self._hcg) if self._hcg.get_data_parallel_world_size() > 1: - logger.info("start broadcast sharding parameters") + logger.info("start broadcast dp parameters") broadcast_dp_parameters(self._layers, self._hcg) From 67d8da802351558059273dd6e821de37d5bd1792 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Thu, 7 Sep 2023 09:14:20 +0000 Subject: [PATCH 6/9] add test --- .../paddle/distributed/fleet/base/topology.py | 2 +- .../fleet/hybrid_parallel_sep_model.py | 197 +++++++++++++++++- 2 files changed, 193 insertions(+), 6 deletions(-) diff --git a/python/paddle/distributed/fleet/base/topology.py b/python/paddle/distributed/fleet/base/topology.py index da6ceee08257e..21916c770e287 100644 --- a/python/paddle/distributed/fleet/base/topology.py +++ b/python/paddle/distributed/fleet/base/topology.py @@ -312,7 +312,7 @@ def get_parallel_mode(self): # initialize the seed return ParallelMode.TENSOR_PARALLEL elif self._pp_degree > 1: - # tp may coexist with mp、sep、dp and sharding + # pp may coexist with mp、sep、dp and sharding return ParallelMode.PIPELINE_PARALLEL def _check_vaild_topo(self): diff --git a/test/collective/fleet/hybrid_parallel_sep_model.py b/test/collective/fleet/hybrid_parallel_sep_model.py index 3bb83f27dc01d..96abc2d1e50bb 100644 --- a/test/collective/fleet/hybrid_parallel_sep_model.py +++ b/test/collective/fleet/hybrid_parallel_sep_model.py @@ -18,15 +18,150 @@ import numpy as np import paddle +from paddle.autograd import PyLayer from paddle.distributed import fleet +vocab_size = 20 +hidden_size = 10 +inner_size = 8 +output_size = 10 +seq_length = 10 +batch_size = 4 -class TestDistMPTraining(unittest.TestCase): - def setUp(self): - random.seed(2023) - np.random.seed(2023) - paddle.seed(2023) +class Concat(PyLayer): + @staticmethod + def forward(ctx, inp, axis, group): + inputs = [] + paddle.distributed.all_gather(inputs, inp, group=group) + with paddle.no_grad(): + cat = paddle.concat(inputs, axis=axis) + ctx.args_axis = axis + ctx.args_group = group + return cat + + @staticmethod + def backward(ctx, grad): + axis = ctx.args_axis + group = ctx.args_group + with paddle.no_grad(): + grads = paddle.split( + grad, paddle.distributed.get_world_size(group), axis=axis + ) + grad = grads[paddle.distributed.get_rank(group)] + return grad + + +class Split(PyLayer): + @staticmethod + def forward(ctx, inp, axis, group): + with paddle.no_grad(): + inps = paddle.split( + inp, paddle.distributed.get_world_size(group), axis=axis + ) + inp = inps[paddle.distributed.get_rank(group)] + + ctx.args_axis = axis + ctx.args_group = group + return inp + + @staticmethod + def backward(ctx, grad): + axis = ctx.args_axis + group = ctx.args_group + grads = [] + paddle.distributed.all_gather(grads, grad, group=group) + with paddle.no_grad(): + grad = paddle.concat(grads, axis=axis) + return grad + + +class SimpleNet(paddle.nn.Layer): + def __init__( + self, vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 + ): + super().__init__() + self.linear1 = paddle.nn.Linear( + hidden_size, + inner_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Assign(np_fc1) + ), + bias_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0) + ), + ) + + self.linear2 = paddle.nn.Linear( + inner_size, + hidden_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Assign(np_fc2) + ), + bias_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0) + ), + ) + + self.linear3 = paddle.nn.Linear( + hidden_size, + output_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0) + ), + bias_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0) + ), + ) + + self.embedding = paddle.nn.Embedding( + vocab_size, + hidden_size, + weight_attr=paddle.nn.initializer.Constant(value=0.5), + ) + + def forward(self, x): + x = self.embedding(x) + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + x = paddle.matmul(x, self.embedding.weight, transpose_y=True) + return x + + +class SEPModel(paddle.nn.Layer): + def __init__( + self, vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 + ): + self._net = SimpleNet( + vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 + ) + self._hcg = fleet.get_hybrid_communicate_group() + + def forward(self, x): + x = Split.Apply(x, axis=1, group=self._hcg.get_sep_parallel_group()) + x = self._net.forward(x) + x = Concat.Apply(x, axis=1, group=self._hcg.get_sep_parallel_group()) + loss = x.mean() + return loss + + +class DPModel(paddle.nn.Layer): + def __init__( + self, vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 + ): + self._net = SimpleNet( + vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 + ) + + def forward(self, x): + x = self._net.forward(x) + loss = x.mean() + return loss + + +class TestDistSEPTraining(unittest.TestCase): + def setUp(self): self.strategy = fleet.DistributedStrategy() self.strategy.hybrid_configs = { "sharding_degree": 1, @@ -46,6 +181,58 @@ def test_basic_hcg(self): assert hcg.get_dp_sep_parallel_group() is not None assert hcg.get_pp_mp_parallel_group() is not None + def train_batch(self, batch, model, optimizer): + output = model(batch) + loss = output.mean() + loss.backward() + optimizer.step() + optimizer.clear_grad() + return loss + + def build_optimizer(self, model): + optimizer = paddle.optimizer.SGD( + learning_rate=0.001, parameters=model.parameters() + ) + return optimizer + + def build_model(self, model_cls): + paddle.seed(2023) + np.random.seed(2023) + random.seed(2023) + np_fc1 = np.random.random_sample((hidden_size, inner_size)) + np_fc2 = np.random.random_sample((inner_size, hidden_size)) + + model = model_cls( + vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 + ) + + return model + + def test_sep_train(self): + sep_model = self.build_model(SEPModel) + sep_model = fleet.distributed_model(sep_model) + sep_optimizer = self.build_optimizer(sep_model) + sep_optimizer = fleet.distributed_optimizer(sep_optimizer) + dp_model = self.build_model(DPModel) + dp_optimizer = self.build_optimizer(dp_model) + + for _ in range(5): + np_data = np.random.randint( + 0, + vocab_size, + ( + batch_size, + seq_length, + ), + ) + batch = paddle.to_tensor(np_data) + loss_sep = self.train_batch(batch, sep_model, sep_optimizer) + loss_dp = self.train_batch(batch, dp_model, dp_optimizer) + + np.testing.assert_allclose( + loss_sep.numpy(), loss_dp.numpy(), rtol=1e-6 + ) + if __name__ == "__main__": unittest.main() From a882e9ebd5ca6d89f09fd5676c5f0a3563727717 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Thu, 7 Sep 2023 12:22:28 +0000 Subject: [PATCH 7/9] add test, and align with dp --- .../dygraph_optimizer/hybrid_parallel_optimizer.py | 5 ++++- python/paddle/distributed/fleet/model.py | 2 +- test/collective/fleet/hybrid_parallel_sep_model.py | 8 +++++--- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py index 3294dea79a50a..68bbe371ea7be 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py @@ -431,13 +431,16 @@ def _fused_allreduce_gradients(self, parameter_list): if (not self._dp_enable) and (not self._sep_enabled): return group = None + # sep all reduce is not scaled + scale = 1.0 if self._dp_enable: group = self._hcg.get_data_parallel_group() + scale = group.nranks if self._sep_enabled: sep_group = self._hcg.get_sep_parallel_group() dp_sep_group = self._hcg.get_dp_sep_parallel_group() group = sep_group if group is None else dp_sep_group - fused_allreduce_gradients_with_group(parameter_list, group) + fused_allreduce_gradients_with_group(parameter_list, group, scale=scale) @no_grad() @framework.dygraph_only diff --git a/python/paddle/distributed/fleet/model.py b/python/paddle/distributed/fleet/model.py index 1213107b11b23..f7fc29b8d27ab 100755 --- a/python/paddle/distributed/fleet/model.py +++ b/python/paddle/distributed/fleet/model.py @@ -138,7 +138,7 @@ def distributed_model(model): find_unused_parameters=strategy.find_unused_parameters, group=fleet_env._hcg.get_data_parallel_group(), ) - elif fleet_env._hcg.get_parallel_mode() == ParallelMode.SEP_PARALLEL: + elif fleet_env._hcg.get_parallel_mode() == ParallelMode.SEGMENT_PARALLEL: model = SegmentParallel(model, fleet_env._hcg, strategy=strategy) elif fleet_env._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL: model = TensorParallel(model, fleet_env._hcg, strategy=strategy) diff --git a/test/collective/fleet/hybrid_parallel_sep_model.py b/test/collective/fleet/hybrid_parallel_sep_model.py index 96abc2d1e50bb..c45913c5fa4e8 100644 --- a/test/collective/fleet/hybrid_parallel_sep_model.py +++ b/test/collective/fleet/hybrid_parallel_sep_model.py @@ -133,15 +133,16 @@ class SEPModel(paddle.nn.Layer): def __init__( self, vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 ): + super().__init__() self._net = SimpleNet( vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 ) self._hcg = fleet.get_hybrid_communicate_group() def forward(self, x): - x = Split.Apply(x, axis=1, group=self._hcg.get_sep_parallel_group()) + x = Split.apply(x, axis=1, group=self._hcg.get_sep_parallel_group()) x = self._net.forward(x) - x = Concat.Apply(x, axis=1, group=self._hcg.get_sep_parallel_group()) + x = Concat.apply(x, axis=1, group=self._hcg.get_sep_parallel_group()) loss = x.mean() return loss @@ -150,6 +151,7 @@ class DPModel(paddle.nn.Layer): def __init__( self, vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 ): + super().__init__() self._net = SimpleNet( vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 ) @@ -230,7 +232,7 @@ def test_sep_train(self): loss_dp = self.train_batch(batch, dp_model, dp_optimizer) np.testing.assert_allclose( - loss_sep.numpy(), loss_dp.numpy(), rtol=1e-6 + loss_sep.numpy(), loss_dp.numpy(), rtol=1e-3 ) From cb65c7a0968fcbcbcad54159f8d3e8b7b2f1c206 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Mon, 11 Sep 2023 11:09:39 +0000 Subject: [PATCH 8/9] follow comment --- .../hybrid_parallel_optimizer.py | 14 ++----------- .../fleet/utils/hybrid_parallel_util.py | 21 ++++++++++++++++--- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py index 68bbe371ea7be..8a18d42b5ba6b 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py @@ -31,7 +31,7 @@ from ...base.topology import ParallelMode from ...utils.hybrid_parallel_util import ( - fused_allreduce_gradients_with_group, + fused_allreduce_gradients, unwrap_optimizer, ) from ...utils.log_util import logger @@ -430,17 +430,7 @@ def _step(self, parameters_list): def _fused_allreduce_gradients(self, parameter_list): if (not self._dp_enable) and (not self._sep_enabled): return - group = None - # sep all reduce is not scaled - scale = 1.0 - if self._dp_enable: - group = self._hcg.get_data_parallel_group() - scale = group.nranks - if self._sep_enabled: - sep_group = self._hcg.get_sep_parallel_group() - dp_sep_group = self._hcg.get_dp_sep_parallel_group() - group = sep_group if group is None else dp_sep_group - fused_allreduce_gradients_with_group(parameter_list, group, scale=scale) + fused_allreduce_gradients(self, parameter_list, self._hcg) @no_grad() @framework.dygraph_only diff --git a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py index d7c49c58d3b08..ada149d50b9bb 100644 --- a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py +++ b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py @@ -239,9 +239,24 @@ def fused_allreduce_gradients_with_group( def fused_allreduce_gradients(parameter_list, hcg): - data_parallel_group = None if hcg is None else hcg.get_data_parallel_group() - logger.debug("dp start fuse allreduce gradients") - fused_allreduce_gradients_with_group(parameter_list, data_parallel_group) + dp_enabled = hcg.get_data_parallel_world_size() > 1 + sep_enabled = hcg.get_sep_parallel_world_size() > 1 + assert ( + dp_enabled or sep_enabled + ), f"dp_enabled {dp_enabled}; sep_enabled {sep_enabled}" + group = None + # sep all reduce is not scaled + scale = 1.0 + if dp_enabled: + group = hcg.get_data_parallel_group() + scale = group.nranks + if sep_enabled: + sep_group = hcg.get_sep_parallel_group() + dp_sep_group = hcg.get_dp_sep_parallel_group() + group = sep_group if group is None else dp_sep_group + + logger.debug("dp or sep start fuse allreduce gradients") + fused_allreduce_gradients_with_group(parameter_list, group, scale=scale) def broadcast_sharding_parameters(model, hcg): From 310ba7ef5ca0dc76d36ed813ebb482b2a8febbcc Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Thu, 14 Sep 2023 03:49:59 +0000 Subject: [PATCH 9/9] polish --- .../hybrid_parallel_optimizer.py | 36 ++++++------------- python/paddle/distributed/fleet/optimizer.py | 4 +-- .../fleet/utils/hybrid_parallel_util.py | 31 ++++++++-------- 3 files changed, 29 insertions(+), 42 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py index 771296aa0a865..26d446f9db795 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py @@ -285,7 +285,7 @@ def __init__(self, optimizer, hcg, strategy): self._sharding_enable = self._hcg.get_sharding_parallel_world_size() > 1 - self._sep_enabled = self._hcg.get_sep_parallel_world_size() > 1 + self._sep_enable = self._hcg.get_sep_parallel_world_size() > 1 if ( isinstance(self._inner_opt._grad_clip, ClipGradByGlobalNorm) @@ -441,15 +441,7 @@ def _step(self, parameters_list): moment2, src_rank, mp_group, mp_configs.sync_mode ) - def _fused_allreduce_gradients(self, parameter_list): - if (not self._dp_enable) and (not self._sep_enabled): - return - fused_allreduce_gradients(self, parameter_list, self._hcg) - - @no_grad() - @framework.dygraph_only - def step(self): - parameter_list = list(obtain_optimizer_parameters_list(self._inner_opt)) + def _hybrid_sync_grad(self, parameter_list): dp_parameter_list = parameter_list if self._sharding_enable: assert isinstance(self._inner_opt, DygraphShardingOptimizer) @@ -459,9 +451,14 @@ def step(self): dp_parameter_list = self._inner_opt.filter_parameters( parameter_list, self._hcg ) + if self._dp_enable or self._sep_enable: + fused_allreduce_gradients(dp_parameter_list, self._hcg) - self._fused_allreduce_gradients(dp_parameter_list) - + @no_grad() + @framework.dygraph_only + def step(self): + parameter_list = list(obtain_optimizer_parameters_list(self._inner_opt)) + self._hybrid_sync_grad(parameter_list) self._step(parameter_list) @no_grad() @@ -476,20 +473,7 @@ def minimize( else obtain_optimizer_parameters_list(self._inner_opt) ) parameter_list = list(parameter_list) - dp_parameter_list = parameter_list - # Here sharding should use global parameter list - if self._sharding_enable: - assert isinstance(self._inner_opt, DygraphShardingOptimizer) - self._inner_opt.reduce_gradients(parameter_list, self._hcg) - - # dp later do not need to use global parameter list - if not g_shard_norm_align_dp: - dp_parameter_list = self._inner_opt.filter_parameters( - parameter_list, self._hcg - ) - - self._fused_allreduce_gradients(dp_parameter_list) - + self._hybrid_sync_grad(parameter_list) return self._inner_opt.minimize( loss, startup_program, parameter_list, no_grad_set ) diff --git a/python/paddle/distributed/fleet/optimizer.py b/python/paddle/distributed/fleet/optimizer.py index edba794280b82..1f1439b3b0b05 100755 --- a/python/paddle/distributed/fleet/optimizer.py +++ b/python/paddle/distributed/fleet/optimizer.py @@ -74,14 +74,14 @@ def _dygraph_distributed_optimizer(optimizer, strategy=None): ].dp_comm_overlap: # grad all-reduce of dp and sep with be fused hp_optim._dp_enable = False - hp_optim._sep_enabled = False + hp_optim._sep_enable = False if fleet_env._user_defined_strategy.hybrid_configs[ "pp_configs" ].sharding_comm_overlap: hp_optim._sharding_enable = False assert ( - not hp_optim._sep_enabled + not hp_optim._sep_enable ), "sep parallel can not coexist with sharding_comm_overlap" return hp_optim diff --git a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py index ff1c7c6a9529a..1b5c3409d287a 100644 --- a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py +++ b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py @@ -239,21 +239,24 @@ def fused_allreduce_gradients_with_group( def fused_allreduce_gradients(parameter_list, hcg): - dp_enabled = hcg.get_data_parallel_world_size() > 1 - sep_enabled = hcg.get_sep_parallel_world_size() > 1 - assert ( - dp_enabled or sep_enabled - ), f"dp_enabled {dp_enabled}; sep_enabled {sep_enabled}" group = None - # sep all reduce is not scaled - scale = 1.0 - if dp_enabled: - group = hcg.get_data_parallel_group() - scale = group.nranks - if sep_enabled: - sep_group = hcg.get_sep_parallel_group() - dp_sep_group = hcg.get_dp_sep_parallel_group() - group = sep_group if group is None else dp_sep_group + scale = None + if hcg is not None: + dp_enabled = hcg.get_data_parallel_world_size() > 1 + sep_enabled = hcg.get_sep_parallel_world_size() > 1 + assert ( + dp_enabled or sep_enabled + ), f"dp_enabled {dp_enabled}; sep_enabled {sep_enabled}" + group = None + # sep all reduce is not scaled + scale = 1.0 + if dp_enabled: + group = hcg.get_data_parallel_group() + scale = group.nranks + if sep_enabled: + sep_group = hcg.get_sep_parallel_group() + dp_sep_group = hcg.get_dp_sep_parallel_group() + group = sep_group if group is None else dp_sep_group logger.debug("dp or sep start fuse allreduce gradients") fused_allreduce_gradients_with_group(parameter_list, group, scale=scale)