From 68e2fd19c07af558f88d0b2db6d3992c3ed3f4ed Mon Sep 17 00:00:00 2001 From: Mingjie Li Date: Tue, 8 Apr 2025 12:37:32 +0800 Subject: [PATCH] Support complicated use cases with TiedLayerSpec Extend the builtin `getattr` to a recursive version `PipelineModule._recursive_getattr` for nested tied weights, e.g., "linear.weight". Meanwhile, sort tie_keys in `PipelineModule._index_tied_modules` to avoid hanging. Signed-off-by: Mingjie Li --- deepspeed/runtime/pipe/module.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/deepspeed/runtime/pipe/module.py b/deepspeed/runtime/pipe/module.py index 49fa2807c355..2bc0c37bffb7 100644 --- a/deepspeed/runtime/pipe/module.py +++ b/deepspeed/runtime/pipe/module.py @@ -443,18 +443,26 @@ def _partition_layers(self, method='uniform'): self._set_bounds(start=self.parts[stage_id], stop=self.parts[stage_id + 1]) + @staticmethod + def _recursive_getattr(module: torch.nn.Module, attr_name: str) -> torch.Tensor: + '''Allow getting an attribute like "linear.weight"''' + weight = module + for item in attr_name.split("."): + weight = getattr(weight, item) + return weight + def allreduce_tied_weight_gradients(self): '''All reduce the gradients of the tied weights between tied stages''' for key, comm in self.tied_comms.items(): for attr_name in comm['weight_attr']: - weight = getattr(self.tied_modules[key], attr_name) + weight = self._recursive_getattr(self.tied_modules[key], attr_name) dist.all_reduce(weight.grad, group=comm['group']) def get_tied_weights_and_groups(self): weight_group_list = [] for key, comm in self.tied_comms.items(): for attr_name in comm['weight_attr']: - weight = getattr(self.tied_modules[key], attr_name) + weight = self._recursive_getattr(self.tied_modules[key], attr_name) weight_group_list.append((weight, comm['group'])) return weight_group_list @@ -462,7 +470,7 @@ def _synchronize_tied_weights(self): for key, comm in self.tied_comms.items(): for attr_name in comm['weight_attr']: dist.broadcast( - getattr(comm['module'], attr_name), + self._recursive_getattr(comm['module'], attr_name), src=min(comm['ranks']), group=comm['group'], ) @@ -475,7 +483,10 @@ def _index_tied_modules(self): specs = self._layer_specs tie_keys = set(s.key for s in specs if isinstance(s, TiedLayerSpec)) - for key in tie_keys: + # Since Python 3.7, "Dictionary order is guaranteed to be insertion order." + # Sort tie_keys here so that orders of self.tied_comms.items() are consistent + # among ranks. + for key in sorted(tie_keys): # Find the layers that the tied module appears in tied_layers = [] for idx, layer in enumerate(specs):