Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoRA] Allow applying LoRA at different stages #429

Merged
merged 7 commits into from
Jun 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/conceptual_guides/lora.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ As with other methods supported by PEFT, to fine-tune a model using LoRA, you ne
- `alpha`: LoRA scaling factor.
- `bias`: Specifies if the `bias` parameters should be trained. Can be `'none'`, `'all'` or `'lora_only'`.
- `modules_to_save`: List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint. These typically include model's custom head that is randomly initialized for the fine-tuning task.
- `layers_to_transform`: List of layers to be transformed by LoRA. If not specified, all layers in `target_modules` are transformed.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, this is super clean! ✨

Should we also add these two parameters to the docstring here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! Yes will add it now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added it in eb16fc3 !

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice, thank you!

- `layers_pattern`: Pattern to match layer names in `target_modules`, if `layers_to_transform` is specified. By default `PeftModel` will look at common layer pattern (`layers`, `h`, `blocks`, etc.), use it for exotic and custom models.

## LoRA examples

Expand Down
41 changes: 41 additions & 0 deletions src/peft/tuners/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from ..import_utils import is_bnb_4bit_available, is_bnb_available
from ..utils import (
COMMON_LAYERS_PATTERN,
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING,
ModulesToSaveWrapper,
PeftConfig,
Expand Down Expand Up @@ -55,6 +56,13 @@ class LoraConfig(PeftConfig):
bias (`str`): Bias type for Lora. Can be 'none', 'all' or 'lora_only'
modules_to_save (`List[str]`):List of modules apart from LoRA layers to be set as trainable
and saved in the final checkpoint.
layers_to_transform (`Union[List[int],int]`):
The layer indexes to transform, if this argument is specified, it will apply the LoRA transformations on
the layer indexes that are specified in this list. If a single integer is passed, it will apply the LoRA
transformations on the layer at this index.
layers_pattern (`str`):
The layer pattern name, used only if `layers_to_transform` is different from `None` and if the layer
pattern is not in the common layers pattern.
"""

r: int = field(default=8, metadata={"help": "Lora attention dimension"})
Expand Down Expand Up @@ -84,6 +92,18 @@ class LoraConfig(PeftConfig):
default=True,
metadata={"help": "Whether to initialize the weights of the Lora layers."},
)
layers_to_transform: Optional[Union[List, int]] = field(
default=None,
metadata={
"help": "The layer indexes to transform, is this argument is specified, PEFT will transform only the layers indexes that are specified inside this list. If a single integer is passed, PEFT will transform only the layer at this index."
},
)
layers_pattern: Optional[str] = field(
default=None,
metadata={
"help": "The layer pattern name, used only if `layers_to_transform` is different to None and if the layer pattern is not in the common layers pattern."
},
)

def __post_init__(self):
self.peft_type = PeftType.LORA
Expand Down Expand Up @@ -185,11 +205,32 @@ def _find_and_replace(self, adapter_name):
"init_lora_weights": lora_config.init_lora_weights,
}
key_list = [key for key, _ in self.model.named_modules()]
is_using_layer_indexes = getattr(lora_config, "layers_to_transform", None) is not None
layer_indexing_pattern = getattr(lora_config, "layers_pattern", None)

for key in key_list:
if isinstance(lora_config.target_modules, str):
target_module_found = re.fullmatch(lora_config.target_modules, key)
else:
target_module_found = any(key.endswith(target_key) for target_key in lora_config.target_modules)

if is_using_layer_indexes and target_module_found:
layers_pattern = COMMON_LAYERS_PATTERN if layer_indexing_pattern is None else layer_indexing_pattern
layers_pattern = [layers_pattern] if isinstance(layers_pattern, str) else layers_pattern

for pattern in layers_pattern:
layer_index = re.match(f".*.{pattern}\.(\d+)\.*", key)
if layer_index is not None:
layer_index = int(layer_index.group(1))
if isinstance(lora_config.layers_to_transform, int):
target_module_found = layer_index == lora_config.layers_to_transform
else:
target_module_found = layer_index in lora_config.layers_to_transform

break
else:
target_module_found = False

if target_module_found:
if not is_target_modules_in_base_model:
is_target_modules_in_base_model = True
Expand Down
1 change: 1 addition & 0 deletions src/peft/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING,
COMMON_LAYERS_PATTERN,
CONFIG_NAME,
WEIGHTS_NAME,
_set_trainable,
Expand Down
2 changes: 2 additions & 0 deletions src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,8 @@ def transpose(weight, fan_in_fan_out):
"chatglm": ["query_key_value"],
}

COMMON_LAYERS_PATTERN = ["layers", "h", "block", "blocks"]

TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING = {
"t5": ["q", "k", "v", "o", "wi", "wo"],
"mt5": ["q", "k", "v", "o", "wi_0", "wi_1", "wo"],
Expand Down
4 changes: 4 additions & 0 deletions tests/test_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ def test_generate_half_prec(self, test_name, model_id, config_cls, config_kwargs
def test_training_decoders(self, test_name, model_id, config_cls, config_kwargs):
self._test_training(model_id, config_cls, config_kwargs)

@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_training_decoders_layer_indexing(self, test_name, model_id, config_cls, config_kwargs):
self._test_training_layer_indexing(model_id, config_cls, config_kwargs)

@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_training_decoders_gradient_checkpointing(self, test_name, model_id, config_cls, config_kwargs):
self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs)
4 changes: 4 additions & 0 deletions tests/test_encoder_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ def test_generate_half_prec(self, test_name, model_id, config_cls, config_kwargs
def test_training_encoder_decoders(self, test_name, model_id, config_cls, config_kwargs):
self._test_training(model_id, config_cls, config_kwargs)

@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_training_encoder_decoders_layer_indexing(self, test_name, model_id, config_cls, config_kwargs):
self._test_training_layer_indexing(model_id, config_cls, config_kwargs)

@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_training_encoder_decoders_gradient_checkpointing(self, test_name, model_id, config_cls, config_kwargs):
self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs)
55 changes: 55 additions & 0 deletions tests/testing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,60 @@ def _test_training(self, model_id, config_cls, config_kwargs):
else:
self.assertIsNone(param.grad)

def _test_training_layer_indexing(self, model_id, config_cls, config_kwargs):
if config_cls not in (LoraConfig,):
return

config = config_cls(
base_model_name_or_path=model_id,
layers_to_transform=[0],
**config_kwargs,
)
model = self.transformers_class.from_pretrained(model_id)
model = get_peft_model(model, config)
model = model.to(self.torch_device)

inputs = self.prepare_inputs_for_testing()

# check if `training` works
output = model(**inputs)[0]
logits = output[0]

loss = output.sum()
loss.backward()

nb_trainable = 0

for n, param in model.named_parameters():
if "lora" in n:
self.assertIsNotNone(param.grad)
nb_trainable += 1
else:
self.assertIsNone(param.grad)

with tempfile.TemporaryDirectory() as tmp_dirname:
model.save_pretrained(tmp_dirname)

model_from_pretrained = self.transformers_class.from_pretrained(model_id)
model_from_pretrained = PeftModel.from_pretrained(model_from_pretrained, tmp_dirname)

logits_from_pretrained = model_from_pretrained(**inputs)[0][0]
self.assertTrue(torch.allclose(logits, logits_from_pretrained, atol=1e-4, rtol=1e-4))

model = self.transformers_class.from_pretrained(model_id)
config = config_cls(
base_model_name_or_path=model_id,
**config_kwargs,
)
model = get_peft_model(model, config)
nb_trainable_all = 0

for n, param in model.named_parameters():
if "lora" in n:
nb_trainable_all += 1

self.assertLess(nb_trainable, nb_trainable_all)

def _test_training_gradient_checkpointing(self, model_id, config_cls, config_kwargs):
if config_cls not in (LoraConfig,):
return
Expand All @@ -367,6 +421,7 @@ def _test_training_gradient_checkpointing(self, model_id, config_cls, config_kwa

# check if `training` works
output = model(**inputs)[0]

loss = output.sum()
loss.backward()

Expand Down