Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainable Tokens: Support for Weight Tying #2399

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions docs/source/developer_guides/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,8 @@ peft_model = get_peft_model(base_model, lora_config)
The token weights are part of your adapter state dict and saved alongside the LoRA weights.
If we would have used full fine-tuning with `modules_to_save=['embed_tokens']` we would have stored the full embedding matrix in the checkpoint, leading to a much bigger file.

To give a bit of an indication how much VRAM can be saved, a rudimentary comparison of the above example was made between training the embedding matrix fully (`modules_to_save=["embed_tokens"]`), using a LoRA for the embedding matrix (`target_modules=[..., "embed_tokens"]`, rank 32) and trainable tokens (`trainable_token_indices=[...]`, 6 tokens). Trainable tokens used about as much VRAM (15,562MB vs. 15,581MB) as LoRA while being specific to the tokens and saved ~1GB of VRAM over fully training the embedding matrix.


## Merge LoRA weights into the base model

Expand Down
7 changes: 4 additions & 3 deletions docs/source/package_reference/trainable_tokens.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ The method only targets specific tokens and selectively trains the token indices
required RAM will be lower and disk memory is also significantly lower than storing the full fine-tuned embedding matrix.

Some preliminary benchmarks acquired with [this script](https://github.com/huggingface/peft/blob/main/scripts/train_memory.py)
suggest that for `gemma-2-2b` (which has a rather large embedding matrix) you can save 4.8GiB VRAM with Trainable Tokens
over fully fine-tuning the embedding matrix. While LoRA will use even less memory (-6.3GiB total over fine-tuning) it might also target
tokens you don't want to be changed. With less extreme embedding matrixes the difference might come out shorter as well.
suggest that for `gemma-2-2b` (which has a rather large embedding matrix) you can save ~4 GiB VRAM with Trainable Tokens
over fully fine-tuning the embedding matrix. While LoRA will use comparable amounts of VRAM it might also target
tokens you don't want to be changed. Note that these are just indications and varying embedding matrix sizes might skew
these numbers a bit.

Note that this method does not add tokens for you, you have to add tokens to the tokenizer yourself and resize the
embedding matrix of the model accordingly. This method will only re-train the embeddings for the tokens you specify.
Expand Down
31 changes: 23 additions & 8 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
PeftType,
TaskType,
_get_batch_size,
_get_input_embeddings_name,
_prepare_prompt_learning_config,
_set_adapter,
_set_trainable,
Expand Down Expand Up @@ -958,7 +959,8 @@ def set_additional_trainable_modules(self, peft_config, adapter_name):
if isinstance(peft_config.trainable_token_indices, dict):
target_layers = peft_config.trainable_token_indices
else:
target_layers = {"embed_tokens": peft_config.trainable_token_indices}
layer_name = _get_input_embeddings_name(self.model) or "embed_tokens"
target_layers = {layer_name: peft_config.trainable_token_indices}

if self.modules_to_save:
for target_layer in target_layers:
Expand All @@ -973,7 +975,7 @@ def set_additional_trainable_modules(self, peft_config, adapter_name):
# `ModulesToSaveWrapper`.

for target_layer, token_indices in target_layers.items():
new_training_modules = _set_trainable(
_set_trainable(
self,
adapter_name,
module_names=[target_layer],
Expand All @@ -982,14 +984,27 @@ def set_additional_trainable_modules(self, peft_config, adapter_name):
token_indices=token_indices,
)

# Handle weight-tying of output and input embeddings. Currently this only consists of failing.
# There might be the possibility that we have output weights that are tied to the input weights.
# In that case we will tie any module that wants tied weights to the token adapter to make sure that
# any modification is reflected in the tied layers as well.
model_config = BaseTuner.get_model_config(self)
if model_config.get("tie_word_embeddings", False) and isinstance(
self.model.get_input_embeddings(), TrainableTokensWrapper
if (
model_config.get("tie_word_embeddings", False)
and self.model._tied_weights_keys is not None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about a comment why this check is required?

and isinstance(self.model.get_input_embeddings(), TrainableTokensWrapper)
):
raise ValueError(
"The model uses weight-tying which is currently not supported with `trainable_token_indices`. "
"You can try disabling weight-tying but you must expect an increased memory usage."
# the embedding layer is modified and we want weight tying.
module_keys = [".".join(n.split(".")[:-1]) for n in self.model._tied_weights_keys]

token_adapter = self.model.get_input_embeddings().token_adapter
_set_trainable(
self,
adapter_name,
module_names=module_keys,
strict_module_check=True,
wrapper_cls=TrainableTokensWrapper,
token_indices=token_adapter.token_indices[adapter_name],
tied_adapter=self.model.get_input_embeddings().token_adapter,
)

def get_layer_status(self) -> list[TunerLayerStatus]:
Expand Down
21 changes: 11 additions & 10 deletions src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,12 +275,12 @@ class LoraConfig(PeftConfig):
The core module from Megatron to use, defaults to `"megatron.core"`.
trainable_token_indices (`Optional[Union[List[int], dict[str, List[int]]]]`)
Lets you specify which token indices to selectively fine-tune without requiring to re-train the whole
embedding matrix using the `peft.TrainableTokensModel` method. You can either specify a list of indices
which will then target the `embed_tokens` layer, or, if your model is using a different layer for
embedding, you can specify a dictionary where the key is the name of the embedding module and the values
are the list of token indices, e.g. `{'embed_tokens': [0, 1, ...]}`. Note that training with FSDP/DeepSpeed
might not yet be fully supported with this option enabled. Also note that models using weight-tying are
currently not supported.
embedding matrix using the `peft.TrainableTokensModel` method. You can specify token indices in two ways.
Either you specify a list of indices which will then target the model's input embedding layer (or, if not
found, `embed_tokens`). Alternatively, you can specify a dictionary where the key is the name of the
embedding module and the values are the list of token indices, e.g. `{'embed_tokens': [0, 1, ...]}`. Note
that training with FSDP/DeepSpeed might not yet be fully supported with this option enabled. Also note that
models using weight-tying are currently not supported.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adjust/delete?

loftq_config (`Optional[LoftQConfig]`):
The configuration of LoftQ. If this is not None, then LoftQ will be used to quantize the backbone weights
and initialize Lora layers. Also pass `init_lora_weights='loftq'`. Note that you should not pass a
Expand Down Expand Up @@ -444,10 +444,11 @@ class LoraConfig(PeftConfig):
metadata={
"help": (
"Lets you specify which token indices to selectively fine-tune without requiring to re-train the "
"whole embedding matrix using the `peft.TrainableTokensModel` method. You can either specify a list "
"of indices which will then target the `embed_tokens` layer, or, if your model is using a different "
"layer for embedding, you can specify a dictionary where the key is the name of the embedding module "
"and the values are the list of token indices, e.g. `{'embed_tokens': [0, 1, ...]}`. "
"whole embedding matrix using the `peft.TrainableTokensModel` method. You can specify token indices "
"in two ways. Either you specify a list of indices which will then target the model's input embedding "
"layer (or, if not found, `embed_tokens`). Alternatively, you can specify a dictionary where the key "
"is the name of the embedding module and the values are the list of token indices, e.g. "
"`{'embed_tokens': [0, 1, ...]}`. "
"Note that training with FSDP/DeepSpeed might not yet be fully supported with this option enabled. "
"Also note that models using weight-tying are currently not supported."
)
Expand Down
14 changes: 8 additions & 6 deletions src/peft/tuners/trainable_tokens/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@ class TrainableTokensConfig(PeftConfig):
token with a tokenizer, you can tokenize the string and look at the returned `input_ids`. The closer the
amount of indices is to the total amount of tokens, the less efficient this method gets.
target_modules (`Optional[Union[list[str], str]]`):
List of module names or regex expression of the module names to replace with our `TrainableTokensLayer`.
This is by default the `embed_tokens` layer. But could be multiple embedding-like layers, such as
`embedding`, `encoder.embeddings` or `decoder.embeddings`.
List of module names or regex expression of the module names to replace with our `TrainableTokensLayer`. If
not defined, it will attempt to get the model's input embedding layer if the model has a
`get_input_embeddings` method (transformer models usually do), if that fails the default is 'embed_tokens'.
Other example targets are `embedding`, `encoder.embeddings` or `decoder.embeddings`.
init_weights (`bool`):
By default the new token weights are initialized to be the same as the respective token embeddings. This
makes TrainableTokens a no-op when not trained. If set to `False` the weights will be random values. Do not
Expand All @@ -61,12 +62,13 @@ class TrainableTokensConfig(PeftConfig):
},
)
target_modules: Optional[Union[list[str], str]] = field(
default_factory=lambda: ["embed_tokens"],
default=None,
metadata={
"help": (
"List of module names or regex expression of the module names to replace with our "
"`TrainableTokensLayer`. This is by default the `embed_tokens` layer. "
"But could be multiple embedding-like layers, such as `embedding`, `encoder.embeddings` or "
"`TrainableTokensLayer`. If not defined, it will default to the model's input embedding layer if "
"the model has a `get_input_embeddings` method (transformer models usually do), if that fails the "
"default is 'embed_tokens'. Other example targets could be `embedding`, `encoder.embeddings` or "
"`decoder.embeddings`."
),
},
Expand Down
73 changes: 54 additions & 19 deletions src/peft/tuners/trainable_tokens/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,24 +37,40 @@ def __init__(
base_layer: nn.Module,
adapter_name: str,
token_indices: list[int],
tied_adapter: Optional[TrainableTokensLayer] = None,
**kwargs,
) -> None:
super().__init__()

self.base_layer = base_layer
self._active_adapter = adapter_name
self.token_indices = {}
self.kwargs = kwargs

self.tied_adapter = tied_adapter

# we store the updated weights of particular tokens and their originals. we assume
# that the count of new tokens is far smaller than the number of total tokens.
self.trainable_tokens_delta = nn.ParameterDict({})
self.trainable_tokens_original = BufferDict({})
#
# In case we have weight tying with another token adapter, we'll have no actual
# references on our own but use everything from the tied adapter.
if not self.tied_adapter:
self.trainable_tokens_delta = nn.ParameterDict({})
self.trainable_tokens_original = BufferDict({})
self.token_indices = {}
else:
self.trainable_tokens_delta = self.tied_adapter.trainable_tokens_delta
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, you mentioned that stuff like model.state_dict() already correctly works for this case thanks to transformers. Could you add a comment for that? When the model is not a transformers model, would we get the same param twice?

self.trainable_tokens_original = self.tied_adapter.trainable_tokens_original
self.token_indices = self.tied_adapter.token_indices

# Mark the weight as unmerged
self.merged_adapters = []

def update_layer(self, adapter_name, **kwargs):
if kwargs.get("tied_adapter", None):
# in this case we don't have any say, we're just following whatever the tied
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# in this case we don't have any say, we're just following whatever the tied
# in this case we don't have any because we're just following whatever the tied

?

# adpater does, so we'll just return here.
return

self.token_indices[adapter_name] = kwargs["token_indices"]
init_weights = kwargs.get("init_weights", True)

Expand Down Expand Up @@ -130,6 +146,16 @@ def unmerge(self) -> None:
originals = self.trainable_tokens_original[adapter_name].to(self.base_layer.weight)
self.base_layer.weight.data.index_copy_(dim=0, index=index, source=originals)

def get_merged_weights(self, active_adapters):
W = self.base_layer.weight

for adapter_name in active_adapters:
index = torch.tensor(self.token_indices[adapter_name]).to(W.device)
deltas = self.trainable_tokens_delta[adapter_name].to(W)
W = W.index_copy(dim=0, index=index, source=deltas)

return W

def forward_adapters(self, x: torch.Tensor, active_adapters, *args, **kwargs) -> torch.Tensor:
if self.disable_adapters or not active_adapters:
if self.merged:
Expand All @@ -140,22 +166,31 @@ def forward_adapters(self, x: torch.Tensor, active_adapters, *args, **kwargs) ->
else:
self._check_overlapping_tokens(active_adapters)

W = self.base_layer.weight

for adapter_name in active_adapters:
index = torch.tensor(self.token_indices[adapter_name]).to(W.device)
deltas = self.trainable_tokens_delta[adapter_name].to(W)
W = W.index_copy(dim=0, index=index, source=deltas)

result = F.embedding(
input=x,
weight=W,
padding_idx=self.base_layer.padding_idx,
max_norm=self.base_layer.max_norm,
norm_type=self.base_layer.norm_type,
scale_grad_by_freq=self.base_layer.scale_grad_by_freq,
sparse=self.base_layer.sparse,
)
W = self.get_merged_weights(active_adapters)

# Normally it should be very clear that we're wrapping Embedding layers but there are cases, such as
# tying weights with an LM head where the layer we wrap is a Linear layer. Therefore we must choose
# accordingly.
if isinstance(self.base_layer, torch.nn.Embedding):
result = F.embedding(
input=x,
weight=W,
padding_idx=self.base_layer.padding_idx,
max_norm=self.base_layer.max_norm,
norm_type=self.base_layer.norm_type,
scale_grad_by_freq=self.base_layer.scale_grad_by_freq,
sparse=self.base_layer.sparse,
)
elif isinstance(self.base_layer, torch.nn.Linear):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would not necessarily work with quantized models, right? I wonder if we can find a more robust way of handling this, but I'm not sure how exactly.

# Probably a tied adapter that wraps an LM head.
result = F.linear(
input=x,
weight=W,
)
else:
raise ValueError(
"TrainableTokensLayer wraps an unknown layer type, maybe you are targeting the wrong layer?"
)

return result

Expand Down
Loading
Loading