Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions deepspeed/runtime/pipe/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,26 +443,34 @@ def _partition_layers(self, method='uniform'):

self._set_bounds(start=self.parts[stage_id], stop=self.parts[stage_id + 1])

@staticmethod
def _recursive_getattr(module: torch.nn.Module, attr_name: str) -> torch.Tensor:
'''Allow getting an attribute like "linear.weight"'''
weight = module
for item in attr_name.split("."):
weight = getattr(weight, item)
return weight

def allreduce_tied_weight_gradients(self):
'''All reduce the gradients of the tied weights between tied stages'''
for key, comm in self.tied_comms.items():
for attr_name in comm['weight_attr']:
weight = getattr(self.tied_modules[key], attr_name)
weight = self._recursive_getattr(self.tied_modules[key], attr_name)
dist.all_reduce(weight.grad, group=comm['group'])

def get_tied_weights_and_groups(self):
weight_group_list = []
for key, comm in self.tied_comms.items():
for attr_name in comm['weight_attr']:
weight = getattr(self.tied_modules[key], attr_name)
weight = self._recursive_getattr(self.tied_modules[key], attr_name)
weight_group_list.append((weight, comm['group']))
return weight_group_list

def _synchronize_tied_weights(self):
for key, comm in self.tied_comms.items():
for attr_name in comm['weight_attr']:
dist.broadcast(
getattr(comm['module'], attr_name),
self._recursive_getattr(comm['module'], attr_name),
src=min(comm['ranks']),
group=comm['group'],
)
Expand All @@ -475,7 +483,10 @@ def _index_tied_modules(self):

specs = self._layer_specs
tie_keys = set(s.key for s in specs if isinstance(s, TiedLayerSpec))
for key in tie_keys:
# Since Python 3.7, "Dictionary order is guaranteed to be insertion order."
# Sort tie_keys here so that orders of self.tied_comms.items() are consistent
# among ranks.
for key in sorted(tie_keys):
# Find the layers that the tied module appears in
tied_layers = []
for idx, layer in enumerate(specs):
Expand Down