Skip to content

Commit c45910c

Browse files
committed
Update modeling_utils.py
1 parent 8672488 commit c45910c

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

src/transformers/modeling_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2481,6 +2481,9 @@ def initialize_weights(self):
24812481
to the `_init_weights` of the outer-most model instead of the given sub-model.
24822482
This function dynamically dispatches the correct `init_weights` function to the modules as we advance in the
24832483
module graph along the recursion. It can handle an arbitrary number of sub-models.
2484+
2485+
Note that the `torch.no_grad()` decorator is very important as well, as most of our `_init_weights` do not use
2486+
`torch.nn.init` functions (which are all no_grad by default), but simply do in-place ops such as `module.weight.data.zero_()`.
24842487
"""
24852488
if not hasattr(torch.nn.Module, "smart_apply"):
24862489
# This function is equivalent to `torch.nn.Module.apply`, except that it dynamically adjust the function

0 commit comments

Comments
 (0)