Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🔴 🚨 Resizing tokens embeddings: initialize from old embeddings' normal distribution. #33325

Merged
merged 32 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
25c92e1
intilize new embeddings from normal distrib
abuelnasr0 Sep 5, 2024
a95639c
Fix typo in comments
abuelnasr0 Sep 5, 2024
d850b99
Fix typo in comments
abuelnasr0 Sep 5, 2024
3f44507
Fix style
abuelnasr0 Sep 5, 2024
5ea5f82
Fix variables naming
abuelnasr0 Sep 5, 2024
d1d81d5
Add tests
abuelnasr0 Sep 5, 2024
f3aaf0a
Fix style
abuelnasr0 Sep 5, 2024
bdef61a
code consistency nit
abuelnasr0 Sep 6, 2024
15a7b5a
Add deepspeed support
abuelnasr0 Sep 6, 2024
6e40b4f
Add deepspeed support
abuelnasr0 Sep 6, 2024
aba7d8c
Conver embeddings weights to float32 before computations
abuelnasr0 Sep 6, 2024
4f1b0fa
Add deepspeed tests
abuelnasr0 Sep 7, 2024
dea8e28
Cover when vocab_size is smaller than embedding_size
abuelnasr0 Sep 8, 2024
84f8cfa
Style fix
abuelnasr0 Sep 8, 2024
2923e85
Add tests for vocab_size smaller than hiddin_size
abuelnasr0 Sep 8, 2024
188ba1b
Style fix
abuelnasr0 Sep 8, 2024
22ac85c
Nits in tests
abuelnasr0 Sep 8, 2024
3e42f66
Nits in tests
abuelnasr0 Sep 8, 2024
226f31c
Check for deepspeed before importing it
abuelnasr0 Sep 9, 2024
cef744f
Increase vocab_size for positive definite covariance matrix test
abuelnasr0 Sep 9, 2024
6583cd5
Add warning
abuelnasr0 Sep 15, 2024
7577cd4
Add multivariate_resizing flag and implement resizing for lm_heads
abuelnasr0 Sep 27, 2024
0472bac
Fix typo
abuelnasr0 Sep 27, 2024
fd4ad00
Fix wrong bias indexing
abuelnasr0 Sep 27, 2024
6ff2bca
Fix bias is zero check
abuelnasr0 Sep 27, 2024
12e61c6
remove multivariate_resizing flag from tests
abuelnasr0 Sep 27, 2024
eb80c33
Intialize bias from old bias normal distribution
abuelnasr0 Sep 27, 2024
ef6bdbc
Fixup
abuelnasr0 Sep 27, 2024
5cdce5f
Code usability
abuelnasr0 Oct 1, 2024
f4a9cf4
Use mean_resizing instead of multivariate_resizing
abuelnasr0 Oct 1, 2024
fc436d7
Fix up
abuelnasr0 Oct 1, 2024
8e60a36
Fix comments and docs
abuelnasr0 Oct 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 147 additions & 11 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2048,7 +2048,10 @@ def _get_no_split_modules(self, device_map: str):
return list(_no_split_modules)

def resize_token_embeddings(
self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
self,
new_num_tokens: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
mean_resizing: bool = True,
) -> nn.Embedding:
"""
Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
Expand All @@ -2068,11 +2071,19 @@ def resize_token_embeddings(
`>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
details about this, or help on choosing the correct value for resizing, refer to this guide:
https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
mean_resizing (`bool`):
Whether to initialize the added embeddings from a multivariate normal distribution that has old embeddings' mean and
covariance or to initialize them with a normal distribution that has a mean of zero and std equals `config.initializer_range`.

Setting `mean_resizing` to `True` is useful when increasing the size of the embeddings of causal language models,
where the generated tokens' probabilities won't be affected by the added embeddings because initializing the new embeddings with the
old embeddings' mean will reduce the kl-divergence between the next token probability before and after adding the new embeddings.
Refer to this article for more information: https://nlp.stanford.edu/~johnhew/vocab-expansion.html

Return:
`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
"""
model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
if new_num_tokens is None and pad_to_multiple_of is None:
return model_embeds

Expand All @@ -2095,9 +2106,11 @@ def resize_token_embeddings(

return model_embeds

def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None):
def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None, mean_resizing=True):
old_embeddings = self.get_input_embeddings()
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of)
new_embeddings = self._get_resized_embeddings(
old_embeddings, new_num_tokens, pad_to_multiple_of, mean_resizing
)
if hasattr(old_embeddings, "_hf_hook"):
hook = old_embeddings._hf_hook
add_hook_to_module(new_embeddings, hook)
Expand All @@ -2120,9 +2133,9 @@ def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None):
if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings:
old_lm_head = self.get_output_embeddings()
if isinstance(old_lm_head, torch.nn.Embedding):
new_lm_head = self._get_resized_embeddings(old_lm_head, new_num_tokens)
new_lm_head = self._get_resized_embeddings(old_lm_head, new_num_tokens, mean_resizing=mean_resizing)
else:
new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens)
new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens, mean_resizing=mean_resizing)
if hasattr(old_lm_head, "_hf_hook"):
hook = old_lm_head._hf_hook
add_hook_to_module(new_lm_head, hook)
Expand All @@ -2137,6 +2150,7 @@ def _get_resized_embeddings(
old_embeddings: nn.Embedding,
new_num_tokens: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
mean_resizing: bool = True,
) -> nn.Embedding:
"""
Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly
Expand All @@ -2159,6 +2173,14 @@ def _get_resized_embeddings(
`>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
details about this, or help on choosing the correct value for resizing, refer to this guide:
https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
mean_resizing (`bool`):
Whether to initialize the added embeddings from a multivariate normal distribution that has old embeddings' mean and
covariance or to initialize them with a normal distribution that has a mean of zero and std equals `config.initializer_range`.

Setting `mean_resizing` to `True` is useful when increasing the size of the embeddings of causal language models,
where the generated tokens' probabilities will not be affected by the added embeddings because initializing the new embeddings with the
old embeddings' mean will reduce the kl-divergence between the next token probability before and after adding the new embeddings.
Refer to this article for more information: https://nlp.stanford.edu/~johnhew/vocab-expansion.html


Return:
Expand Down Expand Up @@ -2217,8 +2239,32 @@ def _get_resized_embeddings(
dtype=old_embeddings.weight.dtype,
)

# initialize all new embeddings (in particular added tokens)
self._init_weights(new_embeddings)
if new_num_tokens > old_num_tokens and not mean_resizing:
# initialize new embeddings (in particular added tokens) with a mean of 0 and std equals `config.initializer_range`.
self._init_weights(new_embeddings)

elif new_num_tokens > old_num_tokens and mean_resizing:
# initialize new embeddings (in particular added tokens). The new embeddings will be initialized
# from a multivariate normal distribution that has old embeddings' mean and covariance.
# as described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html
logger.warning_once(
"The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. "
"As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. "
"To disable this, use `mean_resizing=False`"
)

added_num_tokens = new_num_tokens - old_num_tokens
if is_deepspeed_zero3_enabled() and not is_quantized:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all of this can be re-used no? As a "self.init_tensor" which checks if deepspeed is available, computes the covariance if not given, uses None otherwise

Copy link
Contributor Author

@abuelnasr0 abuelnasr0 Oct 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have introduced three functions:

  • self._init_added_embeddings_weights_with_mean()
  • self._init_added_lm_head_weights_with_mean() and it uses self._init_added_embeddings_weights_with_mean()
  • self._init_added_lm_head_bias_with_mean()
    This will improve code usability for our case. what do you think? I am open to any other change.

Also, I think that mean_resizing is more user-friendly and explains the whole point of the new resizing technique.

import deepspeed

with deepspeed.zero.GatheredParameters([old_embeddings.weight], modifier_rank=None):
self._init_added_embeddings_weights_with_mean(
old_embeddings, new_embeddings, old_embedding_dim, old_num_tokens, added_num_tokens
)
else:
self._init_added_embeddings_weights_with_mean(
old_embeddings, new_embeddings, old_embedding_dim, old_num_tokens, added_num_tokens
)

# Copy token embeddings from the previous weights

Expand Down Expand Up @@ -2258,7 +2304,11 @@ def _get_resized_embeddings(
return old_embeddings

def _get_resized_lm_head(
self, old_lm_head: nn.Linear, new_num_tokens: Optional[int] = None, transposed: Optional[bool] = False
self,
old_lm_head: nn.Linear,
new_num_tokens: Optional[int] = None,
transposed: Optional[bool] = False,
mean_resizing: bool = True,
) -> nn.Linear:
"""
Build a resized Linear Module from a provided old Linear Module. Increasing the size will add newly initialized
Expand All @@ -2275,6 +2325,14 @@ def _get_resized_lm_head(
`torch.nn.Linear` module of the model without doing anything. transposed (`bool`, *optional*, defaults
to `False`): Whether `old_lm_head` is transposed or not. If True `old_lm_head.size()` is `lm_head_dim,
vocab_size` else `vocab_size, lm_head_dim`.
mean_resizing (`bool`):
Whether to initialize the added embeddings from a multivariate normal distribution that has old embeddings' mean and
covariance or to initialize them with a normal distribution that has a mean of zero and std equals `config.initializer_range`.

Setting `mean_resizing` to `True` is useful when increasing the size of the embeddings of causal language models,
where the generated tokens' probabilities will not be affected by the added embeddings because initializing the new embeddings with the
old embeddings' mean will reduce the kl-divergence between the next token probability before and after adding the new embeddings.
Refer to this article for more information: https://nlp.stanford.edu/~johnhew/vocab-expansion.html

Return:
`torch.nn.Linear`: Pointer to the resized Linear Module or the old Linear Module if `new_num_tokens` is
Expand Down Expand Up @@ -2321,8 +2379,40 @@ def _get_resized_lm_head(
dtype=old_lm_head.weight.dtype,
)

# initialize new lm head (in particular added tokens)
self._init_weights(new_lm_head)
if new_num_tokens > old_num_tokens and not mean_resizing:
# initialize new embeddings (in particular added tokens) with a mean of 0 and std equals `config.initializer_range`.
self._init_weights(new_lm_head)

elif new_num_tokens > old_num_tokens and mean_resizing:
# initialize new lm_head weights (in particular added tokens). The new lm_head weights
# will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance.
# as described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html
logger.warning_once(
"The new lm_head weights will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. "
"As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. "
"To disable this, use `mean_resizing=False`"
)

added_num_tokens = new_num_tokens - old_num_tokens
if is_deepspeed_zero3_enabled() and not is_quantized:
abuelnasr0 marked this conversation as resolved.
Show resolved Hide resolved
import deepspeed

params = [old_lm_head.weight]
if has_new_lm_head_bias:
params += [old_lm_head.bias]
with deepspeed.zero.GatheredParameters(params, modifier_rank=None):
self._init_added_lm_head_weights_with_mean(
old_lm_head, new_lm_head, old_lm_head_dim, old_num_tokens, added_num_tokens, transposed
)
if has_new_lm_head_bias:
self._init_added_lm_head_bias_with_mean(old_lm_head, new_lm_head, added_num_tokens)

else:
self._init_added_lm_head_weights_with_mean(
old_lm_head, new_lm_head, old_lm_head_dim, old_num_tokens, added_num_tokens, transposed
)
if has_new_lm_head_bias:
self._init_added_lm_head_bias_with_mean(old_lm_head, new_lm_head, added_num_tokens)

num_tokens_to_copy = min(old_num_tokens, new_num_tokens)

Expand All @@ -2341,6 +2431,52 @@ def _get_resized_lm_head(

return new_lm_head

def _init_added_embeddings_weights_with_mean(
self, old_embeddings, new_embeddings, old_embedding_dim, old_num_tokens, added_num_tokens
):
old_embeddings_weight = old_embeddings.weight.data.to(torch.float32)
mean_embeddings = torch.mean(old_embeddings_weight, axis=0)
old_centered_embeddings = old_embeddings_weight - mean_embeddings
covariance = old_centered_embeddings.T @ old_centered_embeddings / old_num_tokens
if old_embedding_dim >= old_num_tokens:
# Covarince matrix must be positive definite. For edge cases, when `vocab_size` is
# smaller than `hidden_size`, covarince matrix won't be positive definite so we
# must add the eye matrix to the covarince matrix to convert it to be positive definite.
covariance = covariance + torch.eye(old_embedding_dim, device=old_embeddings.weight.device) * 1e-3
distribution = torch.distributions.multivariate_normal.MultivariateNormal(
mean_embeddings, covariance_matrix=1e-5 * covariance
)
new_embeddings.weight.data[-1 * added_num_tokens :, :] = distribution.sample(
sample_shape=(added_num_tokens,)
).to(old_embeddings.weight.dtype)

def _init_added_lm_head_weights_with_mean(
self,
old_lm_head,
new_lm_head,
old_lm_head_dim,
old_num_tokens,
added_num_tokens,
transposed=False,
):
if transposed:
# Transpose to the desired shape for the function.
new_lm_head.weight.data = new_lm_head.weight.data.T

# The same initilization logic as Embeddings.
self._init_added_embeddings_weights_with_mean(
old_lm_head, new_lm_head, old_lm_head_dim, old_num_tokens, added_num_tokens
)

if transposed:
# Transpose again to the correct shape.
new_lm_head.weight.data = new_lm_head.weight.data.T

def _init_added_lm_head_bias_with_mean(self, old_lm_head, new_lm_head, added_num_tokens):
bias_mean = torch.mean(old_lm_head.bias.data, axis=0, dtype=torch.float32)
bias_std = torch.std(old_lm_head.bias.data, axis=0).to(torch.float32)
new_lm_head.bias.data[-1 * added_num_tokens :].normal_(mean=bias_mean, std=bias_std * 1e-5)

def _copy_lm_head_original_to_resized(
self, new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
):
Expand Down
Loading
Loading