diff --git a/docs/source/conceptual_guides/lora.mdx b/docs/source/conceptual_guides/lora.mdx index 5b18303b9e..9156375b85 100644 --- a/docs/source/conceptual_guides/lora.mdx +++ b/docs/source/conceptual_guides/lora.mdx @@ -49,6 +49,8 @@ As with other methods supported by PEFT, to fine-tune a model using LoRA, you ne - `alpha`: LoRA scaling factor. - `bias`: Specifies if the `bias` parameters should be trained. Can be `'none'`, `'all'` or `'lora_only'`. - `modules_to_save`: List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint. These typically include model's custom head that is randomly initialized for the fine-tuning task. +- `layers_to_transform`: List of layers to be transformed by LoRA. If not specified, all layers in `target_modules` are transformed. +- `layers_pattern`: Pattern to match layer names in `target_modules`, if `layers_to_transform` is specified. By default `PeftModel` will look at common layer pattern (`layers`, `h`, `blocks`, etc.), use it for exotic and custom models. ## LoRA examples diff --git a/src/peft/tuners/lora.py b/src/peft/tuners/lora.py index e3a2e6d4e7..5540dbbe01 100644 --- a/src/peft/tuners/lora.py +++ b/src/peft/tuners/lora.py @@ -26,6 +26,7 @@ from ..import_utils import is_bnb_4bit_available, is_bnb_available from ..utils import ( + COMMON_LAYERS_PATTERN, TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, ModulesToSaveWrapper, PeftConfig, @@ -55,6 +56,13 @@ class LoraConfig(PeftConfig): bias (`str`): Bias type for Lora. Can be 'none', 'all' or 'lora_only' modules_to_save (`List[str]`):List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint. + layers_to_transform (`Union[List[int],int]`): + The layer indexes to transform, if this argument is specified, it will apply the LoRA transformations on + the layer indexes that are specified in this list. If a single integer is passed, it will apply the LoRA + transformations on the layer at this index. + layers_pattern (`str`): + The layer pattern name, used only if `layers_to_transform` is different from `None` and if the layer + pattern is not in the common layers pattern. """ r: int = field(default=8, metadata={"help": "Lora attention dimension"}) @@ -84,6 +92,18 @@ class LoraConfig(PeftConfig): default=True, metadata={"help": "Whether to initialize the weights of the Lora layers."}, ) + layers_to_transform: Optional[Union[List, int]] = field( + default=None, + metadata={ + "help": "The layer indexes to transform, is this argument is specified, PEFT will transform only the layers indexes that are specified inside this list. If a single integer is passed, PEFT will transform only the layer at this index." + }, + ) + layers_pattern: Optional[str] = field( + default=None, + metadata={ + "help": "The layer pattern name, used only if `layers_to_transform` is different to None and if the layer pattern is not in the common layers pattern." + }, + ) def __post_init__(self): self.peft_type = PeftType.LORA @@ -185,11 +205,32 @@ def _find_and_replace(self, adapter_name): "init_lora_weights": lora_config.init_lora_weights, } key_list = [key for key, _ in self.model.named_modules()] + is_using_layer_indexes = getattr(lora_config, "layers_to_transform", None) is not None + layer_indexing_pattern = getattr(lora_config, "layers_pattern", None) + for key in key_list: if isinstance(lora_config.target_modules, str): target_module_found = re.fullmatch(lora_config.target_modules, key) else: target_module_found = any(key.endswith(target_key) for target_key in lora_config.target_modules) + + if is_using_layer_indexes and target_module_found: + layers_pattern = COMMON_LAYERS_PATTERN if layer_indexing_pattern is None else layer_indexing_pattern + layers_pattern = [layers_pattern] if isinstance(layers_pattern, str) else layers_pattern + + for pattern in layers_pattern: + layer_index = re.match(f".*.{pattern}\.(\d+)\.*", key) + if layer_index is not None: + layer_index = int(layer_index.group(1)) + if isinstance(lora_config.layers_to_transform, int): + target_module_found = layer_index == lora_config.layers_to_transform + else: + target_module_found = layer_index in lora_config.layers_to_transform + + break + else: + target_module_found = False + if target_module_found: if not is_target_modules_in_base_model: is_target_modules_in_base_model = True diff --git a/src/peft/utils/__init__.py b/src/peft/utils/__init__.py index 2057c51838..8bc1937a17 100644 --- a/src/peft/utils/__init__.py +++ b/src/peft/utils/__init__.py @@ -22,6 +22,7 @@ TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING, + COMMON_LAYERS_PATTERN, CONFIG_NAME, WEIGHTS_NAME, _set_trainable, diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index e9b36e3f91..f8db2b8a94 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -213,6 +213,8 @@ def transpose(weight, fan_in_fan_out): "chatglm": ["query_key_value"], } +COMMON_LAYERS_PATTERN = ["layers", "h", "block", "blocks"] + TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING = { "t5": ["q", "k", "v", "o", "wi", "wo"], "mt5": ["q", "k", "v", "o", "wi_0", "wi_1", "wo"], diff --git a/tests/test_decoder_models.py b/tests/test_decoder_models.py index f52eaf7afb..f95263078c 100644 --- a/tests/test_decoder_models.py +++ b/tests/test_decoder_models.py @@ -92,6 +92,10 @@ def test_generate_half_prec(self, test_name, model_id, config_cls, config_kwargs def test_training_decoders(self, test_name, model_id, config_cls, config_kwargs): self._test_training(model_id, config_cls, config_kwargs) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + def test_training_decoders_layer_indexing(self, test_name, model_id, config_cls, config_kwargs): + self._test_training_layer_indexing(model_id, config_cls, config_kwargs) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) def test_training_decoders_gradient_checkpointing(self, test_name, model_id, config_cls, config_kwargs): self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs) diff --git a/tests/test_encoder_decoder_models.py b/tests/test_encoder_decoder_models.py index bb761ddd4d..b0ce467abc 100644 --- a/tests/test_encoder_decoder_models.py +++ b/tests/test_encoder_decoder_models.py @@ -95,6 +95,10 @@ def test_generate_half_prec(self, test_name, model_id, config_cls, config_kwargs def test_training_encoder_decoders(self, test_name, model_id, config_cls, config_kwargs): self._test_training(model_id, config_cls, config_kwargs) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + def test_training_encoder_decoders_layer_indexing(self, test_name, model_id, config_cls, config_kwargs): + self._test_training_layer_indexing(model_id, config_cls, config_kwargs) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) def test_training_encoder_decoders_gradient_checkpointing(self, test_name, model_id, config_cls, config_kwargs): self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs) diff --git a/tests/testing_common.py b/tests/testing_common.py index b8cbf59494..96ae2ae183 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -345,6 +345,60 @@ def _test_training(self, model_id, config_cls, config_kwargs): else: self.assertIsNone(param.grad) + def _test_training_layer_indexing(self, model_id, config_cls, config_kwargs): + if config_cls not in (LoraConfig,): + return + + config = config_cls( + base_model_name_or_path=model_id, + layers_to_transform=[0], + **config_kwargs, + ) + model = self.transformers_class.from_pretrained(model_id) + model = get_peft_model(model, config) + model = model.to(self.torch_device) + + inputs = self.prepare_inputs_for_testing() + + # check if `training` works + output = model(**inputs)[0] + logits = output[0] + + loss = output.sum() + loss.backward() + + nb_trainable = 0 + + for n, param in model.named_parameters(): + if "lora" in n: + self.assertIsNotNone(param.grad) + nb_trainable += 1 + else: + self.assertIsNone(param.grad) + + with tempfile.TemporaryDirectory() as tmp_dirname: + model.save_pretrained(tmp_dirname) + + model_from_pretrained = self.transformers_class.from_pretrained(model_id) + model_from_pretrained = PeftModel.from_pretrained(model_from_pretrained, tmp_dirname) + + logits_from_pretrained = model_from_pretrained(**inputs)[0][0] + self.assertTrue(torch.allclose(logits, logits_from_pretrained, atol=1e-4, rtol=1e-4)) + + model = self.transformers_class.from_pretrained(model_id) + config = config_cls( + base_model_name_or_path=model_id, + **config_kwargs, + ) + model = get_peft_model(model, config) + nb_trainable_all = 0 + + for n, param in model.named_parameters(): + if "lora" in n: + nb_trainable_all += 1 + + self.assertLess(nb_trainable, nb_trainable_all) + def _test_training_gradient_checkpointing(self, model_id, config_cls, config_kwargs): if config_cls not in (LoraConfig,): return @@ -367,6 +421,7 @@ def _test_training_gradient_checkpointing(self, model_id, config_cls, config_kwa # check if `training` works output = model(**inputs)[0] + loss = output.sum() loss.backward()