@@ -2476,30 +2476,29 @@ def _initialize_weights(self, module):
24762476 def initialize_weights (self ):
24772477 """
24782478 This is equivalent to calling `self.apply(self._initialize_weights)`, but instead of full depth-first recursion,
2479- it handles correctly composite models. Indeed, depth-first recursion fails with composite models as it will usually
2479+ it correctly handles composite models. Indeed, depth-first recursion fails with composite models as it will usually
24802480 initialize the basic blocks (e.g. nn.Linear, nn.Embedding, etc) first, which will cause them to be initialized according
24812481 to the `_init_weights` of the outer-most model instead of the given sub-model.
2482- This function first searches for sub-models, initialize them, then initialize only remaining modules.
2482+ This function dynamically dispatches the correct `init_weights` function to the modules as we advance in the
2483+ module graph along the recursion. It can handle an arbitrary number of sub-models.
24832484 """
2484- sub_models = []
2485- for module_name , module in self .named_modules ():
2486- # self is of course not a sub-model
2487- if module is self :
2488- continue
2489- if hasattr (module , "_init_weights" ):
2490- sub_models .append (module_name )
2491-
2492- # sort according to depth, in order to initialize the dept-most sub-models first (to avoid issue mentionned in docstring)
2493- # Note that the ordering of similar depth modules is not important, as they cannot have common modules
2494- sub_models = sorted (sub_models , key = lambda x : len (x .split ("." )), reverse = True )
2495-
2496- for module_name in sub_models :
2497- module = self .get_submodule (module_name )
2498- # This will set the `_is_hf_initialized` flag everywhere, making future calls on the same module to be skipped
2499- module .apply (module ._initialize_weights )
2500-
2501- # Finally, apply it to self as well to finalize missing modules
2502- self .apply (self ._initialize_weights )
2485+ if not hasattr (torch .nn .Module , "smart_apply" ):
2486+ # This function is equivalent to `torch.nn.Module.apply`, except that it dynamically adjust the function
2487+ # to apply as we go down the graph
2488+ def smart_apply (self , fn ):
2489+ for module in self .children ():
2490+ # We found a sub-model: recursively dispatch its own init function now!
2491+ if hasattr (module , "_init_weights" ):
2492+ module .smart_apply (module ._initialize_weights )
2493+ else :
2494+ module .smart_apply (fn )
2495+ fn (self )
2496+ return self
2497+
2498+ torch .nn .Module .smart_apply = smart_apply
2499+
2500+ # Let the magic happen with this simple call
2501+ self .smart_apply (self ._initialize_weights )
25032502
25042503 def tie_weights (self ):
25052504 """
0 commit comments