Skip to content

Commit

Permalink
Fixing bug with param count without embeddings (#12461)
Browse files Browse the repository at this point in the history
* fixing bug with param count without embeddings

* Update src/transformers/modeling_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* style

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
TevenLeScao and sgugger authored Jul 1, 2021
1 parent d5b8fe3 commit 7f0027d
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,11 +352,16 @@ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool
:obj:`int`: The number of parameters.
"""

def parameter_filter(x):
return (x.requires_grad or not only_trainable) and not (isinstance(x, nn.Embedding) and exclude_embeddings)

params = filter(parameter_filter, self.parameters()) if only_trainable else self.parameters()
return sum(p.numel() for p in params)
if exclude_embeddings:
embedding_param_names = [
f"{name}.weight" for name, module_type in self.named_modules() if isinstance(module_type, nn.Embedding)
]
non_embedding_parameters = [
parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
]
return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
else:
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)

def estimate_tokens(self, input_dict: Dict[str, Union[torch.Tensor, Any]]) -> int:
"""
Expand Down

0 comments on commit 7f0027d

Please sign in to comment.