From a7e725cac020e6fc0b8b4fe494918707a6d92079 Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Mon, 13 Nov 2023 11:39:35 +0800 Subject: [PATCH 01/11] [cherry-pick]Stage1 no fuse opt frl (#58925) * reslove conflivt * format * no fuse for optimizer * bug fix * fix utest * fix utest --------- Co-authored-by: Yuang Liu --- .../framework/distributed_strategy.proto | 5 +- .../dygraph_sharding_optimizer.py | 157 ++++++++-- .../fleet/meta_parallel/pipeline_parallel.py | 19 +- .../fleet/meta_parallel/pp_utils/utils.py | 2 + .../sharding/group_sharded_storage.py | 14 + .../fleet/utils/tensor_fusion_helper.py | 295 +++++++++++++++++- ...rid_parallel_sharding_model_with_fusion.py | 189 +++++++++++ ...test_parallel_dygraph_sharding_parallel.py | 4 + 8 files changed, 631 insertions(+), 54 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_sharding_model_with_fusion.py diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 5d03b58fad7f46..2e84b3fe1ceda2 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -69,9 +69,12 @@ message PpConfig { message DygraphShardingConfig { optional bool split_param = 1 [ default = false ]; + optional bool tensor_fusion = 2 [ default = false ]; + optional int32 accumulate_steps = 3 [ default = 1 ]; + optional bool comm_overlap = 4 [ default = false ]; + optional bool fuse_optimizer = 5 [ default = true ]; } - message HybridConfig { optional int32 dp_degree = 1 [ default = -1 ]; optional int32 mp_degree = 2 [ default = 1 ]; diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py index da231f67395bfa..1d9a43bb90f976 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py @@ -31,6 +31,7 @@ HOOK_ACTION, FusedCommBuffer, assign_group_by_size, + fused_parameters, ) g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 1)) @@ -86,39 +87,90 @@ def __init__(self, optimizer, hcg): self._parameter_list.append(param) else: self._parameter_list = optimizer._parameter_list + self._origin_parameter_list = self._parameter_list self._inner_opt = optimizer self._hcg = hcg self._sharding_world_size = self._hcg.get_sharding_parallel_world_size() self._sharding_rank = self._hcg.get_sharding_parallel_rank() + strategy = fleet.fleet._user_defined_strategy + self.tensor_fusion = strategy.hybrid_configs[ + 'sharding_configs' + ].tensor_fusion + self.accumulate_steps = strategy.hybrid_configs[ + 'sharding_configs' + ].accumulate_steps + self.comm_overlap = strategy.hybrid_configs[ + 'sharding_configs' + ].comm_overlap + self.fuse_optimizer = strategy.hybrid_configs[ + 'sharding_configs' + ].fuse_optimizer + pp_overlap = strategy.hybrid_configs['pp_configs'].sharding_comm_overlap + if self.tensor_fusion or self.comm_overlap: + assert ( + not pp_overlap + ), "Can not enable pp's sharding_comm_overlap and sharding's tensor_fusion at the same time." + + self._use_main_grad = hasattr(self._parameter_list[0], "main_grad") + self._rank2decay = {} + self._rank2fused = {} + self._comm_buffers = [] + self._rank2params = self._partition_parameters() self._param2rank = self._map_param_to_rank() - if self._using_param_groups: - param_groups = [ - {"params": []} for _ in range(len(optimizer._param_groups)) - ] - for idx, pg in enumerate(optimizer._param_groups): - param_groups[idx].update( - {k: v for k, v in pg.items() if k != 'params'} + if not self.tensor_fusion: + if self._using_param_groups: + param_groups = [ + {"params": []} for _ in range(len(optimizer._param_groups)) + ] + for idx, pg in enumerate(optimizer._param_groups): + param_groups[idx].update( + {k: v for k, v in pg.items() if k != 'params'} + ) + for param in self._rank2params[self._sharding_rank]: + group_id = self._param_2_group_id[id(param)] + param_groups[group_id]['params'].append(param) + + self._set_inner_opt_attr('_param_groups', param_groups) + self._set_inner_opt_attr( + '_parameter_list', self._rank2params[self._sharding_rank] + ) + self._param_groups = self._parameter_list + else: + self._set_inner_opt_attr( + '_param_groups', self._rank2params[self._sharding_rank] + ) + self._set_inner_opt_attr( + '_parameter_list', self._rank2params[self._sharding_rank] ) - for param in self._rank2params[self._sharding_rank]: - group_id = self._param_2_group_id[id(param)] - param_groups[group_id]['params'].append(param) - self._set_inner_opt_attr('_param_groups', param_groups) - self._set_inner_opt_attr( - '_parameter_list', self._rank2params[self._sharding_rank] - ) - self._param_groups = self._parameter_list else: - self._set_inner_opt_attr( - '_param_groups', self._rank2params[self._sharding_rank] - ) - self._set_inner_opt_attr( - '_parameter_list', self._rank2params[self._sharding_rank] + self._tensor_fusion() + + decay_params = [ + p.name for p in self._rank2decay[self._sharding_rank] + ] + fused_params = self._rank2fused[self._sharding_rank] + apply_decay_param_fun = lambda x: x in decay_params + + all_fused_params = [] + for v in self._rank2fused.values(): + all_fused_params += v + self._parameter_list = all_fused_params + self._param_groups = all_fused_params + + self._set_inner_opt_attr('_parameter_list', fused_params) + self._set_inner_opt_attr('_param_groups', fused_params) + origin_decay_param_fun = getattr( + self._inner_opt, '_apply_decay_param_fun', None ) + if origin_decay_param_fun is not None: + self._set_inner_opt_attr( + '_apply_decay_param_fun', apply_decay_param_fun + ) def clear_grad(self, set_to_zero=True): """ @@ -133,7 +185,37 @@ def clear_grad(self, set_to_zero=True): p.main_grad._clear() p.main_grad = None elif not hasattr(p, "main_grad"): - p.clear_gradient(set_to_zero) + if self.tensor_fusion: + if set_to_zero: + p.grad.zero_() + else: + p.grad._clear() + p.grad = None + else: + p.clear_gradient(set_to_zero) + + def _tensor_fusion(self): + comm_group = self._hcg.get_sharding_parallel_group() + for i in range(self._sharding_world_size): + params = self._rank2params[i] + dst = comm_group.ranks[i] + # TODO(sharding dev): make scale_after_comm a field to be configured by user + decay_fused, all_fused, all_buffer = fused_parameters( + params, + use_main_grad=self._use_main_grad, + fuse_param=True, + comm_overlap=self.comm_overlap, + comm_group=comm_group, + dst=dst, + acc_step=self.accumulate_steps, + scale_after_comm=False, + ) + if self.comm_overlap: + self._comm_buffers += all_buffer + self._rank2decay[i] = decay_fused + self._rank2fused[i] = all_fused + for p in all_fused: + self._param2rank[p.name] = i def filter_parameters(self, parameter_list, hcg): sharding_parallel_rank = hcg.get_sharding_parallel_rank() @@ -197,6 +279,12 @@ def _map_param_to_rank(self): def reduce_gradients(self, parameter_list, hcg): # TODO merge grad / nrank with dp + logger.debug("sharding start gradients sync") + if self.comm_overlap: + for buffer in self._comm_buffers: + buffer.scale_grads() + return + with framework.no_grad(): sharding_nrank = hcg.get_sharding_parallel_group().nranks for param in parameter_list: @@ -236,7 +324,12 @@ def _sharding_sync_parameters(self): with framework.no_grad(): # TODO detach not need (?) - for rank, params in self._rank2params.items(): + valid_rank_to_params = ( + self._rank2params + if not self.tensor_fusion + else self._rank2fused + ) + for rank, params in valid_rank_to_params.items(): for param in params: paddle.distributed.broadcast( param, @@ -289,7 +382,12 @@ def step(self): origin_clip = self._inner_opt._grad_clip if not self._using_param_groups: params_grads = [] - for param in self._parameter_list: + target_param_list = ( + self._origin_parameter_list + if (not self.tensor_fusion or not self.fuse_optimizer) + else self._parameter_list + ) + for param in target_param_list: if ( hasattr(param, "regularizer") and param.regularizer is not None @@ -309,9 +407,14 @@ def step(self): params_grads = self._inner_opt._grad_clip(params_grads) # set inner_opt._grad_clip None to avoid repeatedly grad_clip gradients inside inner_opt._apply_optimize self._set_inner_opt_attr('_grad_clip', None) - update_param_names = [ - p.name for p in self._rank2params[self._sharding_rank] - ] + + rank_params = ( + self._rank2params[self._sharding_rank] + if (not self.tensor_fusion or not self.fuse_optimizer) + else self._rank2fused[self._sharding_rank] + ) + update_param_names = [p.name for p in rank_params] + update_params_grads = [ (p, g) for p, g in params_grads if p.name in update_param_names ] @@ -525,7 +628,7 @@ def reduce_gradients(self, parameter_list, hcg): comm_buffer._copy_grad_to_buffer(param) comm_buffer._comm_grads() - comm_buffer.scale_and_split_grads() + comm_buffer.scale_grads() def _sharding_sync_parameters(self): """ diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 44e72ad668415c..5e9b11fbc4520f 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -18,6 +18,11 @@ import paddle from paddle import framework +from paddle.distributed.fleet.utils.tensor_fusion_helper import ( + HOOK_ACTION, + FusedCommBuffer, + assign_group_by_size, +) from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer from ..utils import timer_helper as timer @@ -27,11 +32,6 @@ broadcast_sharding_parameters, ) from ..utils.log_util import logger -from ..utils.tensor_fusion_helper import ( - HOOK_ACTION, - FusedCommBuffer, - assign_group_by_size, -) from .meta_parallel_base import MetaParallelBase from .parallel_layers.pp_layers import PipelineLayer from .pp_utils import p2p_communication as p2p @@ -320,8 +320,8 @@ def fused_gradient( if act == HOOK_ACTION.REDUCE: # parse the relative dst rank to absolute dst rank for sharding dst = comm_group.ranks[dst] - var_groups = assign_group_by_size(parameter_list, group_size) + for group_idx, parameters in var_groups.items(): buffer = FusedCommBuffer( group_idx, @@ -458,8 +458,7 @@ def forward_backward_pipeline(self, data, scaler=None): ), "comm buffers should be created" for _, buffers in self._chunk_2_comm_buffers.items(): for buffer in buffers: - buffer.scale_and_split_grads() - + buffer.scale_grads() if self._enable_timer: self.timers("allreduce_shared_weight_gradients").start() self._layers.allreduce_shared_weight_gradients() @@ -879,7 +878,7 @@ def _sync_overlap_grads(self): for _, buffers in self._chunk_2_comm_buffers.items(): for buffer in buffers: - buffer.scale_and_split_grads() + buffer.scale_grads() def _backward_step_helper(self, micro_step): virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=False) @@ -1266,7 +1265,7 @@ def _sync_overlap_grads(self): for buffers in self._chunk_2_comm_buffers.values(): for buffer in buffers: - buffer.scale_and_split_grads() + buffer.scale_grads() def forward_backward_pipeline( self, data, scaler, forward_only=False, compute_loss=True diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py index 2578d6bb20b68b..b30184c935d12b 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py @@ -14,6 +14,8 @@ import paddle from paddle import _legacy_C_ops +__all__ = [] + FLOAT_TYPE_DICT = { paddle.float16: "float16", paddle.float32: "float32", diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py index 44c5995acc7963..0ee58dd4cbe3e2 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py @@ -30,6 +30,14 @@ from .group_sharded_utils import Type, cvt_to_device, device_guard +class BufferWarper(core.eager.Tensor): + def __init__(self): + super().__init__() + self.need_clip = True + self.is_distributed = False + self.trainable = True + + class InternalStorage: """ This is a basic class, which is responsible for consolidating the basic storage tensor. @@ -92,6 +100,12 @@ def to(self, device, dtype=None, keep_alignment=True): self.buffer = self.buffer.cast(dtype=dtype) self._dtype = dtype + def warp_buffer(self): + tmp_buffer = BufferWarper() + self._buffer = self.buffer + tmp_buffer.get_tensor()._share_data_with(self.buffer.get_tensor()) + self.buffer = tmp_buffer + class ParamStorage(InternalStorage): """ diff --git a/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py b/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py index b4041e8f0443c2..76b2373289b88b 100644 --- a/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py +++ b/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import itertools +import os import weakref from collections import OrderedDict @@ -24,6 +25,7 @@ alignment = { "gpu": 256, } + align = { paddle.float16.value: 2, paddle.bfloat16.value: 2, @@ -39,15 +41,22 @@ class HOOK_ACTION: REDUCE_SCATTER = 2 -def flatten_dense_tensors(parameters, use_main_grad=False, release_grad=False): +def flatten_dense_tensors( + parameters, + use_main_grad=False, + release_grads=False, + fuse_param=True, + warp_buffer=False, +): from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_storage import ( GradStorage, + ParamStorage, ) _buffer_size = 0 _param2align = {} _param2offset = {} - dtype = paddle.float32 if use_main_grad else parameters[0].dtype + dtype = parameters[0].dtype for param in parameters: assert param.trainable, "param must be trainable..." @@ -59,13 +68,22 @@ def flatten_dense_tensors(parameters, use_main_grad=False, release_grad=False): _buffer_size += np.prod(param.shape) + align_ _param2align[param.name] = align_ - if release_grad: + if release_grads: + assert not fuse_param + assert not warp_buffer return None, _buffer_size, _param2offset + if fuse_param: + param_storage = ParamStorage( + size=_buffer_size, dtype=dtype, device="gpu" + ) + param_storage.add_rank_params(parameters, _param2align) + # process gradient + grad_dtype = paddle.float32 if use_main_grad else dtype grad_storage = GradStorage( size=_buffer_size, - dtype=dtype, + dtype=grad_dtype, device="gpu", destination="0", parm2align=_param2align, @@ -74,7 +92,21 @@ def flatten_dense_tensors(parameters, use_main_grad=False, release_grad=False): for param in parameters: grad_storage.add_grad(param, _param2align[param.name]) - return grad_storage.buffer, _buffer_size, _param2offset + if warp_buffer: + if fuse_param: + param_storage.warp_buffer() + grad_storage.warp_buffer() + + if fuse_param: + if not use_main_grad: + # param_storage --> grad_storage + param_storage.buffer._copy_gradient_from(grad_storage.buffer) + else: + param_storage.buffer.main_grad = grad_storage.buffer + param_storage.buffer.stop_gradient = False + return param_storage, grad_storage, _buffer_size, _param2offset + else: + return grad_storage, _buffer_size, _param2offset class ShardingGradView: @@ -257,15 +289,24 @@ def __init__( act=None, dst=-1, release_grads=False, + use_main_grad=None, + fuse_param=False, + scale_after_comm=True, ): + self._id = id self._params = params self._acc_steps = acc_steps self._comm_group = comm_group self._release_grads = release_grads - self.use_main_grad = hasattr(self._params[0], "main_grad") - + self.use_main_grad = ( + use_main_grad + if use_main_grad is not None + else hasattr(self._params[0], "main_grad") + ) + self._fuse_param = fuse_param + self._scale_after_comm = scale_after_comm self._task = None self._dtype = ( paddle.float32 if self.use_main_grad else self._params[0].dtype @@ -289,13 +330,37 @@ def __init__( self._init_step_dict() if self._act != HOOK_ACTION.REDUCE_SCATTER: - ( - self.grad_storage, - self.buffer_size, - self.param2offset, - ) = flatten_dense_tensors( - self._params, self.use_main_grad, self._release_grads - ) + if self._fuse_param: + ( + self.param_storage, + self.grad_storage, + self.buffer_size, + self.param2offset, + ) = flatten_dense_tensors( + self._params, + release_grads=release_grads, + use_main_grad=use_main_grad, + fuse_param=True, + warp_buffer=True, + ) + self.param_storage = self.param_storage.buffer + self.grad_storage = self.grad_storage.buffer + else: + self.param_storage = None + ( + self.grad_storage, + self.buffer_size, + self.param2offset, + ) = flatten_dense_tensors( + self._params, + release_grads=release_grads, + use_main_grad=self.use_main_grad, + fuse_param=False, + warp_buffer=False, + ) + if not self._release_grads: + self.grad_storage = self.grad_storage.buffer + else: ( self._sharding_param_grad_view, @@ -441,6 +506,11 @@ def comm_grads(self): @imperative_base.no_grad def _comm_grads(self): + + if not self._scale_after_comm: + scale_factor = 1.0 / self._comm_group.nranks + self.grad_storage.scale_(scale_factor) + if self._act == HOOK_ACTION.ALL_REDUCE: task = paddle.distributed.all_reduce( self.grad_storage, group=self._comm_group, sync_op=False @@ -467,9 +537,13 @@ def _comm_grads(self): self._task = task @imperative_base.no_grad - def scale_and_split_grads(self): + def scale_grads(self): assert self._task is not None, "Task is not initialized. " self._task.wait() + if self._scale_after_comm: + scale_factor = 1.0 / self._comm_group.nranks + self.grad_storage.scale_(scale_factor) + scale_factor = 1.0 / self._comm_group.nranks self.grad_storage.scale_(scale_factor) @@ -494,3 +568,192 @@ def assign_group_by_size(parameters, group_size=128 * 1024 * 1024): var_groups.setdefault(group_idx, []).append(var) return var_groups + + +def bw_hook_func(buffer, param): + @paddle.autograd.no_grad() + def fused_comm(*_): + buffer.add_grad(param) + + return fused_comm + + +def obtain_storage( + parameters, + use_main_grad=False, + clip=True, + dist=False, + fuse_param=True, + comm_overlap=False, + act=None, + comm_group=None, + dst=-1, + acc_steps=1, + scale_after_comm=False, +): + if len(parameters) < 1: + return [], [] + + var_groups = assign_group_by_size(parameters) + storage = [] + buffers = [] + for group_idx, parameters in var_groups.items(): + comm_buffer = FusedCommBuffer( + group_idx, + parameters, + comm_group=comm_group, + acc_steps=acc_steps, + act=act, + dst=dst, + use_main_grad=use_main_grad, + fuse_param=fuse_param, + scale_after_comm=scale_after_comm, + ) + if fuse_param: + param_buffer = comm_buffer.param_storage + param_buffer.need_clip = clip + param_buffer.is_distributed = dist + storage.append(param_buffer) + if comm_overlap: + for param in parameters: + param._register_backward_hook(bw_hook_func(comm_buffer, param)) + buffers.append(comm_buffer) + + return storage, buffers + + +def filter_params(params, is_fp32, is_distributed, need_clip): + params = list( + filter( + lambda x: x.is_distributed + if is_distributed + else (not x.is_distributed), + params, + ) + ) + params = list( + filter( + lambda x: getattr(x, 'need_clip', True) + if need_clip + else (not getattr(x, 'need_clip', True)), + params, + ) + ) + params = list( + filter( + lambda x: x.dtype == paddle.float32 + if is_fp32 + else x.dtype != paddle.float32, + params, + ) + ) + dtype = None + for p in params: + if dtype is None: + dtype = p.dtype + else: + assert dtype == p.dtype + + return params, dtype + + +def fused_parameters( + parameters, + use_main_grad=False, + fuse_param=True, + comm_overlap=False, + comm_group=None, + dst=-1, + acc_step=1, + scale_after_comm=False, +): + """ + Fuse gradients. Fuse parameters if be enabled. Prepare for comm overlap if be enabled. + :param parameters: all parameters to be fused. + :param use_main_grad: does the gradient use main grad or not + :param comm_overlap: enable comm overlap or not + :param comm_group: the comm group for comm overlap + :param dst: the dst for comm overlap + :param acc_step: acc steps, using for comm overlap + :param fuse_param: fuse param or not + :param scale_after_comm: if enable comm overlap, specify the location of grad scale + :return: param storage if fused, comm buffers is comm overlap + """ + g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 0)) + act = ( + HOOK_ACTION.ALL_REDUCE if not g_shard_use_reduce else HOOK_ACTION.REDUCE + ) + if comm_overlap: + assert comm_group is not None + if act == HOOK_ACTION.REDUCE: + assert dst != -1 + elif act == HOOK_ACTION.ALL_REDUCE: + dst = -1 + param_groups = [] + attrs = [] + + is_fp32 = [True, False] + is_distributed = [True, False] + need_clip = [True, False] + + no_fp32_dtype = None + for fp32, dist, clip in itertools.product( + is_fp32, is_distributed, need_clip + ): + params, dtype = filter_params(parameters, fp32, dist, clip) + if not fp32: + if no_fp32_dtype is None: + no_fp32_dtype = dtype + elif dtype is not None: + assert no_fp32_dtype == dtype + attrs.append([dtype, dist, clip]) + param_groups.append(params) + + decay_fused = [] + all_fused = [] + all_buffers = [] + for params, attr in zip(param_groups, attrs): + decay_params = [] + other_params = [] + + for param in params: + if not any(nd in param.name for nd in ["bias", "norm", "b_0"]): + decay_params.append(param) + else: + other_params.append(param) + + is_distributed = attr[1] + need_clip = attr[2] + decay, decay_buffers = obtain_storage( + decay_params, + use_main_grad=use_main_grad, + clip=need_clip, + dist=is_distributed, + fuse_param=fuse_param, + comm_overlap=comm_overlap, + act=act, + comm_group=comm_group, + dst=dst, + acc_steps=acc_step, + scale_after_comm=scale_after_comm, + ) + other, other_buffers = obtain_storage( + other_params, + fuse_param=fuse_param, + comm_overlap=comm_overlap, + use_main_grad=use_main_grad, + clip=need_clip, + dist=is_distributed, + act=act, + comm_group=comm_group, + dst=dst, + acc_steps=acc_step, + scale_after_comm=scale_after_comm, + ) + decay_fused += decay + all_fused += decay + all_fused += other + all_buffers += decay_buffers + all_buffers += other_buffers + + return decay_fused, all_fused, all_buffers diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_sharding_model_with_fusion.py b/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_sharding_model_with_fusion.py new file mode 100644 index 00000000000000..51e10e992364a9 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_sharding_model_with_fusion.py @@ -0,0 +1,189 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import unittest + +import numpy as np + +import paddle +from paddle.distributed import fleet +from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( + DygraphShardingOptimizer, +) + +vocab_size = 20 +hidden_size = 10 +inner_size = 8 +output_size = 10 +seq_length = 2 +batch_size = 4 +STEPS = 10 + + +class SimpleDPNet(paddle.nn.Layer): + def __init__( + self, vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 + ): + super().__init__() + self.linear1 = paddle.nn.Linear( + hidden_size, + inner_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Assign(np_fc1) + ), + bias_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0) + ), + ) + + self.linear2 = paddle.nn.Linear( + inner_size, + hidden_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Assign(np_fc2) + ), + bias_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0) + ), + ) + + self.linear3 = paddle.nn.Linear( + hidden_size, + output_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0) + ), + bias_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0) + ), + ) + + self.embedding = paddle.nn.Embedding( + vocab_size, + hidden_size, + weight_attr=paddle.nn.initializer.Constant(value=0.5), + ) + + def forward(self, x): + x = self.embedding(x) + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + x = paddle.matmul(x, self.embedding.weight, transpose_y=True) + return x + + +class TestDistSharding(unittest.TestCase): + def setUp(self): + random.seed(2021) + np.random.seed(2021) + paddle.seed(2021) + + self.strategy = fleet.DistributedStrategy() + self.strategy.hybrid_configs = { + "sharding_degree": 2, + "dp_degree": 1, + "mp_degree": 1, + "pp_degree": 1, + } + self.strategy.hybrid_configs["sharding_configs"].tensor_fusion = True + self.strategy.hybrid_configs["sharding_configs"].comm_overlap = True + self.strategy.hybrid_configs["sharding_configs"].accumulate_steps = 1 + self.strategy.hybrid_configs["sharding_configs"].fuse_optimizer = False + fleet.init(is_collective=True, strategy=self.strategy) + self.data = np.random.randint( + 0, + vocab_size, + ( + batch_size, + seq_length, + ), + ) + + if paddle.distributed.get_rank() == 0: + self.batch_sharding = paddle.to_tensor(self.data[:2]) + else: + self.batch_sharding = paddle.to_tensor(self.data[2:]) + + self.batch_single = paddle.to_tensor(self.data) + + def train_batch(self, batch, model, optimizer): + output = model(batch) + loss = output.mean() + loss.backward() + optimizer.step() + optimizer.clear_grad() + return loss + + def build_optimizer(self, model): + clip = paddle.nn.ClipGradByGlobalNorm(0.5) + optimizer = paddle.optimizer.AdamW( + parameters=model.parameters(), + learning_rate=0.001, + weight_decay=0.001, + grad_clip=clip, + ) + return optimizer + + def build_model_optimizer(self): + np_fc1 = np.random.random_sample((hidden_size, inner_size)) + np_fc2 = np.random.random_sample((inner_size, hidden_size)) + + model_a = SimpleDPNet( + vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 + ) + optimizer_a = self.build_optimizer(model_a) + + model_b = SimpleDPNet( + vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 + ) + optimizer_b = self.build_optimizer(model_b) + + model_a = fleet.distributed_model(model_a) + optimizer_a = fleet.distributed_optimizer(optimizer_a) + + return model_a, optimizer_a, model_b, optimizer_b + + def sharding_model(self): + ( + model_a, + optimizer_a, + model_b, + optimizer_b, + ) = self.build_model_optimizer() + + self.assertTrue( + isinstance(optimizer_a._inner_opt, DygraphShardingOptimizer) + ) + + for idx in range(STEPS): + loss_a = self.train_batch(self.batch_sharding, model_a, optimizer_a) + loss_b = self.train_batch(self.batch_single, model_b, optimizer_b) + np.testing.assert_allclose(loss_a, loss_b, rtol=1e-6, atol=1e-6) + + for j in range(len(model_a.parameters())): + np.testing.assert_allclose( + model_a.parameters()[j].numpy(), + model_b.parameters()[j].numpy(), + rtol=1e-6, + atol=1e-7, + ) + + def test_sharding_adam(self): + self.sharding_model() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_parallel_dygraph_sharding_parallel.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_parallel_dygraph_sharding_parallel.py index 264c0e7e19e48f..68752cb3c7a65b 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_parallel_dygraph_sharding_parallel.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_parallel_dygraph_sharding_parallel.py @@ -47,6 +47,10 @@ def test_hybrid_parallel_sharding_state_dict(self): os.environ["FLAGS_shard_split_param"] = "0" self.run_mnist_2gpu('hybrid_parallel_sharding_state_dict.py') + def test_hybrid_parallel_sharding_model_with_fusion(self): + os.environ["FLAGS_shard_split_param"] = "0" + self.run_mnist_2gpu('hybrid_parallel_sharding_model_with_fusion.py') + if __name__ == "__main__": unittest.main() From 3faa9dfc6bc470a4ced255055bce74e5db16b8d2 Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Wed, 15 Nov 2023 09:30:06 +0800 Subject: [PATCH 02/11] Revert "[cherry-pick]Stage1 no fuse opt frl (#58925)" (#59008) This reverts commit a7e725cac020e6fc0b8b4fe494918707a6d92079. --- .../framework/distributed_strategy.proto | 5 +- .../dygraph_sharding_optimizer.py | 157 ++-------- .../fleet/meta_parallel/pipeline_parallel.py | 19 +- .../fleet/meta_parallel/pp_utils/utils.py | 2 - .../sharding/group_sharded_storage.py | 14 - .../fleet/utils/tensor_fusion_helper.py | 295 +----------------- ...rid_parallel_sharding_model_with_fusion.py | 189 ----------- ...test_parallel_dygraph_sharding_parallel.py | 4 - 8 files changed, 54 insertions(+), 631 deletions(-) delete mode 100644 python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_sharding_model_with_fusion.py diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 2e84b3fe1ceda2..5d03b58fad7f46 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -69,12 +69,9 @@ message PpConfig { message DygraphShardingConfig { optional bool split_param = 1 [ default = false ]; - optional bool tensor_fusion = 2 [ default = false ]; - optional int32 accumulate_steps = 3 [ default = 1 ]; - optional bool comm_overlap = 4 [ default = false ]; - optional bool fuse_optimizer = 5 [ default = true ]; } + message HybridConfig { optional int32 dp_degree = 1 [ default = -1 ]; optional int32 mp_degree = 2 [ default = 1 ]; diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py index 1d9a43bb90f976..da231f67395bfa 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py @@ -31,7 +31,6 @@ HOOK_ACTION, FusedCommBuffer, assign_group_by_size, - fused_parameters, ) g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 1)) @@ -87,90 +86,39 @@ def __init__(self, optimizer, hcg): self._parameter_list.append(param) else: self._parameter_list = optimizer._parameter_list - self._origin_parameter_list = self._parameter_list self._inner_opt = optimizer self._hcg = hcg self._sharding_world_size = self._hcg.get_sharding_parallel_world_size() self._sharding_rank = self._hcg.get_sharding_parallel_rank() - strategy = fleet.fleet._user_defined_strategy - self.tensor_fusion = strategy.hybrid_configs[ - 'sharding_configs' - ].tensor_fusion - self.accumulate_steps = strategy.hybrid_configs[ - 'sharding_configs' - ].accumulate_steps - self.comm_overlap = strategy.hybrid_configs[ - 'sharding_configs' - ].comm_overlap - self.fuse_optimizer = strategy.hybrid_configs[ - 'sharding_configs' - ].fuse_optimizer - pp_overlap = strategy.hybrid_configs['pp_configs'].sharding_comm_overlap - if self.tensor_fusion or self.comm_overlap: - assert ( - not pp_overlap - ), "Can not enable pp's sharding_comm_overlap and sharding's tensor_fusion at the same time." - - self._use_main_grad = hasattr(self._parameter_list[0], "main_grad") - self._rank2decay = {} - self._rank2fused = {} - self._comm_buffers = [] - self._rank2params = self._partition_parameters() self._param2rank = self._map_param_to_rank() - if not self.tensor_fusion: - if self._using_param_groups: - param_groups = [ - {"params": []} for _ in range(len(optimizer._param_groups)) - ] - for idx, pg in enumerate(optimizer._param_groups): - param_groups[idx].update( - {k: v for k, v in pg.items() if k != 'params'} - ) - for param in self._rank2params[self._sharding_rank]: - group_id = self._param_2_group_id[id(param)] - param_groups[group_id]['params'].append(param) - - self._set_inner_opt_attr('_param_groups', param_groups) - self._set_inner_opt_attr( - '_parameter_list', self._rank2params[self._sharding_rank] - ) - self._param_groups = self._parameter_list - else: - self._set_inner_opt_attr( - '_param_groups', self._rank2params[self._sharding_rank] - ) - self._set_inner_opt_attr( - '_parameter_list', self._rank2params[self._sharding_rank] + if self._using_param_groups: + param_groups = [ + {"params": []} for _ in range(len(optimizer._param_groups)) + ] + for idx, pg in enumerate(optimizer._param_groups): + param_groups[idx].update( + {k: v for k, v in pg.items() if k != 'params'} ) + for param in self._rank2params[self._sharding_rank]: + group_id = self._param_2_group_id[id(param)] + param_groups[group_id]['params'].append(param) + self._set_inner_opt_attr('_param_groups', param_groups) + self._set_inner_opt_attr( + '_parameter_list', self._rank2params[self._sharding_rank] + ) + self._param_groups = self._parameter_list else: - self._tensor_fusion() - - decay_params = [ - p.name for p in self._rank2decay[self._sharding_rank] - ] - fused_params = self._rank2fused[self._sharding_rank] - apply_decay_param_fun = lambda x: x in decay_params - - all_fused_params = [] - for v in self._rank2fused.values(): - all_fused_params += v - self._parameter_list = all_fused_params - self._param_groups = all_fused_params - - self._set_inner_opt_attr('_parameter_list', fused_params) - self._set_inner_opt_attr('_param_groups', fused_params) - origin_decay_param_fun = getattr( - self._inner_opt, '_apply_decay_param_fun', None + self._set_inner_opt_attr( + '_param_groups', self._rank2params[self._sharding_rank] + ) + self._set_inner_opt_attr( + '_parameter_list', self._rank2params[self._sharding_rank] ) - if origin_decay_param_fun is not None: - self._set_inner_opt_attr( - '_apply_decay_param_fun', apply_decay_param_fun - ) def clear_grad(self, set_to_zero=True): """ @@ -185,37 +133,7 @@ def clear_grad(self, set_to_zero=True): p.main_grad._clear() p.main_grad = None elif not hasattr(p, "main_grad"): - if self.tensor_fusion: - if set_to_zero: - p.grad.zero_() - else: - p.grad._clear() - p.grad = None - else: - p.clear_gradient(set_to_zero) - - def _tensor_fusion(self): - comm_group = self._hcg.get_sharding_parallel_group() - for i in range(self._sharding_world_size): - params = self._rank2params[i] - dst = comm_group.ranks[i] - # TODO(sharding dev): make scale_after_comm a field to be configured by user - decay_fused, all_fused, all_buffer = fused_parameters( - params, - use_main_grad=self._use_main_grad, - fuse_param=True, - comm_overlap=self.comm_overlap, - comm_group=comm_group, - dst=dst, - acc_step=self.accumulate_steps, - scale_after_comm=False, - ) - if self.comm_overlap: - self._comm_buffers += all_buffer - self._rank2decay[i] = decay_fused - self._rank2fused[i] = all_fused - for p in all_fused: - self._param2rank[p.name] = i + p.clear_gradient(set_to_zero) def filter_parameters(self, parameter_list, hcg): sharding_parallel_rank = hcg.get_sharding_parallel_rank() @@ -279,12 +197,6 @@ def _map_param_to_rank(self): def reduce_gradients(self, parameter_list, hcg): # TODO merge grad / nrank with dp - logger.debug("sharding start gradients sync") - if self.comm_overlap: - for buffer in self._comm_buffers: - buffer.scale_grads() - return - with framework.no_grad(): sharding_nrank = hcg.get_sharding_parallel_group().nranks for param in parameter_list: @@ -324,12 +236,7 @@ def _sharding_sync_parameters(self): with framework.no_grad(): # TODO detach not need (?) - valid_rank_to_params = ( - self._rank2params - if not self.tensor_fusion - else self._rank2fused - ) - for rank, params in valid_rank_to_params.items(): + for rank, params in self._rank2params.items(): for param in params: paddle.distributed.broadcast( param, @@ -382,12 +289,7 @@ def step(self): origin_clip = self._inner_opt._grad_clip if not self._using_param_groups: params_grads = [] - target_param_list = ( - self._origin_parameter_list - if (not self.tensor_fusion or not self.fuse_optimizer) - else self._parameter_list - ) - for param in target_param_list: + for param in self._parameter_list: if ( hasattr(param, "regularizer") and param.regularizer is not None @@ -407,14 +309,9 @@ def step(self): params_grads = self._inner_opt._grad_clip(params_grads) # set inner_opt._grad_clip None to avoid repeatedly grad_clip gradients inside inner_opt._apply_optimize self._set_inner_opt_attr('_grad_clip', None) - - rank_params = ( - self._rank2params[self._sharding_rank] - if (not self.tensor_fusion or not self.fuse_optimizer) - else self._rank2fused[self._sharding_rank] - ) - update_param_names = [p.name for p in rank_params] - + update_param_names = [ + p.name for p in self._rank2params[self._sharding_rank] + ] update_params_grads = [ (p, g) for p, g in params_grads if p.name in update_param_names ] @@ -628,7 +525,7 @@ def reduce_gradients(self, parameter_list, hcg): comm_buffer._copy_grad_to_buffer(param) comm_buffer._comm_grads() - comm_buffer.scale_grads() + comm_buffer.scale_and_split_grads() def _sharding_sync_parameters(self): """ diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 5e9b11fbc4520f..44e72ad668415c 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -18,11 +18,6 @@ import paddle from paddle import framework -from paddle.distributed.fleet.utils.tensor_fusion_helper import ( - HOOK_ACTION, - FusedCommBuffer, - assign_group_by_size, -) from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer from ..utils import timer_helper as timer @@ -32,6 +27,11 @@ broadcast_sharding_parameters, ) from ..utils.log_util import logger +from ..utils.tensor_fusion_helper import ( + HOOK_ACTION, + FusedCommBuffer, + assign_group_by_size, +) from .meta_parallel_base import MetaParallelBase from .parallel_layers.pp_layers import PipelineLayer from .pp_utils import p2p_communication as p2p @@ -320,8 +320,8 @@ def fused_gradient( if act == HOOK_ACTION.REDUCE: # parse the relative dst rank to absolute dst rank for sharding dst = comm_group.ranks[dst] - var_groups = assign_group_by_size(parameter_list, group_size) + var_groups = assign_group_by_size(parameter_list, group_size) for group_idx, parameters in var_groups.items(): buffer = FusedCommBuffer( group_idx, @@ -458,7 +458,8 @@ def forward_backward_pipeline(self, data, scaler=None): ), "comm buffers should be created" for _, buffers in self._chunk_2_comm_buffers.items(): for buffer in buffers: - buffer.scale_grads() + buffer.scale_and_split_grads() + if self._enable_timer: self.timers("allreduce_shared_weight_gradients").start() self._layers.allreduce_shared_weight_gradients() @@ -878,7 +879,7 @@ def _sync_overlap_grads(self): for _, buffers in self._chunk_2_comm_buffers.items(): for buffer in buffers: - buffer.scale_grads() + buffer.scale_and_split_grads() def _backward_step_helper(self, micro_step): virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=False) @@ -1265,7 +1266,7 @@ def _sync_overlap_grads(self): for buffers in self._chunk_2_comm_buffers.values(): for buffer in buffers: - buffer.scale_grads() + buffer.scale_and_split_grads() def forward_backward_pipeline( self, data, scaler, forward_only=False, compute_loss=True diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py index b30184c935d12b..2578d6bb20b68b 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py @@ -14,8 +14,6 @@ import paddle from paddle import _legacy_C_ops -__all__ = [] - FLOAT_TYPE_DICT = { paddle.float16: "float16", paddle.float32: "float32", diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py index 0ee58dd4cbe3e2..44c5995acc7963 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py @@ -30,14 +30,6 @@ from .group_sharded_utils import Type, cvt_to_device, device_guard -class BufferWarper(core.eager.Tensor): - def __init__(self): - super().__init__() - self.need_clip = True - self.is_distributed = False - self.trainable = True - - class InternalStorage: """ This is a basic class, which is responsible for consolidating the basic storage tensor. @@ -100,12 +92,6 @@ def to(self, device, dtype=None, keep_alignment=True): self.buffer = self.buffer.cast(dtype=dtype) self._dtype = dtype - def warp_buffer(self): - tmp_buffer = BufferWarper() - self._buffer = self.buffer - tmp_buffer.get_tensor()._share_data_with(self.buffer.get_tensor()) - self.buffer = tmp_buffer - class ParamStorage(InternalStorage): """ diff --git a/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py b/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py index 76b2373289b88b..b4041e8f0443c2 100644 --- a/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py +++ b/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import itertools -import os + import weakref from collections import OrderedDict @@ -25,7 +24,6 @@ alignment = { "gpu": 256, } - align = { paddle.float16.value: 2, paddle.bfloat16.value: 2, @@ -41,22 +39,15 @@ class HOOK_ACTION: REDUCE_SCATTER = 2 -def flatten_dense_tensors( - parameters, - use_main_grad=False, - release_grads=False, - fuse_param=True, - warp_buffer=False, -): +def flatten_dense_tensors(parameters, use_main_grad=False, release_grad=False): from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_storage import ( GradStorage, - ParamStorage, ) _buffer_size = 0 _param2align = {} _param2offset = {} - dtype = parameters[0].dtype + dtype = paddle.float32 if use_main_grad else parameters[0].dtype for param in parameters: assert param.trainable, "param must be trainable..." @@ -68,22 +59,13 @@ def flatten_dense_tensors( _buffer_size += np.prod(param.shape) + align_ _param2align[param.name] = align_ - if release_grads: - assert not fuse_param - assert not warp_buffer + if release_grad: return None, _buffer_size, _param2offset - if fuse_param: - param_storage = ParamStorage( - size=_buffer_size, dtype=dtype, device="gpu" - ) - param_storage.add_rank_params(parameters, _param2align) - # process gradient - grad_dtype = paddle.float32 if use_main_grad else dtype grad_storage = GradStorage( size=_buffer_size, - dtype=grad_dtype, + dtype=dtype, device="gpu", destination="0", parm2align=_param2align, @@ -92,21 +74,7 @@ def flatten_dense_tensors( for param in parameters: grad_storage.add_grad(param, _param2align[param.name]) - if warp_buffer: - if fuse_param: - param_storage.warp_buffer() - grad_storage.warp_buffer() - - if fuse_param: - if not use_main_grad: - # param_storage --> grad_storage - param_storage.buffer._copy_gradient_from(grad_storage.buffer) - else: - param_storage.buffer.main_grad = grad_storage.buffer - param_storage.buffer.stop_gradient = False - return param_storage, grad_storage, _buffer_size, _param2offset - else: - return grad_storage, _buffer_size, _param2offset + return grad_storage.buffer, _buffer_size, _param2offset class ShardingGradView: @@ -289,24 +257,15 @@ def __init__( act=None, dst=-1, release_grads=False, - use_main_grad=None, - fuse_param=False, - scale_after_comm=True, ): - self._id = id self._params = params self._acc_steps = acc_steps self._comm_group = comm_group self._release_grads = release_grads - self.use_main_grad = ( - use_main_grad - if use_main_grad is not None - else hasattr(self._params[0], "main_grad") - ) - self._fuse_param = fuse_param - self._scale_after_comm = scale_after_comm + self.use_main_grad = hasattr(self._params[0], "main_grad") + self._task = None self._dtype = ( paddle.float32 if self.use_main_grad else self._params[0].dtype @@ -330,37 +289,13 @@ def __init__( self._init_step_dict() if self._act != HOOK_ACTION.REDUCE_SCATTER: - if self._fuse_param: - ( - self.param_storage, - self.grad_storage, - self.buffer_size, - self.param2offset, - ) = flatten_dense_tensors( - self._params, - release_grads=release_grads, - use_main_grad=use_main_grad, - fuse_param=True, - warp_buffer=True, - ) - self.param_storage = self.param_storage.buffer - self.grad_storage = self.grad_storage.buffer - else: - self.param_storage = None - ( - self.grad_storage, - self.buffer_size, - self.param2offset, - ) = flatten_dense_tensors( - self._params, - release_grads=release_grads, - use_main_grad=self.use_main_grad, - fuse_param=False, - warp_buffer=False, - ) - if not self._release_grads: - self.grad_storage = self.grad_storage.buffer - + ( + self.grad_storage, + self.buffer_size, + self.param2offset, + ) = flatten_dense_tensors( + self._params, self.use_main_grad, self._release_grads + ) else: ( self._sharding_param_grad_view, @@ -506,11 +441,6 @@ def comm_grads(self): @imperative_base.no_grad def _comm_grads(self): - - if not self._scale_after_comm: - scale_factor = 1.0 / self._comm_group.nranks - self.grad_storage.scale_(scale_factor) - if self._act == HOOK_ACTION.ALL_REDUCE: task = paddle.distributed.all_reduce( self.grad_storage, group=self._comm_group, sync_op=False @@ -537,13 +467,9 @@ def _comm_grads(self): self._task = task @imperative_base.no_grad - def scale_grads(self): + def scale_and_split_grads(self): assert self._task is not None, "Task is not initialized. " self._task.wait() - if self._scale_after_comm: - scale_factor = 1.0 / self._comm_group.nranks - self.grad_storage.scale_(scale_factor) - scale_factor = 1.0 / self._comm_group.nranks self.grad_storage.scale_(scale_factor) @@ -568,192 +494,3 @@ def assign_group_by_size(parameters, group_size=128 * 1024 * 1024): var_groups.setdefault(group_idx, []).append(var) return var_groups - - -def bw_hook_func(buffer, param): - @paddle.autograd.no_grad() - def fused_comm(*_): - buffer.add_grad(param) - - return fused_comm - - -def obtain_storage( - parameters, - use_main_grad=False, - clip=True, - dist=False, - fuse_param=True, - comm_overlap=False, - act=None, - comm_group=None, - dst=-1, - acc_steps=1, - scale_after_comm=False, -): - if len(parameters) < 1: - return [], [] - - var_groups = assign_group_by_size(parameters) - storage = [] - buffers = [] - for group_idx, parameters in var_groups.items(): - comm_buffer = FusedCommBuffer( - group_idx, - parameters, - comm_group=comm_group, - acc_steps=acc_steps, - act=act, - dst=dst, - use_main_grad=use_main_grad, - fuse_param=fuse_param, - scale_after_comm=scale_after_comm, - ) - if fuse_param: - param_buffer = comm_buffer.param_storage - param_buffer.need_clip = clip - param_buffer.is_distributed = dist - storage.append(param_buffer) - if comm_overlap: - for param in parameters: - param._register_backward_hook(bw_hook_func(comm_buffer, param)) - buffers.append(comm_buffer) - - return storage, buffers - - -def filter_params(params, is_fp32, is_distributed, need_clip): - params = list( - filter( - lambda x: x.is_distributed - if is_distributed - else (not x.is_distributed), - params, - ) - ) - params = list( - filter( - lambda x: getattr(x, 'need_clip', True) - if need_clip - else (not getattr(x, 'need_clip', True)), - params, - ) - ) - params = list( - filter( - lambda x: x.dtype == paddle.float32 - if is_fp32 - else x.dtype != paddle.float32, - params, - ) - ) - dtype = None - for p in params: - if dtype is None: - dtype = p.dtype - else: - assert dtype == p.dtype - - return params, dtype - - -def fused_parameters( - parameters, - use_main_grad=False, - fuse_param=True, - comm_overlap=False, - comm_group=None, - dst=-1, - acc_step=1, - scale_after_comm=False, -): - """ - Fuse gradients. Fuse parameters if be enabled. Prepare for comm overlap if be enabled. - :param parameters: all parameters to be fused. - :param use_main_grad: does the gradient use main grad or not - :param comm_overlap: enable comm overlap or not - :param comm_group: the comm group for comm overlap - :param dst: the dst for comm overlap - :param acc_step: acc steps, using for comm overlap - :param fuse_param: fuse param or not - :param scale_after_comm: if enable comm overlap, specify the location of grad scale - :return: param storage if fused, comm buffers is comm overlap - """ - g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 0)) - act = ( - HOOK_ACTION.ALL_REDUCE if not g_shard_use_reduce else HOOK_ACTION.REDUCE - ) - if comm_overlap: - assert comm_group is not None - if act == HOOK_ACTION.REDUCE: - assert dst != -1 - elif act == HOOK_ACTION.ALL_REDUCE: - dst = -1 - param_groups = [] - attrs = [] - - is_fp32 = [True, False] - is_distributed = [True, False] - need_clip = [True, False] - - no_fp32_dtype = None - for fp32, dist, clip in itertools.product( - is_fp32, is_distributed, need_clip - ): - params, dtype = filter_params(parameters, fp32, dist, clip) - if not fp32: - if no_fp32_dtype is None: - no_fp32_dtype = dtype - elif dtype is not None: - assert no_fp32_dtype == dtype - attrs.append([dtype, dist, clip]) - param_groups.append(params) - - decay_fused = [] - all_fused = [] - all_buffers = [] - for params, attr in zip(param_groups, attrs): - decay_params = [] - other_params = [] - - for param in params: - if not any(nd in param.name for nd in ["bias", "norm", "b_0"]): - decay_params.append(param) - else: - other_params.append(param) - - is_distributed = attr[1] - need_clip = attr[2] - decay, decay_buffers = obtain_storage( - decay_params, - use_main_grad=use_main_grad, - clip=need_clip, - dist=is_distributed, - fuse_param=fuse_param, - comm_overlap=comm_overlap, - act=act, - comm_group=comm_group, - dst=dst, - acc_steps=acc_step, - scale_after_comm=scale_after_comm, - ) - other, other_buffers = obtain_storage( - other_params, - fuse_param=fuse_param, - comm_overlap=comm_overlap, - use_main_grad=use_main_grad, - clip=need_clip, - dist=is_distributed, - act=act, - comm_group=comm_group, - dst=dst, - acc_steps=acc_step, - scale_after_comm=scale_after_comm, - ) - decay_fused += decay - all_fused += decay - all_fused += other - all_buffers += decay_buffers - all_buffers += other_buffers - - return decay_fused, all_fused, all_buffers diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_sharding_model_with_fusion.py b/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_sharding_model_with_fusion.py deleted file mode 100644 index 51e10e992364a9..00000000000000 --- a/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_sharding_model_with_fusion.py +++ /dev/null @@ -1,189 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import random -import unittest - -import numpy as np - -import paddle -from paddle.distributed import fleet -from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( - DygraphShardingOptimizer, -) - -vocab_size = 20 -hidden_size = 10 -inner_size = 8 -output_size = 10 -seq_length = 2 -batch_size = 4 -STEPS = 10 - - -class SimpleDPNet(paddle.nn.Layer): - def __init__( - self, vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 - ): - super().__init__() - self.linear1 = paddle.nn.Linear( - hidden_size, - inner_size, - weight_attr=paddle.framework.ParamAttr( - initializer=paddle.nn.initializer.Assign(np_fc1) - ), - bias_attr=paddle.framework.ParamAttr( - initializer=paddle.nn.initializer.Constant(0.0) - ), - ) - - self.linear2 = paddle.nn.Linear( - inner_size, - hidden_size, - weight_attr=paddle.framework.ParamAttr( - initializer=paddle.nn.initializer.Assign(np_fc2) - ), - bias_attr=paddle.framework.ParamAttr( - initializer=paddle.nn.initializer.Constant(0.0) - ), - ) - - self.linear3 = paddle.nn.Linear( - hidden_size, - output_size, - weight_attr=paddle.framework.ParamAttr( - initializer=paddle.nn.initializer.Constant(0.0) - ), - bias_attr=paddle.framework.ParamAttr( - initializer=paddle.nn.initializer.Constant(0.0) - ), - ) - - self.embedding = paddle.nn.Embedding( - vocab_size, - hidden_size, - weight_attr=paddle.nn.initializer.Constant(value=0.5), - ) - - def forward(self, x): - x = self.embedding(x) - x = self.linear1(x) - x = self.linear2(x) - x = self.linear3(x) - x = paddle.matmul(x, self.embedding.weight, transpose_y=True) - return x - - -class TestDistSharding(unittest.TestCase): - def setUp(self): - random.seed(2021) - np.random.seed(2021) - paddle.seed(2021) - - self.strategy = fleet.DistributedStrategy() - self.strategy.hybrid_configs = { - "sharding_degree": 2, - "dp_degree": 1, - "mp_degree": 1, - "pp_degree": 1, - } - self.strategy.hybrid_configs["sharding_configs"].tensor_fusion = True - self.strategy.hybrid_configs["sharding_configs"].comm_overlap = True - self.strategy.hybrid_configs["sharding_configs"].accumulate_steps = 1 - self.strategy.hybrid_configs["sharding_configs"].fuse_optimizer = False - fleet.init(is_collective=True, strategy=self.strategy) - self.data = np.random.randint( - 0, - vocab_size, - ( - batch_size, - seq_length, - ), - ) - - if paddle.distributed.get_rank() == 0: - self.batch_sharding = paddle.to_tensor(self.data[:2]) - else: - self.batch_sharding = paddle.to_tensor(self.data[2:]) - - self.batch_single = paddle.to_tensor(self.data) - - def train_batch(self, batch, model, optimizer): - output = model(batch) - loss = output.mean() - loss.backward() - optimizer.step() - optimizer.clear_grad() - return loss - - def build_optimizer(self, model): - clip = paddle.nn.ClipGradByGlobalNorm(0.5) - optimizer = paddle.optimizer.AdamW( - parameters=model.parameters(), - learning_rate=0.001, - weight_decay=0.001, - grad_clip=clip, - ) - return optimizer - - def build_model_optimizer(self): - np_fc1 = np.random.random_sample((hidden_size, inner_size)) - np_fc2 = np.random.random_sample((inner_size, hidden_size)) - - model_a = SimpleDPNet( - vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 - ) - optimizer_a = self.build_optimizer(model_a) - - model_b = SimpleDPNet( - vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 - ) - optimizer_b = self.build_optimizer(model_b) - - model_a = fleet.distributed_model(model_a) - optimizer_a = fleet.distributed_optimizer(optimizer_a) - - return model_a, optimizer_a, model_b, optimizer_b - - def sharding_model(self): - ( - model_a, - optimizer_a, - model_b, - optimizer_b, - ) = self.build_model_optimizer() - - self.assertTrue( - isinstance(optimizer_a._inner_opt, DygraphShardingOptimizer) - ) - - for idx in range(STEPS): - loss_a = self.train_batch(self.batch_sharding, model_a, optimizer_a) - loss_b = self.train_batch(self.batch_single, model_b, optimizer_b) - np.testing.assert_allclose(loss_a, loss_b, rtol=1e-6, atol=1e-6) - - for j in range(len(model_a.parameters())): - np.testing.assert_allclose( - model_a.parameters()[j].numpy(), - model_b.parameters()[j].numpy(), - rtol=1e-6, - atol=1e-7, - ) - - def test_sharding_adam(self): - self.sharding_model() - - -if __name__ == "__main__": - unittest.main() diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_parallel_dygraph_sharding_parallel.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_parallel_dygraph_sharding_parallel.py index 68752cb3c7a65b..264c0e7e19e48f 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_parallel_dygraph_sharding_parallel.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_parallel_dygraph_sharding_parallel.py @@ -47,10 +47,6 @@ def test_hybrid_parallel_sharding_state_dict(self): os.environ["FLAGS_shard_split_param"] = "0" self.run_mnist_2gpu('hybrid_parallel_sharding_state_dict.py') - def test_hybrid_parallel_sharding_model_with_fusion(self): - os.environ["FLAGS_shard_split_param"] = "0" - self.run_mnist_2gpu('hybrid_parallel_sharding_model_with_fusion.py') - if __name__ == "__main__": unittest.main() From d7828b17ac248d196f24e8091e77d9af2d1e8a47 Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Thu, 16 Nov 2023 11:36:57 +0800 Subject: [PATCH 03/11] [Distributed] Add sharding v1 & comm overlap (#59023) * add sharding v1 & comm overlap * add overlap comm in sharding * add overlap comm in sharding * add overlap comm in sharding --- .../framework/distributed_strategy.proto | 2 + .../dygraph_sharding_optimizer.py | 101 ++++++++++++++++++ .../fleet/utils/tensor_fusion_helper.py | 7 +- 3 files changed, 109 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 5d03b58fad7f46..3097bcad69ab75 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -69,6 +69,8 @@ message PpConfig { message DygraphShardingConfig { optional bool split_param = 1 [ default = false ]; + optional bool comm_overlap = 2 [ default = false ]; + optional int32 accumulate_steps = 3 [ default = -1 ]; } diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py index da231f67395bfa..9b262bd1581f51 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py @@ -120,6 +120,97 @@ def __init__(self, optimizer, hcg): '_parameter_list', self._rank2params[self._sharding_rank] ) + strategy = fleet.fleet._user_defined_strategy + sharding_configs = strategy.hybrid_configs["sharding_configs"] + pp_configs = strategy.hybrid_configs["pp_configs"] + + self._pp_overlap = pp_configs.sharding_comm_overlap + acc_steps = sharding_configs.accumulate_steps + self.comm_overlap = sharding_configs.comm_overlap + comm_group = self._hcg.get_sharding_parallel_group() + + if not self._pp_overlap and self.comm_overlap: + assert ( + acc_steps > 0 + ), "acc_steps should be larger than 0 when using comm_overlap in sharding" + self.register_reduce_overlap_hook( + comm_group, acc_steps, use_comm=True + ) + + def _build_comm_buffers( + self, comm_group, acc_steps=1, group_size=256 * 1024 * 1024 + ): + parameter_list = list(self._parameter_list) + + if not parameter_list: + return [] + + # Using defaultdict for automatic list creation + fused_parameter_group = defaultdict(list) + + for p in parameter_list: + assert p.name in self._param2rank + dst_rank = self._param2rank[p.name] + fused_parameter_group[dst_rank].append(p) + + # Pre-compute absolute destination ranks + absolute_dst_ranks = { + rank: comm_group.ranks[rank] for rank in fused_parameter_group + } + + comm_buffers = [] + for dst, params in fused_parameter_group.items(): + var_groups = assign_group_by_size(params, group_size) + abs_dst = absolute_dst_ranks[dst] + + # Using list comprehension for buffer creation + buffers = [ + FusedCommBuffer( + group_idx, + parameters, + comm_group, + acc_steps, + HOOK_ACTION.REDUCE, + abs_dst, + release_grads=False, + ) + for group_idx, parameters in var_groups.items() + ] + comm_buffers.extend(buffers) + + return comm_buffers + + def register_reduce_overlap_hook( + self, + comm_group, + acc_steps, + use_comm=False, + group_size=128 * 1024 * 1024, + ): + # Build communication buffers once and store them + if not hasattr(self, 'comm_buffers'): + self.comm_buffers = self._build_comm_buffers( + comm_group, acc_steps, group_size + ) + + # Register backward hooks for each parameter in the buffer + for buffer in self.comm_buffers: + for param in buffer._params: + # Directly register the hook function with necessary parameters + param._register_backward_hook( + self._create_backward_hook(buffer, param, use_comm) + ) + + def _create_backward_hook(self, buffer, param, use_comm): + """Creates a backward hook function for autograd.""" + + @paddle.autograd.no_grad() + def fused_allreduce(*_): + # Directly add gradient to the buffer + buffer.add_grad(param, use_comm=use_comm) + + return fused_allreduce + def clear_grad(self, set_to_zero=True): """ should clear grad for all parameters in model @@ -196,6 +287,9 @@ def _map_param_to_rank(self): return mapping def reduce_gradients(self, parameter_list, hcg): + if self._pp_overlap or self.comm_overlap: + return + # TODO merge grad / nrank with dp with framework.no_grad(): sharding_nrank = hcg.get_sharding_parallel_group().nranks @@ -259,6 +353,10 @@ def minimize( # NOTE in dygraph mode, the only different between step and minimize is that minimize # allow user to customize the parameters for updating on each step + if self.comm_overlap: + for buffer in self.comm_buffers: + buffer.scale_and_split_grads() + assert ( not self._using_param_groups ), "minimize() is not support if using param_groups" @@ -282,6 +380,9 @@ def minimize( @framework.dygraph_only def step(self): # TODO Check whether the model trainable param changed and update state accordingly + if self.comm_overlap: + for buffer in self.comm_buffers: + buffer.scale_and_split_grads() # hack to grad_clip all parameters, # otherwise the self._inner_opt will only grad_clip the self._rank2params[self._sharding_rank] params diff --git a/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py b/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py index b4041e8f0443c2..0dc1dcff0e5897 100644 --- a/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py +++ b/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py @@ -385,7 +385,12 @@ def _copy_grad_to_buffer(self, param): param._copy_gradient_from(tmp_var) def add_grad(self, param, use_comm=True): - assert param.name in self._params_step_dict + assert ( + param.name in self._params_step_dict + ), "{} not in params_dict, please check accumulation_steps".format( + param.name + ) + if not self._release_grads: current_ptr = get_grad_address(param, self.use_main_grad) if self._params_to_addr[param.name] != current_ptr: From 5b87ce0811d60804c13df895cb7057508797795e Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Thu, 16 Nov 2023 21:06:13 +0800 Subject: [PATCH 04/11] [BugFix]Fix bug in dygraph shardingv1 overlap (#59067) * fix bug in scale grad * fix bug in scale grad --- .../dygraph_optimizer/dygraph_sharding_optimizer.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py index 9b262bd1581f51..bb0af75ad6e6f4 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py @@ -287,7 +287,12 @@ def _map_param_to_rank(self): return mapping def reduce_gradients(self, parameter_list, hcg): - if self._pp_overlap or self.comm_overlap: + if self._pp_overlap: + return + + if self.comm_overlap: + for buffer in self.comm_buffers: + buffer.scale_and_split_grads() return # TODO merge grad / nrank with dp @@ -380,10 +385,6 @@ def minimize( @framework.dygraph_only def step(self): # TODO Check whether the model trainable param changed and update state accordingly - if self.comm_overlap: - for buffer in self.comm_buffers: - buffer.scale_and_split_grads() - # hack to grad_clip all parameters, # otherwise the self._inner_opt will only grad_clip the self._rank2params[self._sharding_rank] params # TODO(pangengzheng): remove the hacked grad_clip codes here when there is no diff in calculating global norm values in HybridParallelClipGrad compared to dp. From bcf9676365952b7e07695e47a77fb833f783190b Mon Sep 17 00:00:00 2001 From: wentao yu Date: Sat, 18 Nov 2023 16:11:41 +0800 Subject: [PATCH 05/11] optimize trace hang && fix event leak (#58707) --- .../distributed/collective/process_group.cc | 6 + .../distributed/collective/process_group.h | 1 + .../collective/process_group_nccl.cc | 64 ++++++++- .../collective/process_group_nccl.h | 2 + .../core/distributed/comm_context_manager.cc | 20 +++ .../core/distributed/comm_context_manager.h | 12 ++ paddle/phi/core/distributed/comm_task.h | 26 +++- .../phi/core/distributed/comm_task_manager.cc | 126 ++++++++++++++++-- .../phi/core/distributed/comm_task_manager.h | 18 ++- paddle/phi/core/distributed/nccl_comm_task.cc | 63 +++++++-- paddle/phi/core/distributed/nccl_comm_task.h | 4 + paddle/phi/core/distributed/utils.h | 61 +++++++++ 12 files changed, 375 insertions(+), 28 deletions(-) diff --git a/paddle/fluid/distributed/collective/process_group.cc b/paddle/fluid/distributed/collective/process_group.cc index 4963fe0453ac2d..0e1e9ae07e9b1e 100644 --- a/paddle/fluid/distributed/collective/process_group.cc +++ b/paddle/fluid/distributed/collective/process_group.cc @@ -28,6 +28,12 @@ ProcessGroup::ProcessGroup(int rank, int size, int gid) auto map = ProcessGroupMapFromGid::getInstance(); map->insert(gid_, this); } + const char* global_rank = std::getenv("PADDLE_TRAINER_ID"); + PADDLE_ENFORCE_NOT_NULL( + global_rank, + phi::errors::NotFound( + "The environment variable 'PADDLE_TRAINER_ID' cannot be found.")); + global_rank_ = std::atoi(global_rank); } // TODO(sunyilun): methods below will be removed later diff --git a/paddle/fluid/distributed/collective/process_group.h b/paddle/fluid/distributed/collective/process_group.h index 19234945ccc5e2..451616596acb07 100644 --- a/paddle/fluid/distributed/collective/process_group.h +++ b/paddle/fluid/distributed/collective/process_group.h @@ -484,6 +484,7 @@ class ProcessGroup { } protected: + int global_rank_{-1}; int rank_; int size_; int gid_; diff --git a/paddle/fluid/distributed/collective/process_group_nccl.cc b/paddle/fluid/distributed/collective/process_group_nccl.cc index 128732c3c2370a..c0333e8c5015a2 100644 --- a/paddle/fluid/distributed/collective/process_group_nccl.cc +++ b/paddle/fluid/distributed/collective/process_group_nccl.cc @@ -21,6 +21,7 @@ #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/core/distributed/check/nccl_dynamic_check.h" #include "paddle/phi/core/distributed/check/static_check.h" +#include "paddle/phi/core/distributed/comm_context_manager.h" #include "paddle/phi/core/distributed/comm_task_manager.h" #include "paddle/phi/core/distributed/nccl_comm_task.h" #include "paddle/phi/core/distributed/nccl_tools.h" @@ -819,6 +820,7 @@ void ProcessGroupNCCL::BroadcastUniqueNCCLID(ncclUniqueId* nccl_id, const auto& nccl_id_wrapper = store_->get(store_key); std::memcpy(nccl_id, nccl_id_wrapper.data(), nccl_id_wrapper.size()); } + place_to_group_key_[p2p_key] = store_key; } void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place, @@ -860,6 +862,48 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place, auto comm_ctx = std::make_unique(place); comm_ctx->set_nccl_comm(nccl_comm); + if (FLAGS_enable_async_trace) { + // gather global ranks in current group + int* gpu_global_rank = nullptr; + size_t gpu_global_rank_size = sizeof(int); + CUDA_CHECK(cudaMalloc(&gpu_global_rank, gpu_global_rank_size)); + + CUDA_CHECK(cudaMemcpy(gpu_global_rank, + &global_rank_, + gpu_global_rank_size, + cudaMemcpyHostToDevice)); + + int* gpu_global_ranks = nullptr; + size_t gpu_global_ranks_size = num_ranks * sizeof(int); + CUDA_CHECK(cudaMalloc(&gpu_global_ranks, gpu_global_ranks_size)); + + NCCL_CHECK(phi::dynload::ncclAllGather(gpu_global_rank, + gpu_global_ranks, + 1, + ncclInt, + nccl_comm, + comm_ctx->stream())); + + std::vector global_ranks(num_ranks); + CUDA_CHECK(cudaMemcpy(global_ranks.data(), + gpu_global_ranks, + gpu_global_ranks_size, + cudaMemcpyDeviceToHost)); + CUDA_CHECK(cudaFree(gpu_global_rank)); + CUDA_CHECK(cudaFree(gpu_global_ranks)); + + // store global_ranks in current group_key + std::once_flag flag; + std::call_once(flag, [this]() { + phi::distributed::CommContextManager::GetInstance().SetStore(store_); + phi::distributed::CommTaskManager::GetInstance().SetTimeout(pg_timeout_); + }); + + std::string group_key = place_to_group_key_.at(place_key); + phi::distributed::CommContextManager::GetInstance().AddGroupRanks( + group_key, global_ranks); + } + auto* calc_ctx = static_cast( platform::DeviceContextPool::Instance().Get(place)); place_to_calc_event_.emplace( @@ -913,8 +957,10 @@ std::shared_ptr ProcessGroupNCCL::Collective( if (!FLAGS_enable_async_trace) { fn(nccl_comm, nccl_stream); } else { + std::string group_key = place_to_group_key_.at(key); auto comm_task = std::make_shared(place, + group_key, rank_, size_, gid_, @@ -973,22 +1019,29 @@ std::shared_ptr ProcessGroupNCCL::Point2Point( bool is_batch_p2p = s_group_call_counter > 0; std::string key = ""; + int p2p_nrank = 0; if (is_batch_p2p) { key = GetKeyFromPlace(place); p2p_rank = rank_; p2p_target_rank = peer; + p2p_nrank = GetSize(); } else { int low_rank = rank_ < peer ? rank_ : peer; int high_rank = rank_ < peer ? peer : rank_; key = std::to_string(low_rank) + "->" + std::to_string(high_rank); p2p_rank = rank_ < peer ? 0 : 1; p2p_target_rank = 1 - p2p_rank; + p2p_nrank = 2; } platform::CUDADeviceGuard cuda_guard(place); if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) { CreateNCCLEnvCache(place, key, comm_type, p2p_rank); } + if (p2p_comm_seq_.find(key) == p2p_comm_seq_.end()) { + p2p_comm_seq_[key] = 0; + } + p2p_comm_seq_[key]++; if (!use_calc_stream) { SyncCalcStream(place, key); @@ -1002,18 +1055,21 @@ std::shared_ptr ProcessGroupNCCL::Point2Point( auto nccl_comm = comm_ctx->nccl_comm(); auto nccl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream(); + std::string group_key = place_to_group_key_.at(key); auto comm_task = std::make_shared(place, - rank_, - size_, + group_key, + p2p_rank, + p2p_nrank, gid_, - comm_seq_, + p2p_comm_seq_[key], tensor.numel(), sync_op, use_calc_stream, nccl_comm, nccl_stream, - comm_type); + comm_type, + pg_timeout_); if (!FLAGS_enable_async_trace) { fn(nccl_comm, nccl_stream, p2p_target_rank); diff --git a/paddle/fluid/distributed/collective/process_group_nccl.h b/paddle/fluid/distributed/collective/process_group_nccl.h index da9283d2032484..ab49c3dc82a5ea 100644 --- a/paddle/fluid/distributed/collective/process_group_nccl.h +++ b/paddle/fluid/distributed/collective/process_group_nccl.h @@ -234,6 +234,8 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream { place_to_comm_ctx_; uint64_t comm_seq_{0}; + std::unordered_map p2p_comm_seq_; + std::unordered_map place_to_group_key_; // TODO(sunyilun): attrs below will be removed later std::mutex mutex_; diff --git a/paddle/phi/core/distributed/comm_context_manager.cc b/paddle/phi/core/distributed/comm_context_manager.cc index 719294655588ec..bd9c2e358ece47 100644 --- a/paddle/phi/core/distributed/comm_context_manager.cc +++ b/paddle/phi/core/distributed/comm_context_manager.cc @@ -109,5 +109,25 @@ bool CommContextManager::Has(int ring_id) const { return id_to_comm_context_.find(ring_id) != id_to_comm_context_.end(); } +void CommContextManager::SetGroupSize(const std::string& pg_key, int size) { + pg_key_size_[pg_key] = size; +} + +void CommContextManager::AddGroupRanks(const std::string& pg_key, + std::vector global_ranks) { + if (pg_key_ranks_.find(pg_key) == pg_key_ranks_.end()) { + pg_key_ranks_[pg_key] = global_ranks; + } +} + +std::vector CommContextManager::GetGroupRanks( + const std::string& pg_key) const { + PADDLE_ENFORCE_NE( + pg_key_ranks_.find(pg_key), + pg_key_ranks_.end(), + errors::NotFound("Can not find pg_key %d in GroupRanks.", pg_key)); + return pg_key_ranks_.at(pg_key); +} + } // namespace distributed } // namespace phi diff --git a/paddle/phi/core/distributed/comm_context_manager.h b/paddle/phi/core/distributed/comm_context_manager.h index ed77c5ac9e5d35..280af041226afd 100644 --- a/paddle/phi/core/distributed/comm_context_manager.h +++ b/paddle/phi/core/distributed/comm_context_manager.h @@ -16,6 +16,7 @@ #include #include +#include #include #include "paddle/phi/core/distributed/comm_context.h" @@ -44,6 +45,12 @@ class CommContextManager { bool Has(int ring_id) const; + void SetGroupSize(const std::string& pg_key, int size); + + void AddGroupRanks(const std::string& pg_key, std::vector global_ranks); + + std::vector GetGroupRanks(const std::string& pg_key) const; + #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) static void CreateNCCLCommContext(const std::shared_ptr& store, int dev_id, @@ -64,6 +71,11 @@ class CommContextManager { std::unordered_map> id_to_comm_context_; std::shared_ptr store_; + + // process group key to global ranks map + std::unordered_map> pg_key_ranks_; + // process group key to group size map + std::unordered_map pg_key_size_; }; } // namespace distributed diff --git a/paddle/phi/core/distributed/comm_task.h b/paddle/phi/core/distributed/comm_task.h index 5766c9407fe2df..9f5e0167e58b35 100644 --- a/paddle/phi/core/distributed/comm_task.h +++ b/paddle/phi/core/distributed/comm_task.h @@ -37,6 +37,7 @@ class CommTask { public: CommTask(const std::string& backend = "", const phi::Place& place = phi::Place(), + const std::string& group_key = "", int rank = -1, int size = 0, int gid = 0, @@ -47,6 +48,7 @@ class CommTask { CommType comm_type = CommType::UNKNOWN) : backend_(backend), place_(place), + group_key_(group_key), rank_(rank), size_(size), gid_(gid), @@ -65,9 +67,10 @@ class CommTask { virtual ~CommTask() = default; std::string UniqueKey() { - return "op:" + CommTypeToString(comm_type_) + + return "group_key:" + group_key_ + ",op:" + CommTypeToString(comm_type_) + ",gid:" + std::to_string(gid_) + ",seq:" + std::to_string(seq_); } + std::string GroupKey() { return group_key_; } std::string GetBackend() { return backend_; } phi::Place GetPlace() { return place_; } int GetGlobalRank() { return global_rank_; } @@ -104,6 +107,12 @@ class CommTask { return; } + virtual void ClearRecord() { + PADDLE_THROW( + phi::errors::Unimplemented("%s is not implemented.", __func__)); + return; + } + virtual std::string GetCommErrors() { PADDLE_THROW( phi::errors::Unimplemented("%s is not implemented.", __func__)); @@ -124,6 +133,16 @@ class CommTask { phi::errors::Unimplemented("%s is not implemented.", __func__)); return false; } + virtual void SetUpdated(bool updated) { + PADDLE_THROW( + phi::errors::Unimplemented("%s is not implemented.", __func__)); + return; + } + virtual bool IsUpdated() { + PADDLE_THROW( + phi::errors::Unimplemented("%s is not implemented.", __func__)); + return false; + } virtual void AbortComm() { PADDLE_THROW( phi::errors::Unimplemented("%s is not implemented.", __func__)); @@ -133,6 +152,7 @@ class CommTask { protected: std::string backend_; phi::Place place_; + std::string group_key_; int global_rank_; int rank_; int size_; @@ -144,7 +164,11 @@ class CommTask { CommType comm_type_; bool start_trace_updated_{false}; + // task status + bool started_ = false; bool completed_ = false; + // task status changed + bool updated_ = true; bool aborted_{false}; std::chrono::time_point start_time_; std::shared_ptr store_; diff --git a/paddle/phi/core/distributed/comm_task_manager.cc b/paddle/phi/core/distributed/comm_task_manager.cc index 427c8ead360c46..cc206c41d244ce 100644 --- a/paddle/phi/core/distributed/comm_task_manager.cc +++ b/paddle/phi/core/distributed/comm_task_manager.cc @@ -23,6 +23,7 @@ #include "paddle/phi/core/distributed/comm_context_manager.h" #include "paddle/phi/core/distributed/comm_task_manager.h" +#include #include #include @@ -41,28 +42,45 @@ namespace phi { namespace distributed { std::thread CommTaskManager::comm_task_loop_thread_; +std::thread CommTaskManager::comm_task_clear_loop_thread_; const int64_t CommTaskManager::loop_thread_sleep_millis = 10000; std::atomic CommTaskManager::terminated_; std::mutex CommTaskManager::comm_task_list_mutex_; std::condition_variable CommTaskManager::comm_task_list_cv_; std::list> CommTaskManager::comm_task_list_; + +std::mutex CommTaskManager::comm_task_clear_list_mutex_; +std::condition_variable CommTaskManager::comm_task_clear_list_cv_; +std::list> CommTaskManager::comm_task_clear_list_; + std::unordered_map> CommTaskManager::init_comm_task_map_; std::unordered_map> CommTaskManager::start_comm_task_map_; +std::unordered_map> + CommTaskManager::group_last_comm_task_; +std::chrono::time_point + CommTaskManager::last_update_time_ = std::chrono::steady_clock::now(); CommTaskManager::CommTaskManager() { terminated_.store(false); comm_task_loop_thread_ = std::thread(&CommTaskManager::CommTaskLoop, this); + comm_task_clear_loop_thread_ = + std::thread(&CommTaskManager::CommTaskClearLoop, this); LOG(INFO) << "CommTaskManager init success."; } CommTaskManager::~CommTaskManager() { terminated_.store(true); if (comm_task_loop_thread_.joinable()) { - comm_task_loop_thread_.join(); comm_task_list_cv_.notify_one(); + comm_task_loop_thread_.join(); + } + + if (comm_task_clear_loop_thread_.joinable()) { + comm_task_clear_list_cv_.notify_one(); + comm_task_clear_loop_thread_.join(); } LOG(INFO) << "CommTaskManager destruct success."; } @@ -74,14 +92,28 @@ void CommTaskManager::CommTaskEnqueue(std::shared_ptr comm_task) { } } +void CommTaskManager::CommTaskClearEnqueue( + std::shared_ptr comm_task) { + if (!terminated_.load()) { + std::lock_guard lock(comm_task_clear_list_mutex_); + comm_task_clear_list_.emplace_back(comm_task); + } +} + void CommTaskManager::Stop() { terminated_.store(true); LOG(INFO) << "CommTaskManager stopped begin."; if (comm_task_loop_thread_.joinable()) { - comm_task_loop_thread_.join(); comm_task_list_cv_.notify_one(); + comm_task_loop_thread_.join(); + } + + if (comm_task_clear_loop_thread_.joinable()) { + comm_task_clear_list_cv_.notify_one(); + comm_task_clear_loop_thread_.join(); } + LOG(INFO) << "CommTaskManager stopped."; } @@ -89,30 +121,55 @@ void CommTaskManager::CommTaskLoop() { bool done = false; while (!terminated_.load() || !done) { std::unique_lock lock(comm_task_list_mutex_); + VLOG(3) << "IsTimeout: " << IsTimeout() + << ", comm_task_list_ size: " << comm_task_list_.size() + << ", init_comm_task_map_ size: " << init_comm_task_map_.size() + << ", start_comm_task_map_ size: " << start_comm_task_map_.size() + << ", logged_ " << logged_; + comm_task_list_cv_.wait_for( lock, std::chrono::milliseconds(loop_thread_sleep_millis), [&]() -> bool { return terminated_.load(); }); + + if (IsTimeout() && !logged_) { + // case 1: all group is empty, has no task + // report error immediately + if (group_last_comm_task_.empty()) { + LOG(WARNING) << "Find no task started in all group"; + } else { + // case 2: all group is not empty, but all last task is completed + // case 3: all group is not empty, some group task started but not + for (auto iter : group_last_comm_task_) { + LogLongStr("Find last group comm task:", iter.second->GetTraceMsg()); + } + } + logged_ = true; + } for (auto iter = comm_task_list_.begin(); iter != comm_task_list_.end();) { auto task = *iter; if (task->IsTimeout()) { if (!task->IsStarted()) { - LOG(ERROR) << "Find timeout init but not start task: " - << task->GetTraceMsg() << ",comm:" << task->nccl_comm() - << ",stream:" << task->nccl_stream(); + LOG(WARNING) << "Find timeout init but not start task:" + << task->GetTraceMsg(); std::string task_key = task->UniqueKey(); init_comm_task_map_[task_key] = task; } else if (!task->IsCompleted()) { - LOG(ERROR) << "Find timeout start but not finish task: " - << task->GetTraceMsg() << ",comm:" << task->nccl_comm() - << ",stream:" << task->nccl_stream(); + LOG(WARNING) << "Find timeout start but not finish task:" + << task->GetTraceMsg(); std::string task_key = task->UniqueKey(); start_comm_task_map_[task_key] = task; } iter = comm_task_list_.erase(iter); } else { - if (task->IsStarted() && task->IsCompleted()) { - iter = comm_task_list_.erase(iter); + if (task->IsStarted()) { + if (task->IsCompleted()) { + CommTaskClearEnqueue(task); + iter = comm_task_list_.erase(iter); + } else { + ++iter; + } + UpdateLastCommTask(task); } else { ++iter; } @@ -136,6 +193,8 @@ void CommTaskManager::CommTaskLoop() { iter != start_comm_task_map_.end();) { auto task = iter->second; if (task->IsCompleted()) { + CommTaskClearEnqueue(task); + UpdateLastCommTask(task); iter = start_comm_task_map_.erase(iter); LOG(INFO) << "Finish timeout task: " << task->GetTraceMsg(); } else { @@ -152,5 +211,52 @@ void CommTaskManager::CommTaskLoop() { } } +void CommTaskManager::CommTaskClearLoop() { + std::future future; + while (!terminated_.load()) { + if (future.valid()) { + future.wait(); + } + std::unique_lock lock(comm_task_clear_list_mutex_); + comm_task_clear_list_cv_.wait_for( + lock, + std::chrono::milliseconds(loop_thread_sleep_millis), + [&]() -> bool { return terminated_.load(); }); + + VLOG(3) << "comm_task_clear_list_ size: " << comm_task_clear_list_.size(); + for (auto iter = comm_task_clear_list_.begin(); + iter != comm_task_clear_list_.end();) { + auto task = *iter; + VLOG(3) << "start clear task: " << task->GetTraceMsg(); + future = std::async(std::launch::async, [&]() { task->ClearRecord(); }); + if (future.wait_for(std::chrono::seconds(30)) == + std::future_status::timeout) { + VLOG(0) << "clear task timeout, detail: " << task->GetTraceMsg(); + break; + } + VLOG(3) << "end clear task: " << task->GetTraceMsg(); + iter = comm_task_clear_list_.erase(iter); + } + } +} + +void CommTaskManager::UpdateLastCommTask(std::shared_ptr task) { + if (!task->IsUpdated()) { + return; + } + group_last_comm_task_[task->GroupKey()] = task; + last_update_time_ = std::chrono::steady_clock::now(); + task->SetUpdated(false); +} + +void CommTaskManager::SetTimeout(int64_t timeout) { + timeout_ = std::chrono::milliseconds(timeout); +} + +bool CommTaskManager::IsTimeout() { + auto current_timepoint = std::chrono::steady_clock::now(); + return std::chrono::duration_cast( + current_timepoint - last_update_time_) >= timeout_; +} } // namespace distributed } // namespace phi diff --git a/paddle/phi/core/distributed/comm_task_manager.h b/paddle/phi/core/distributed/comm_task_manager.h index d53f7f9dd4ca8c..cf5a8701b153d2 100644 --- a/paddle/phi/core/distributed/comm_task_manager.h +++ b/paddle/phi/core/distributed/comm_task_manager.h @@ -47,12 +47,18 @@ class CommTaskManager { } void CommTaskEnqueue(std::shared_ptr comm_task); + void CommTaskClearEnqueue(std::shared_ptr comm_task); void Stop(); + void UpdateLastCommTask(std::shared_ptr comm_task); + void SetTimeout(int64_t timeout); private: void CommTaskLoop(); + void CommTaskClearLoop(); + bool IsTimeout(); static std::thread comm_task_loop_thread_; + static std::thread comm_task_clear_loop_thread_; static const int64_t loop_thread_sleep_millis; static std::atomic terminated_; @@ -60,6 +66,11 @@ class CommTaskManager { static std::mutex comm_task_list_mutex_; static std::condition_variable comm_task_list_cv_; static std::list> comm_task_list_; + + static std::mutex comm_task_clear_list_mutex_; + static std::condition_variable comm_task_clear_list_cv_; + static std::list> comm_task_clear_list_; + // not start task static std::unordered_map> init_comm_task_map_; @@ -67,7 +78,12 @@ class CommTaskManager { static std::unordered_map> start_comm_task_map_; std::shared_ptr store_; - bool store_error_{false}; + // record last comm task in current group, eg: group_key->comm_task + static std::unordered_map> + group_last_comm_task_; + static std::chrono::time_point last_update_time_; + std::chrono::milliseconds timeout_; + bool logged_ = false; }; } // namespace distributed diff --git a/paddle/phi/core/distributed/nccl_comm_task.cc b/paddle/phi/core/distributed/nccl_comm_task.cc index 470ba5d24f5e42..84530d3fcd23ab 100644 --- a/paddle/phi/core/distributed/nccl_comm_task.cc +++ b/paddle/phi/core/distributed/nccl_comm_task.cc @@ -18,6 +18,7 @@ #include "glog/logging.h" #include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/core/distributed/comm_context_manager.h" #include "paddle/phi/core/distributed/nccl_tools.h" #include "paddle/phi/core/distributed/trace_utils.h" #include "paddle/phi/core/utils/data_type.h" @@ -26,6 +27,7 @@ namespace phi { namespace distributed { NCCLCommTask::NCCLCommTask(const phi::Place& place, + const std::string& group_key, int rank, int size, int gid, @@ -39,6 +41,7 @@ NCCLCommTask::NCCLCommTask(const phi::Place& place, int64_t timeout) : CommTask("NCCL", place, + group_key, rank, size, gid, @@ -73,6 +76,19 @@ void NCCLCommTask::EndRecord() { CUDA_CHECK(cudaEventRecord(nccl_end_event_, nccl_stream_)); } +void NCCLCommTask::ClearRecord() { + if (start_event_created_) { + backends::gpu::GPUDeviceGuard guard(place_.device); + CUDA_CHECK(cudaEventDestroy(nccl_start_event_)); + start_event_created_ = false; + } + if (end_event_created_) { + backends::gpu::GPUDeviceGuard guard(place_.device); + CUDA_CHECK(cudaEventDestroy(nccl_end_event_)); + end_event_created_ = false; + } +} + bool NCCLCommTask::CudaEventQuery(cudaEvent_t event) { cudaError_t ret = cudaEventQuery(event); if (ret == cudaSuccess) { @@ -147,9 +163,31 @@ std::string NCCLCommTask::GetCommErrors() { return comm_error_; } -bool NCCLCommTask::IsStarted() { return CudaEventQuery(nccl_start_event_); } +bool NCCLCommTask::IsStarted() { + if (started_) { + return true; + } + if (start_event_created_ && CudaEventQuery(nccl_start_event_)) { + started_ = true; + updated_ = true; + } + return started_; +} + +bool NCCLCommTask::IsCompleted() { + if (completed_) { + return true; + } + if (end_event_created_ && CudaEventQuery(nccl_end_event_)) { + completed_ = true; + updated_ = true; + } + return completed_; +} + +void NCCLCommTask::SetUpdated(bool updated) { updated_ = updated; } -bool NCCLCommTask::IsCompleted() { return CudaEventQuery(nccl_end_event_); } +bool NCCLCommTask::IsUpdated() { return updated_; } bool NCCLCommTask::IsTimeout() { auto current_timepoint = std::chrono::steady_clock::now(); @@ -173,18 +211,19 @@ std::string NCCLCommTask::GetTraceMsg() { auto current_timepoint = std::chrono::steady_clock::now(); auto time_elapsed = std::chrono::duration_cast( current_timepoint - start_time_); - return "op:" + CommTypeToString(comm_type_) + ",gid:" + std::to_string(gid_) + - ",seq:" + std::to_string(seq_) + - ",started:" + std::to_string(IsStarted()) + - ",completed:" + std::to_string(IsCompleted()) + + auto global_ranks = + phi::distributed::CommContextManager::GetInstance().GetGroupRanks( + group_key_); + return "group_key:" + group_key_ + + ",group_ranks:" + VectorToString(global_ranks) + ",global_rank:" + std::to_string(global_rank_) + ",local_rank:" + std::to_string(rank_) + - ",size:" + std::to_string(size_) + ",numel:" + std::to_string(numel_) + - ",sync_op:" + std::to_string(sync_op_) + - ",use_calc_stream:" + std::to_string(use_calc_stream_) + - ",timeout:" + std::to_string(timeout_.count()) + - ",is_timeout:" + std::to_string(IsTimeout()) + - ",time_elapsed:" + std::to_string(time_elapsed.count()); + ",comm_count:" + std::to_string(seq_) + + ",op:" + CommTypeToString(comm_type_) + + ",started:" + std::to_string(started_) + + ",completed:" + std::to_string(completed_) + + ",numel:" + std::to_string(numel_) + + ",nranks:" + std::to_string(size_); } } // namespace distributed diff --git a/paddle/phi/core/distributed/nccl_comm_task.h b/paddle/phi/core/distributed/nccl_comm_task.h index f6725e9adc4457..69bee44a4c2d03 100644 --- a/paddle/phi/core/distributed/nccl_comm_task.h +++ b/paddle/phi/core/distributed/nccl_comm_task.h @@ -34,6 +34,7 @@ static int64_t DefaultTimeout = 30 * 60 * 1000; class NCCLCommTask : public CommTask { public: NCCLCommTask(const phi::Place& place = phi::Place(), + const std::string& group_key = "", int rank = -1, int size = 0, int gid = 0, @@ -51,6 +52,8 @@ class NCCLCommTask : public CommTask { bool IsStarted() override; bool IsTimeout() override; bool IsCompleted() override; + void SetUpdated(bool updated) override; + bool IsUpdated() override; std::string GetTraceMsg() override; std::string GetCommErrors() override; @@ -58,6 +61,7 @@ class NCCLCommTask : public CommTask { void StartRecord(); void EndRecord(); + void ClearRecord() override; bool CudaEventQuery(cudaEvent_t event); diff --git a/paddle/phi/core/distributed/utils.h b/paddle/phi/core/distributed/utils.h index 9b37be71b1f5a9..aa9d4f3161f295 100644 --- a/paddle/phi/core/distributed/utils.h +++ b/paddle/phi/core/distributed/utils.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#include #include "paddle/phi/core/dense_tensor.h" namespace phi { @@ -85,5 +86,65 @@ inline std::string CommTypeToString(CommType CommType) { return "Unknown"; } +// convert vector to string, concatenate continuous intervals with `~`, +// concatenate discontinuous intervals with `#` eg: [1,2,3,4,5,7,8,9] => +// 1~3#4#5#7~9 +inline std::string VectorToString(const std::vector& vec) { + if (vec.empty()) { + return ""; + } + if (vec.size() == 1) { + return std::to_string(vec[0]); + } + + std::stringstream ss; + size_t i = 0; + int start_rank = vec[i]; + for (; i < vec.size() - 1; ++i) { + if (vec[i] + 1 == vec[i + 1]) { + continue; + } + if (ss.rdbuf()->in_avail() != 0) { + ss << "#"; + } + ss << start_rank; + if (start_rank != vec[i]) { + ss << "~"; + ss << vec[i]; + } + start_rank = vec[i + 1]; + } + + if (ss.rdbuf()->in_avail() != 0) { + ss << "#"; + } + ss << start_rank; + if (start_rank != vec[i]) { + ss << "~"; + ss << vec[i]; + } + + return ss.str(); +} + +inline void LogLongStr(const std::string prefix, const std::string& log) { + size_t max_log_size = 20000; + if (log.size() >= max_log_size) { + int log_count = log.size() / max_log_size + 1; + int index = 0; + int part = 0; + while (index + max_log_size < log.size()) { + LOG(INFO) << prefix << "part:" << part << "/" << log_count << "," + << log.substr(index, max_log_size) << std::endl; + index += max_log_size; + part++; + } + LOG(INFO) << prefix << "part:" << part << "/" << log_count << "," + << log.substr(index) << std::endl; + } else { + LOG(INFO) << prefix << "part:0/1," << log << std::endl; + } +} + } // namespace distributed } // namespace phi From ae651577d2bcc36a2884872950c30c23aa2a565d Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Mon, 20 Nov 2023 14:37:51 +0800 Subject: [PATCH 06/11] [Distributed]Add tensorfuse and non-overlap for shardingv1 (#59117) * sort comm buffer for v1 * sort comm buffer for v1 * sort comm buffer for v1 * sort comm buffer for v1 * support non-overlap and fuse * revert test_seq2seq utest bug --- .../dygraph_sharding_optimizer.py | 116 +++++++++--------- .../fleet/utils/tensor_fusion_helper.py | 7 +- test/dygraph_to_static/CMakeLists.txt | 2 - tools/final_ut_parallel_rule.py | 1 - tools/parallel_UT_rule.py | 2 - 5 files changed, 67 insertions(+), 61 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py index bb0af75ad6e6f4..e959074361d8df 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py @@ -36,6 +36,12 @@ g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 1)) g_shard_norm_align_dp = int(os.environ.get("FLAGS_shard_norm_align_dp", 0)) +g_shard_sort_reduce_root = int( + os.environ.get("FLAGS_shard_sort_reduce_root", 1) +) # it will remove in the future + +g_shard_fused_gradient = int(os.environ.get("FLAGS_shard_fused_gradient", 0)) + if g_shard_norm_align_dp: assert ( @@ -128,18 +134,33 @@ def __init__(self, optimizer, hcg): acc_steps = sharding_configs.accumulate_steps self.comm_overlap = sharding_configs.comm_overlap comm_group = self._hcg.get_sharding_parallel_group() + self._use_fuse_gradients = g_shard_fused_gradient + + assert ( + not self.comm_overlap or self._use_fuse_gradients + ), "If you use comm overlap in sharding, you should set g_shard_fused_gradient to True" + + if self._use_fuse_gradients: + # Build communication buffers once and store them + if not hasattr(self, 'comm_buffers'): + self.comm_buffers = self._build_comm_buffers( + comm_group, + acc_steps, + group_size=128 * 1024 * 1024, + ) + # NOTE(shenliang03): Sort the comm_buffers by dst rank, + # it will improve the performance in reduce communicate. Default + # g_shard_sort_reduce_root is True. + if g_shard_sort_reduce_root: + self.comm_buffers.sort(key=lambda x: x._dst) if not self._pp_overlap and self.comm_overlap: assert ( acc_steps > 0 ), "acc_steps should be larger than 0 when using comm_overlap in sharding" - self.register_reduce_overlap_hook( - comm_group, acc_steps, use_comm=True - ) + self.register_reduce_overlap_hook(use_comm=True) - def _build_comm_buffers( - self, comm_group, acc_steps=1, group_size=256 * 1024 * 1024 - ): + def _build_comm_buffers(self, comm_group, acc_steps, group_size): parameter_list = list(self._parameter_list) if not parameter_list: @@ -180,19 +201,7 @@ def _build_comm_buffers( return comm_buffers - def register_reduce_overlap_hook( - self, - comm_group, - acc_steps, - use_comm=False, - group_size=128 * 1024 * 1024, - ): - # Build communication buffers once and store them - if not hasattr(self, 'comm_buffers'): - self.comm_buffers = self._build_comm_buffers( - comm_group, acc_steps, group_size - ) - + def register_reduce_overlap_hook(self, use_comm): # Register backward hooks for each parameter in the buffer for buffer in self.comm_buffers: for param in buffer._params: @@ -295,37 +304,39 @@ def reduce_gradients(self, parameter_list, hcg): buffer.scale_and_split_grads() return - # TODO merge grad / nrank with dp - with framework.no_grad(): - sharding_nrank = hcg.get_sharding_parallel_group().nranks - for param in parameter_list: - g_var = None - if param.trainable and (param._grad_ivar() is not None): - g_var = param._grad_ivar() - if param.trainable and hasattr(param, "main_grad"): - assert ( - param._grad_ivar() is None - ), "param.grad should be None when using main_grad" - g_var = param.main_grad - if g_var is not None: - g_var.scale_(1.0 / sharding_nrank) - param_rank = self._param2rank[param.name] - if not g_shard_use_reduce: - paddle.distributed.all_reduce( - g_var, - group=hcg.get_sharding_parallel_group(), - sync_op=True, - ) - else: - # TODO(pangengzheng): change to reduce operation when there is no diff in calculating global norm values in HybridParallelClipGrad compared to dp. - paddle.distributed.reduce( - g_var, - dst=hcg.get_sharding_parallel_group().ranks[ - param_rank - ], - group=hcg.get_sharding_parallel_group(), - sync_op=True, - ) + if self._use_fuse_gradients: + for buffer in self.comm_buffers: + buffer._comm_grads() + buffer.scale_and_split_grads() + return + + sharding_nrank = hcg.get_sharding_parallel_group().nranks + for param in parameter_list: + g_var = None + if param.trainable and (param._grad_ivar() is not None): + g_var = param._grad_ivar() + if param.trainable and hasattr(param, "main_grad"): + assert ( + param._grad_ivar() is None + ), "param.grad should be None when using main_grad" + g_var = param.main_grad + if g_var is not None: + g_var.scale_(1.0 / sharding_nrank) + param_rank = self._param2rank[param.name] + if not g_shard_use_reduce: + paddle.distributed.all_reduce( + g_var, + group=hcg.get_sharding_parallel_group(), + sync_op=True, + ) + else: + # TODO(pangengzheng): change to reduce operation when there is no diff in calculating global norm values in HybridParallelClipGrad compared to dp. + paddle.distributed.reduce( + g_var, + dst=hcg.get_sharding_parallel_group().ranks[param_rank], + group=hcg.get_sharding_parallel_group(), + sync_op=True, + ) def _sharding_sync_parameters(self): """ @@ -357,11 +368,6 @@ def minimize( ): # NOTE in dygraph mode, the only different between step and minimize is that minimize # allow user to customize the parameters for updating on each step - - if self.comm_overlap: - for buffer in self.comm_buffers: - buffer.scale_and_split_grads() - assert ( not self._using_param_groups ), "minimize() is not support if using param_groups" diff --git a/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py b/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py index 0dc1dcff0e5897..cfd74f3328a38a 100644 --- a/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py +++ b/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py @@ -264,7 +264,12 @@ def __init__( self._comm_group = comm_group self._release_grads = release_grads - self.use_main_grad = hasattr(self._params[0], "main_grad") + self.use_main_grad = all( + hasattr(param, "main_grad") for param in self._params + ) + assert ( + self.use_main_grad + ), "All parameters in FusedCommBuffer must have main_grad." self._task = None self._dtype = ( diff --git a/test/dygraph_to_static/CMakeLists.txt b/test/dygraph_to_static/CMakeLists.txt index 3b1d7f726b1e83..2a03e4c3d31f60 100644 --- a/test/dygraph_to_static/CMakeLists.txt +++ b/test/dygraph_to_static/CMakeLists.txt @@ -22,7 +22,6 @@ set(TEST_EAGER_OPS test_resnet_pure_fp16 test_se_resnet test_sentiment - test_seq2seq test_tsm test_word2vec test_yolov3 @@ -64,7 +63,6 @@ set_tests_properties(test_se_resnet PROPERTIES TIMEOUT 900) set_tests_properties(test_yolov3 PROPERTIES TIMEOUT 900 LABELS "RUN_TYPE=EXCLUSIVE") set_tests_properties(test_mobile_net PROPERTIES TIMEOUT 120) -set_tests_properties(test_seq2seq PROPERTIES TIMEOUT 150) set_tests_properties(test_cycle_gan PROPERTIES TIMEOUT 150) set_tests_properties(test_bert PROPERTIES TIMEOUT 120) set_tests_properties(test_basic_api_transformation PROPERTIES TIMEOUT 120) diff --git a/tools/final_ut_parallel_rule.py b/tools/final_ut_parallel_rule.py index e0fc86c19a8cc4..eba7240c0d0761 100644 --- a/tools/final_ut_parallel_rule.py +++ b/tools/final_ut_parallel_rule.py @@ -62,7 +62,6 @@ def classify_cases_by_mem(rootPath): case_always_timeout = [ 'test_quant2_int8_resnet50_channelwise_mkldnn', 'test_parallel_dygraph_unused_variables_gloo', - 'test_seq2seq', 'test_pool3d_op', 'test_trilinear_interp_op', 'test_trilinear_interp_v2_op', diff --git a/tools/parallel_UT_rule.py b/tools/parallel_UT_rule.py index e9f8f99a0086d1..b8c4537114afdc 100755 --- a/tools/parallel_UT_rule.py +++ b/tools/parallel_UT_rule.py @@ -1636,7 +1636,6 @@ 'test_trt_scale_op', 'test_trt_slice_plugin', 'test_trt_gather_op', - 'test_seq2seq', 'test_bilinear_interp_v2_op', 'test_conv2d_transpose_op', 'test_conv3d_op', @@ -3015,7 +3014,6 @@ 'test_inplace', 'test_einsum', 'test_cosine_similarity_api', - 'test_seq2seq', 'test_word2vec', 'test_scale_mkldnn_op', 'test_asp_pruning_2d_best', From 58a43ce17814b18bf8a771f92226ce4f999bdd04 Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Mon, 20 Nov 2023 18:37:24 +0800 Subject: [PATCH 07/11] rm seq2seq (#59149) --- test/dygraph_to_static/CMakeLists.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/dygraph_to_static/CMakeLists.txt b/test/dygraph_to_static/CMakeLists.txt index 2a03e4c3d31f60..fe672c986093bc 100644 --- a/test/dygraph_to_static/CMakeLists.txt +++ b/test/dygraph_to_static/CMakeLists.txt @@ -31,6 +31,8 @@ set(TEST_EAGER_OPS test_simnet test_transformer) list(REMOVE_ITEM TEST_OPS test_lac) +list(REMOVE_ITEM TEST_OPS test_seq2seq) + # NOTE(Aurelius84): In case of Windows CI, if open ON_INFER, RWLOCK of Scope # will be removed and will cause some random failed in multi-thread. if(WITH_PYTHON) From d1f2194ec6f59ae02c81d606aebb949753ed9ba1 Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Tue, 21 Nov 2023 10:30:39 +0800 Subject: [PATCH 08/11] [Distributed]Add shardingv2 overlap in non-PP (#59119) * add sharding v2 + overlap * fix shardingv2 * add paddlenlp for sharidng overlap in stage v2 --- .../dygraph_sharding_optimizer.py | 65 ++++++++++++++++--- 1 file changed, 55 insertions(+), 10 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py index e959074361d8df..f363ae42f80ced 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py @@ -569,21 +569,65 @@ def __init__(self, optimizer, hcg): 'pp_configs' ].release_gradients - self._build_comm_buffers() self._set_inner_opt_attr('_parameter_list', self._local_parameter_list) self._set_inner_opt_attr('_param_groups', self._local_parameter_list) - def _build_comm_buffers(self, group_size=256 * 1024 * 1024): + sharding_configs = strategy.hybrid_configs["sharding_configs"] + acc_steps = sharding_configs.accumulate_steps + comm_group = self._hcg.get_sharding_parallel_group() + + self.comm_overlap = sharding_configs.comm_overlap + + self._build_comm_buffers( + comm_group, acc_steps, group_size=128 * 1024 * 1024 + ) + # NOTE(shenliang03): Sort the comm_buffers by dst rank, + # it will improve the performance in reduce communicate. Default + # g_shard_sort_reduce_root is True. + if g_shard_sort_reduce_root: + self._comm_buffer_list.sort(key=lambda x: x._dst) + + assert ( + not self.comm_overlap or acc_steps > 0 + ), "acc_steps should be larger than 0 when using comm_overlap in sharding" + + assert ( + not self.pp_overlap or not self.comm_overlap + ), "pp_overlap and comm_overlap should not be True at the same time" + + if not self.pp_overlap and self.comm_overlap: + self.register_reduce_overlap_hook(use_comm=True) + + def register_reduce_overlap_hook(self, use_comm): + # Register backward hooks for each parameter in the buffer + for buffer in self._comm_buffer_list: + for param in buffer._params: + # Directly register the hook function with necessary parameters + param._register_backward_hook( + self._create_backward_hook(buffer, param, use_comm) + ) + + def _create_backward_hook(self, buffer, param, use_comm): + """Creates a backward hook function for autograd.""" + + @paddle.autograd.no_grad() + def fused_allreduce(*_): + # Directly add gradient to the buffer + buffer.add_grad(param, use_comm=use_comm) + + return fused_allreduce + + def _build_comm_buffers(self, comm_group, acc_steps, group_size): if self.pp_overlap: return - comm_group = self._hcg.get_sharding_parallel_group() var_groups = assign_group_by_size(self._parameter_list, group_size) for group_idx, parameters in var_groups.items(): buffer = FusedCommBuffer( group_idx, parameters, comm_group, + acc_steps, act=HOOK_ACTION.REDUCE_SCATTER, release_grads=self.pp_release_grads, ) @@ -625,15 +669,16 @@ def filter_parameters(self, parameter_list, hcg): return parameter_list def reduce_gradients(self, parameter_list, hcg): - # TODO merge grad / nrank with dp with framework.no_grad(): - for comm_buffer in self._comm_buffer_list: - if self.pp_release_grads and comm_buffer.grad_storage is None: - for param in comm_buffer.params: - comm_buffer._copy_grad_to_buffer(param) + for buffer in self._comm_buffer_list: + if self.pp_release_grads and buffer.grad_storage is None: + for param in buffer.params: + buffer._copy_grad_to_buffer(param) - comm_buffer._comm_grads() - comm_buffer.scale_and_split_grads() + if not self.comm_overlap: + buffer._comm_grads() + + buffer.scale_and_split_grads() def _sharding_sync_parameters(self): """ From 94511148283dea826e2368d078c1f6ed3c77bb33 Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Tue, 21 Nov 2023 10:50:10 +0800 Subject: [PATCH 09/11] [Distributed]Fix group size in fusion in sharding v2 (#59184) * fix group size in sharding v2 * fix group size in sharding v2 --- .../dygraph_optimizer/dygraph_sharding_optimizer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py index f363ae42f80ced..db827a35ee16c6 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py @@ -138,7 +138,7 @@ def __init__(self, optimizer, hcg): assert ( not self.comm_overlap or self._use_fuse_gradients - ), "If you use comm overlap in sharding, you should set g_shard_fused_gradient to True" + ), "If you use comm overlap in sharding, you should set FLAGS_shard_fused_gradient to True" if self._use_fuse_gradients: # Build communication buffers once and store them @@ -578,8 +578,11 @@ def __init__(self, optimizer, hcg): self.comm_overlap = sharding_configs.comm_overlap + # NOTE(shenliang03): `group_size` will affect the result of the parameter fuse, + # which in turn affects save/load. Therefore, it is best not to modify 256MB + # to prevent compatibility issues. self._build_comm_buffers( - comm_group, acc_steps, group_size=128 * 1024 * 1024 + comm_group, acc_steps, group_size=256 * 1024 * 1024 ) # NOTE(shenliang03): Sort the comm_buffers by dst rank, # it will improve the performance in reduce communicate. Default From adbbcf5ad33e7f2e2286683e242cce84dc10e4b5 Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Wed, 22 Nov 2023 19:45:49 +0800 Subject: [PATCH 10/11] fix bug in sharding v1 (#59267) --- .../dygraph_optimizer/dygraph_sharding_optimizer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py index db827a35ee16c6..577a1e04394ef7 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py @@ -295,6 +295,7 @@ def _map_param_to_rank(self): mapping[param.name] = rank return mapping + @paddle.autograd.no_grad() def reduce_gradients(self, parameter_list, hcg): if self._pp_overlap: return From c940ea71b5ad216e2395d6b4802a7fec4f9a21c5 Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Mon, 27 Nov 2023 11:28:22 +0800 Subject: [PATCH 11/11] [Distributed] Add assert for sharding + pp (#59221) * add assert for sharding + pp * add assert for sharding + pp --- .../dygraph_optimizer/dygraph_sharding_optimizer.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py index 577a1e04394ef7..9eb6d6e500c2fd 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py @@ -135,11 +135,16 @@ def __init__(self, optimizer, hcg): self.comm_overlap = sharding_configs.comm_overlap comm_group = self._hcg.get_sharding_parallel_group() self._use_fuse_gradients = g_shard_fused_gradient + self._use_pipelie_parallel = strategy.hybrid_configs["pp_degree"] > 1 assert ( not self.comm_overlap or self._use_fuse_gradients ), "If you use comm overlap in sharding, you should set FLAGS_shard_fused_gradient to True" + assert not ( + self._use_pipelie_parallel and self._use_fuse_gradients + ), "You can not use pipelie parallel and fused gradient at the same time" + if self._use_fuse_gradients: # Build communication buffers once and store them if not hasattr(self, 'comm_buffers'): @@ -599,6 +604,11 @@ def __init__(self, optimizer, hcg): not self.pp_overlap or not self.comm_overlap ), "pp_overlap and comm_overlap should not be True at the same time" + self._use_pipelie_parallel = strategy.hybrid_configs["pp_degree"] > 1 + assert not ( + self._use_pipelie_parallel and self.comm_overlap + ), "You should not use pipelie parallel and comm_overlap at the same time" + if not self.pp_overlap and self.comm_overlap: self.register_reduce_overlap_hook(use_comm=True)