diff --git a/python/paddle/distributed/fleet/base/topology.py b/python/paddle/distributed/fleet/base/topology.py index 1bacdfb8df02e6..b6130b55bf6737 100644 --- a/python/paddle/distributed/fleet/base/topology.py +++ b/python/paddle/distributed/fleet/base/topology.py @@ -55,6 +55,7 @@ class ParallelMode: TENSOR_PARALLEL = 1 PIPELINE_PARALLEL = 2 SHARDING_PARALLEL = 3 + SEGMENT_PARALLEL = 4 class CommunicateTopology: @@ -277,24 +278,41 @@ def __init__(self, topology): _HYBRID_PARALLEL_GROUP = self def get_parallel_mode(self): - # there are four modes : DataParallel / TensorParallel / PipelineParallel / ShardingParallel + # there are five modes : DataParallel / TensorParallel / PipelineParallel / ShardingParallel / SepParalel # NOTE when sharding conjugates with other parallel, sharding should act like a optimizer and # adding its parallel logic within that parallelism # when use sharding alone, it should have its own parallelism for its parallel logic - # TODO modify 3 others parallel to support sharding + + # pp -> mp -> sep -> sharding -> dp if ( - self._mp_degree == 1 - and self._pp_degree == 1 - and self._dp_degree == 1 + self._pp_degree == 1 + and self._mp_degree == 1 + and self._sep_degree == 1 + and self._sharding_degree == 1 + and self._dp_degree > 1 + ): + return ParallelMode.DATA_PARALLEL + elif ( + self._pp_degree == 1 + and self._mp_degree == 1 + and self._sep_degree == 1 and self._sharding_degree > 1 ): + # sharding may coexist with dp return ParallelMode.SHARDING_PARALLEL - elif self._mp_degree == 1 and self._pp_degree == 1: - return ParallelMode.DATA_PARALLEL - elif self._mp_degree > 1 and self._pp_degree == 1: + elif ( + self._pp_degree == 1 + and self._mp_degree == 1 + and self._sep_degree > 1 + ): + # sep may coexist with dp and sharding + return ParallelMode.SEGMENT_PARALLEL + elif self._pp_degree == 1 and self._mp_degree > 1: + # tp may coexist with sep、dp and sharding # initialize the seed return ParallelMode.TENSOR_PARALLEL elif self._pp_degree > 1: + # pp may coexist with mp、sep、dp and sharding return ParallelMode.PIPELINE_PARALLEL def _check_vaild_topo(self): diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py index 37b04f78b752b0..26d446f9db7952 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py @@ -285,6 +285,8 @@ def __init__(self, optimizer, hcg, strategy): self._sharding_enable = self._hcg.get_sharding_parallel_world_size() > 1 + self._sep_enable = self._hcg.get_sep_parallel_world_size() > 1 + if ( isinstance(self._inner_opt._grad_clip, ClipGradByGlobalNorm) and not self._use_dp_mode @@ -439,10 +441,7 @@ def _step(self, parameters_list): moment2, src_rank, mp_group, mp_configs.sync_mode ) - @no_grad() - @framework.dygraph_only - def step(self): - parameter_list = list(obtain_optimizer_parameters_list(self._inner_opt)) + def _hybrid_sync_grad(self, parameter_list): dp_parameter_list = parameter_list if self._sharding_enable: assert isinstance(self._inner_opt, DygraphShardingOptimizer) @@ -452,10 +451,14 @@ def step(self): dp_parameter_list = self._inner_opt.filter_parameters( parameter_list, self._hcg ) - - if self._dp_enable: + if self._dp_enable or self._sep_enable: fused_allreduce_gradients(dp_parameter_list, self._hcg) + @no_grad() + @framework.dygraph_only + def step(self): + parameter_list = list(obtain_optimizer_parameters_list(self._inner_opt)) + self._hybrid_sync_grad(parameter_list) self._step(parameter_list) @no_grad() @@ -470,21 +473,7 @@ def minimize( else obtain_optimizer_parameters_list(self._inner_opt) ) parameter_list = list(parameter_list) - dp_parameter_list = parameter_list - # Here sharding should use global parameter list - if self._sharding_enable: - assert isinstance(self._inner_opt, DygraphShardingOptimizer) - self._inner_opt.reduce_gradients(parameter_list, self._hcg) - - # dp later do not need to use global parameter list - if not g_shard_norm_align_dp: - dp_parameter_list = self._inner_opt.filter_parameters( - parameter_list, self._hcg - ) - - if self._dp_enable: - fused_allreduce_gradients(dp_parameter_list, self._hcg) - + self._hybrid_sync_grad(parameter_list) return self._inner_opt.minimize( loss, startup_program, parameter_list, no_grad_set ) diff --git a/python/paddle/distributed/fleet/meta_parallel/__init__.py b/python/paddle/distributed/fleet/meta_parallel/__init__.py index f507e2f636884b..d50eb940b72bb1 100644 --- a/python/paddle/distributed/fleet/meta_parallel/__init__.py +++ b/python/paddle/distributed/fleet/meta_parallel/__init__.py @@ -26,5 +26,6 @@ from .pipeline_parallel import PipelineParallel # noqa: F401 from .pipeline_parallel import PipelineParallelWithInterleave # noqa: F401 from .sharding_parallel import ShardingParallel # noqa: F401 +from .segment_parallel import SegmentParallel # noqa: F401 __all__ = [] diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 367169aead956e..d90185d36d4661 100755 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -24,6 +24,7 @@ from ..utils.hybrid_parallel_util import ( broadcast_dp_parameters, broadcast_mp_parameters, + broadcast_sep_parameters, broadcast_sharding_parameters, ) from ..utils.log_util import logger @@ -138,6 +139,7 @@ def __init__(self, layers, hcg, strategy): super().__init__(layers, hcg, strategy) self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1 self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1 + self.use_sep_parallel = self._hcg.get_sep_parallel_world_size() > 1 self.use_sharding_parallel = ( self._hcg.get_sharding_parallel_world_size() > 1 ) @@ -161,7 +163,13 @@ def __init__(self, layers, hcg, strategy): self.stage_id = self._hcg.get_stage_id() self.global_rank = self._hcg.get_global_rank() self.pp_group = self._hcg.get_pipe_parallel_group() + self.dp_group = self._hcg.get_data_parallel_group() + + # fused sep and dp + if self.use_sep_parallel: + self.dp_group = self._hcg.get_dp_sep_parallel_group() + self.sharding_group = self._hcg.get_sharding_parallel_group() self._virtual_pp_world_size = None @@ -251,6 +259,10 @@ def __init__(self, layers, hcg, strategy): logger.info("start broadcast mp parameters") broadcast_mp_parameters(self._layers, self._hcg) + if self.use_sep_parallel: + logger.info("start broadcast sep parameters") + broadcast_sep_parameters(self._layers, self._hcg) + if self.use_sharding_parallel: logger.info("start broadcast sharding parameters") broadcast_sharding_parameters(self._layers, self._hcg) diff --git a/python/paddle/distributed/fleet/meta_parallel/segment_parallel.py b/python/paddle/distributed/fleet/meta_parallel/segment_parallel.py new file mode 100644 index 00000000000000..3a48543c50fbc6 --- /dev/null +++ b/python/paddle/distributed/fleet/meta_parallel/segment_parallel.py @@ -0,0 +1,40 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..utils.hybrid_parallel_util import ( + broadcast_dp_parameters, + broadcast_sep_parameters, + broadcast_sharding_parameters, +) +from ..utils.log_util import logger +from .meta_parallel_base import MetaParallelBase + +__all__ = [] + + +class SegmentParallel(MetaParallelBase): + def __init__(self, layers, hcg, **kwargs): + super().__init__(layers, hcg, **kwargs) + + def _prepare_for_model(self): + logger.info("start broadcast sep parameters") + broadcast_sep_parameters(self._layers, self._hcg) + + if self._hcg.get_sharding_parallel_world_size() > 1: + logger.info("start broadcast sharding parameters") + broadcast_sharding_parameters(self._layers, self._hcg) + + if self._hcg.get_data_parallel_world_size() > 1: + logger.info("start broadcast dp parameters") + broadcast_dp_parameters(self._layers, self._hcg) diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding_parallel.py b/python/paddle/distributed/fleet/meta_parallel/sharding_parallel.py index d50d0b62e1bdf1..1a90d3ae9f7b25 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding_parallel.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ..utils.hybrid_parallel_util import broadcast_sharding_parameters +from ..utils.hybrid_parallel_util import ( + broadcast_dp_parameters, + broadcast_sharding_parameters, +) from ..utils.log_util import logger from .meta_parallel_base import MetaParallelBase @@ -27,6 +30,8 @@ def _prepare_for_model(self): logger.info("start broadcast sharding parameters") broadcast_sharding_parameters(self._layers, self._hcg) - # TODO (JZ-LIANG) to support Sharding-DP + if self._hcg.get_data_parallel_world_size() > 1: + logger.info("start broadcast dp parameters") + broadcast_dp_parameters(self._layers, self._hcg) logger.info("sharding's parameters is ready") diff --git a/python/paddle/distributed/fleet/meta_parallel/tensor_parallel.py b/python/paddle/distributed/fleet/meta_parallel/tensor_parallel.py index 883533d8e1724d..6da9dae096d46c 100755 --- a/python/paddle/distributed/fleet/meta_parallel/tensor_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/tensor_parallel.py @@ -16,6 +16,7 @@ broadcast_dp_parameters, broadcast_input_data, broadcast_mp_parameters, + broadcast_sep_parameters, broadcast_sharding_parameters, ) from ..utils.log_util import logger @@ -32,6 +33,10 @@ def _prepare_for_model(self): logger.info("start broadcast mp parameters") broadcast_mp_parameters(self._layers, self._hcg) + if self._hcg.get_sep_parallel_world_size() > 1: + logger.info("start broadcast sep parameters") + broadcast_sep_parameters(self._layers, self._hcg) + if self._hcg.get_sharding_parallel_world_size() > 1: logger.info("start broadcast sharding parameters") broadcast_sharding_parameters(self._layers, self._hcg) diff --git a/python/paddle/distributed/fleet/model.py b/python/paddle/distributed/fleet/model.py index c849a94dcea5d3..f7fc29b8d27ab7 100755 --- a/python/paddle/distributed/fleet/model.py +++ b/python/paddle/distributed/fleet/model.py @@ -20,6 +20,7 @@ PipelineLayer, PipelineParallel, PipelineParallelWithInterleave, + SegmentParallel, ShardingParallel, TensorParallel, ) @@ -130,18 +131,6 @@ def distributed_model(model): if fleet_env._hcg.get_parallel_mode() == ParallelMode.SHARDING_PARALLEL: model = ShardingParallel(model, fleet_env._hcg, strategy=strategy) elif fleet_env._hcg.get_parallel_mode() == ParallelMode.DATA_PARALLEL: - # NOTE (JZ-LIANG) init parameters broadcast within sharding group - # normally it should be done inside DataParallel - if fleet_env.sharding_degree > 1: - from paddle.distributed.fleet.utils.hybrid_parallel_util import ( - broadcast_sharding_parameters, - ) - - assert ( - fleet_env.sharding_degree - == fleet_env._hcg.get_sharding_parallel_world_size() - ) - broadcast_sharding_parameters(model, fleet_env._hcg) model = paddle.DataParallel( model, comm_buffer_size=strategy.fuse_grad_size_in_MB, @@ -149,6 +138,8 @@ def distributed_model(model): find_unused_parameters=strategy.find_unused_parameters, group=fleet_env._hcg.get_data_parallel_group(), ) + elif fleet_env._hcg.get_parallel_mode() == ParallelMode.SEGMENT_PARALLEL: + model = SegmentParallel(model, fleet_env._hcg, strategy=strategy) elif fleet_env._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL: model = TensorParallel(model, fleet_env._hcg, strategy=strategy) elif fleet_env._hcg.get_parallel_mode() == ParallelMode.PIPELINE_PARALLEL: diff --git a/python/paddle/distributed/fleet/optimizer.py b/python/paddle/distributed/fleet/optimizer.py index 85c27fd138ef31..1f1439b3b0b051 100755 --- a/python/paddle/distributed/fleet/optimizer.py +++ b/python/paddle/distributed/fleet/optimizer.py @@ -72,12 +72,17 @@ def _dygraph_distributed_optimizer(optimizer, strategy=None): if fleet_env._user_defined_strategy.hybrid_configs[ "pp_configs" ].dp_comm_overlap: + # grad all-reduce of dp and sep with be fused hp_optim._dp_enable = False + hp_optim._sep_enable = False if fleet_env._user_defined_strategy.hybrid_configs[ "pp_configs" ].sharding_comm_overlap: hp_optim._sharding_enable = False + assert ( + not hp_optim._sep_enable + ), "sep parallel can not coexist with sharding_comm_overlap" return hp_optim else: diff --git a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py index edeb585384b03f..1b5c3409d287ad 100644 --- a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py +++ b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py @@ -239,9 +239,27 @@ def fused_allreduce_gradients_with_group( def fused_allreduce_gradients(parameter_list, hcg): - data_parallel_group = None if hcg is None else hcg.get_data_parallel_group() - logger.debug("dp start fuse allreduce gradients") - fused_allreduce_gradients_with_group(parameter_list, data_parallel_group) + group = None + scale = None + if hcg is not None: + dp_enabled = hcg.get_data_parallel_world_size() > 1 + sep_enabled = hcg.get_sep_parallel_world_size() > 1 + assert ( + dp_enabled or sep_enabled + ), f"dp_enabled {dp_enabled}; sep_enabled {sep_enabled}" + group = None + # sep all reduce is not scaled + scale = 1.0 + if dp_enabled: + group = hcg.get_data_parallel_group() + scale = group.nranks + if sep_enabled: + sep_group = hcg.get_sep_parallel_group() + dp_sep_group = hcg.get_dp_sep_parallel_group() + group = sep_group if group is None else dp_sep_group + + logger.debug("dp or sep start fuse allreduce gradients") + fused_allreduce_gradients_with_group(parameter_list, group, scale=scale) def broadcast_sharding_parameters(model, hcg): @@ -254,6 +272,14 @@ def broadcast_sharding_parameters(model, hcg): ) +def broadcast_sep_parameters(model, hcg): + # TODO TO save memory, use un-fused broadcast to avoid potentional OOM + logger.debug("sep start init parameters sync") + sep_group = hcg.get_sep_parallel_group() + src_rank = hcg.get_sep_parallel_group_src_rank() + sync_params_buffers(model, sep_group, src_rank, is_model_parallel=False) + + def unwrap_optimizer(optimizer, optimizer_instances=()): _inner_opt = optimizer while isinstance(_inner_opt, optimizer_instances): diff --git a/test/collective/fleet/hybrid_parallel_sep_model.py b/test/collective/fleet/hybrid_parallel_sep_model.py index 3bb83f27dc01d3..c45913c5fa4e89 100644 --- a/test/collective/fleet/hybrid_parallel_sep_model.py +++ b/test/collective/fleet/hybrid_parallel_sep_model.py @@ -18,15 +18,152 @@ import numpy as np import paddle +from paddle.autograd import PyLayer from paddle.distributed import fleet +vocab_size = 20 +hidden_size = 10 +inner_size = 8 +output_size = 10 +seq_length = 10 +batch_size = 4 -class TestDistMPTraining(unittest.TestCase): - def setUp(self): - random.seed(2023) - np.random.seed(2023) - paddle.seed(2023) +class Concat(PyLayer): + @staticmethod + def forward(ctx, inp, axis, group): + inputs = [] + paddle.distributed.all_gather(inputs, inp, group=group) + with paddle.no_grad(): + cat = paddle.concat(inputs, axis=axis) + ctx.args_axis = axis + ctx.args_group = group + return cat + + @staticmethod + def backward(ctx, grad): + axis = ctx.args_axis + group = ctx.args_group + with paddle.no_grad(): + grads = paddle.split( + grad, paddle.distributed.get_world_size(group), axis=axis + ) + grad = grads[paddle.distributed.get_rank(group)] + return grad + + +class Split(PyLayer): + @staticmethod + def forward(ctx, inp, axis, group): + with paddle.no_grad(): + inps = paddle.split( + inp, paddle.distributed.get_world_size(group), axis=axis + ) + inp = inps[paddle.distributed.get_rank(group)] + + ctx.args_axis = axis + ctx.args_group = group + return inp + + @staticmethod + def backward(ctx, grad): + axis = ctx.args_axis + group = ctx.args_group + grads = [] + paddle.distributed.all_gather(grads, grad, group=group) + with paddle.no_grad(): + grad = paddle.concat(grads, axis=axis) + return grad + + +class SimpleNet(paddle.nn.Layer): + def __init__( + self, vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 + ): + super().__init__() + self.linear1 = paddle.nn.Linear( + hidden_size, + inner_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Assign(np_fc1) + ), + bias_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0) + ), + ) + + self.linear2 = paddle.nn.Linear( + inner_size, + hidden_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Assign(np_fc2) + ), + bias_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0) + ), + ) + + self.linear3 = paddle.nn.Linear( + hidden_size, + output_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0) + ), + bias_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0) + ), + ) + + self.embedding = paddle.nn.Embedding( + vocab_size, + hidden_size, + weight_attr=paddle.nn.initializer.Constant(value=0.5), + ) + + def forward(self, x): + x = self.embedding(x) + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + x = paddle.matmul(x, self.embedding.weight, transpose_y=True) + return x + + +class SEPModel(paddle.nn.Layer): + def __init__( + self, vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 + ): + super().__init__() + self._net = SimpleNet( + vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 + ) + self._hcg = fleet.get_hybrid_communicate_group() + + def forward(self, x): + x = Split.apply(x, axis=1, group=self._hcg.get_sep_parallel_group()) + x = self._net.forward(x) + x = Concat.apply(x, axis=1, group=self._hcg.get_sep_parallel_group()) + loss = x.mean() + return loss + + +class DPModel(paddle.nn.Layer): + def __init__( + self, vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 + ): + super().__init__() + self._net = SimpleNet( + vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 + ) + + def forward(self, x): + x = self._net.forward(x) + loss = x.mean() + return loss + + +class TestDistSEPTraining(unittest.TestCase): + def setUp(self): self.strategy = fleet.DistributedStrategy() self.strategy.hybrid_configs = { "sharding_degree": 1, @@ -46,6 +183,58 @@ def test_basic_hcg(self): assert hcg.get_dp_sep_parallel_group() is not None assert hcg.get_pp_mp_parallel_group() is not None + def train_batch(self, batch, model, optimizer): + output = model(batch) + loss = output.mean() + loss.backward() + optimizer.step() + optimizer.clear_grad() + return loss + + def build_optimizer(self, model): + optimizer = paddle.optimizer.SGD( + learning_rate=0.001, parameters=model.parameters() + ) + return optimizer + + def build_model(self, model_cls): + paddle.seed(2023) + np.random.seed(2023) + random.seed(2023) + np_fc1 = np.random.random_sample((hidden_size, inner_size)) + np_fc2 = np.random.random_sample((inner_size, hidden_size)) + + model = model_cls( + vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 + ) + + return model + + def test_sep_train(self): + sep_model = self.build_model(SEPModel) + sep_model = fleet.distributed_model(sep_model) + sep_optimizer = self.build_optimizer(sep_model) + sep_optimizer = fleet.distributed_optimizer(sep_optimizer) + dp_model = self.build_model(DPModel) + dp_optimizer = self.build_optimizer(dp_model) + + for _ in range(5): + np_data = np.random.randint( + 0, + vocab_size, + ( + batch_size, + seq_length, + ), + ) + batch = paddle.to_tensor(np_data) + loss_sep = self.train_batch(batch, sep_model, sep_optimizer) + loss_dp = self.train_batch(batch, dp_model, dp_optimizer) + + np.testing.assert_allclose( + loss_sep.numpy(), loss_dp.numpy(), rtol=1e-3 + ) + if __name__ == "__main__": unittest.main()