Skip to content

Commit 8672488

Browse files
committed
Super elegant and efficient init for submodels
1 parent 795ee07 commit 8672488

File tree

1 file changed

+20
-21
lines changed

1 file changed

+20
-21
lines changed

src/transformers/modeling_utils.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2476,30 +2476,29 @@ def _initialize_weights(self, module):
24762476
def initialize_weights(self):
24772477
"""
24782478
This is equivalent to calling `self.apply(self._initialize_weights)`, but instead of full depth-first recursion,
2479-
it handles correctly composite models. Indeed, depth-first recursion fails with composite models as it will usually
2479+
it correctly handles composite models. Indeed, depth-first recursion fails with composite models as it will usually
24802480
initialize the basic blocks (e.g. nn.Linear, nn.Embedding, etc) first, which will cause them to be initialized according
24812481
to the `_init_weights` of the outer-most model instead of the given sub-model.
2482-
This function first searches for sub-models, initialize them, then initialize only remaining modules.
2482+
This function dynamically dispatches the correct `init_weights` function to the modules as we advance in the
2483+
module graph along the recursion. It can handle an arbitrary number of sub-models.
24832484
"""
2484-
sub_models = []
2485-
for module_name, module in self.named_modules():
2486-
# self is of course not a sub-model
2487-
if module is self:
2488-
continue
2489-
if hasattr(module, "_init_weights"):
2490-
sub_models.append(module_name)
2491-
2492-
# sort according to depth, in order to initialize the dept-most sub-models first (to avoid issue mentionned in docstring)
2493-
# Note that the ordering of similar depth modules is not important, as they cannot have common modules
2494-
sub_models = sorted(sub_models, key=lambda x: len(x.split(".")), reverse=True)
2495-
2496-
for module_name in sub_models:
2497-
module = self.get_submodule(module_name)
2498-
# This will set the `_is_hf_initialized` flag everywhere, making future calls on the same module to be skipped
2499-
module.apply(module._initialize_weights)
2500-
2501-
# Finally, apply it to self as well to finalize missing modules
2502-
self.apply(self._initialize_weights)
2485+
if not hasattr(torch.nn.Module, "smart_apply"):
2486+
# This function is equivalent to `torch.nn.Module.apply`, except that it dynamically adjust the function
2487+
# to apply as we go down the graph
2488+
def smart_apply(self, fn):
2489+
for module in self.children():
2490+
# We found a sub-model: recursively dispatch its own init function now!
2491+
if hasattr(module, "_init_weights"):
2492+
module.smart_apply(module._initialize_weights)
2493+
else:
2494+
module.smart_apply(fn)
2495+
fn(self)
2496+
return self
2497+
2498+
torch.nn.Module.smart_apply = smart_apply
2499+
2500+
# Let the magic happen with this simple call
2501+
self.smart_apply(self._initialize_weights)
25032502

25042503
def tie_weights(self):
25052504
"""

0 commit comments

Comments
 (0)