Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an option 'ALL' to include all linear layers as target modules #1295

Merged
merged 28 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
709422c
added helper function to get list of all linear layers; added tests a…
SumanthRH Dec 23, 2023
ff5d202
added bnb tests
SumanthRH Dec 23, 2023
7fa54fd
fixed issues with t5
SumanthRH Dec 24, 2023
135ae4e
style issues
SumanthRH Dec 24, 2023
0813f9a
improved lora and ia3 docs
SumanthRH Dec 24, 2023
9ce4293
fixed code to work for any output embedding layer name
SumanthRH Dec 26, 2023
9328095
style changes
SumanthRH Dec 26, 2023
764b02b
added a test for a base model without lm head
SumanthRH Dec 26, 2023
3e5a29b
added comments
SumanthRH Dec 26, 2023
d59a101
address review comments
SumanthRH Dec 29, 2023
ebf36ac
update tests
SumanthRH Dec 29, 2023
0066fb6
update tests
SumanthRH Dec 29, 2023
4fc1034
minor simplification
SumanthRH Dec 29, 2023
8fcaaf2
changed argument to all_linear
SumanthRH Dec 29, 2023
aeb4f96
minor fix to configs
SumanthRH Dec 29, 2023
2483ce5
minor edit
SumanthRH Dec 31, 2023
bd5c12d
Apply suggestions from code review
SumanthRH Jan 2, 2024
3c5e9b2
address review comments
SumanthRH Jan 2, 2024
2e462b6
added test for diffusion models
SumanthRH Jan 2, 2024
7770df9
minor edits to configs
SumanthRH Jan 2, 2024
eefce2f
spelling correction
SumanthRH Jan 2, 2024
bbc1619
Update tests/test_tuners_utils.py
SumanthRH Jan 3, 2024
6d4d873
Update src/peft/tuners/tuners_utils.py
SumanthRH Jan 3, 2024
929b090
Update src/peft/tuners/tuners_utils.py
SumanthRH Jan 3, 2024
bcc2342
address review comments
SumanthRH Jan 3, 2024
e40250c
revert back to older decorator order
SumanthRH Jan 3, 2024
f77f2a0
style changes
SumanthRH Jan 3, 2024
0287c3f
simplify logic for bnb layers
SumanthRH Jan 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/peft/tuners/ia3/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ class IA3Config(PeftConfig):
Args:
target_modules (`Union[List[str],str]`):
The names of the modules to apply (IA)³ to. If this is specified, only the modules with the specified names
will be replaced. If this is not specified, modules will be chosen according to the model architecture. If
the architecture is not known, an error will be raised -- in this case, you should specify the target
modules manually.
will be replaced. If this is specified as 'all-linear', then all linear/Conv1D modules are chosen,
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved
excluding the output layer. If this is not specified, modules will be chosen according to the model
architecture. If the architecture is not known, an error will be raised -- in this case, you should specify
the target modules manually.
feedforward_modules (`Union[List[str],str]`):
The names of the modules to be treated as feedforward modules, as in the original paper. These modules will
have (IA)^3 vectors multiplied to the input, instead of the output. feedforward_modules must be a name or a
Expand All @@ -49,7 +50,8 @@ class IA3Config(PeftConfig):
metadata={
"help": (
"List of module names or regex expression of the module names to replace with (IA)³."
"For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$'. "
"For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$'."
"This can also be a wildcard 'all-linear' which matches all linear/Conv1D layers except the output layer."
"If not specified, modules will be chosen according to the model architecture, If the architecture is "
"not known, an error will be raised -- in this case, you shoud specify the target modules manually."
),
Expand Down
10 changes: 6 additions & 4 deletions src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,10 @@ class LoraConfig(PeftConfig):
Args:
r (`int`): Lora attention dimension.
target_modules (`Optional[Union[List[str], str]]`): The names of the modules to apply LoRA to. If this is
specified, only the modules with the specified names will be replaced. If this is not specified, modules
will be chosen according to the model architecture. If the architecture is not known, an error will be
raised -- in this case, you should specify the target modules manually.
specified, only the modules with the specified names will be replaced. If this is specified as
'all-linear', then all linear/Conv1D modules are chosen, excluding the output layer. If this is not
specified, modules will be chosen according to the model architecture. If the architecture is not known, an
error will be raised -- in this case, you should specify the target modules manually.
lora_alpha (`int`): The alpha parameter for Lora scaling.
lora_dropout (`float`): The dropout probability for Lora layers.
fan_in_fan_out (`bool`): Set this to True if the layer to replace stores weight like (fan_in, fan_out).
Expand Down Expand Up @@ -87,7 +88,8 @@ class LoraConfig(PeftConfig):
metadata={
"help": (
"List of module names or regex expression of the module names to replace with LoRA."
"For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$'. "
"For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$'."
"This can also be a wildcard 'all-linear' which matches all linear/Conv1D layers except the output layer."
"If not specified, modules will be chosen according to the model architecture, If the architecture is "
"not known, an error will be raised -- in this case, you shoud specify the target modules manually."
),
Expand Down
53 changes: 52 additions & 1 deletion src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@

import torch
from torch import nn
from transformers import PreTrainedModel
from transformers.pytorch_utils import Conv1D

from peft.utils import COMMON_LAYERS_PATTERN
from peft.utils import COMMON_LAYERS_PATTERN, INCLUDE_LINEAR_LAYERS_SHORTHAND

from ..config import PeftConfig
from ..import_utils import is_bnb_available
from ..utils import ModulesToSaveWrapper, _get_submodules


Expand Down Expand Up @@ -213,6 +216,9 @@ def inject_adapter(self, model: nn.Module, adapter_name: str):

peft_config = self._prepare_adapter_config(peft_config, model_config)

# update peft_config.target_modules if required
peft_config = _maybe_include_all_linear_layers(peft_config, model)

for key in key_list:
# Check for modules_to_save in case
if _check_for_modules_to_save and any(
Expand Down Expand Up @@ -523,3 +529,48 @@ def inspect_matched_modules(tuner: BaseTuner, adapter_name: str = "default") ->
else:
module_dict["unmatched"].append(key)
return module_dict


def _maybe_include_all_linear_layers(peft_config: PeftConfig, model: nn.Module) -> PeftConfig:
"""
Helper function to update `target_modules` to all linear/Conv1D layers if provided as 'all-linear'. Adapted from
the QLoRA repository: https://github.com/artidoro/qlora/blob/main/qlora.py
"""

# if `target_modules` is a string, convert to lower case and check if it matches "all-linear"
if not (
isinstance(peft_config.target_modules, str)
and peft_config.target_modules.lower() == INCLUDE_LINEAR_LAYERS_SHORTHAND
):
return peft_config

if not isinstance(model, PreTrainedModel):
raise ValueError(
f"Only instances of PreTrainedModel are supported for the '{INCLUDE_LINEAR_LAYERS_SHORTHAND}' flag"
SumanthRH marked this conversation as resolved.
Show resolved Hide resolved
)

is_loaded_in_8bit = getattr(model, "is_loaded_in_8bit", False)
is_loaded_in_4bit = getattr(model, "is_loaded_in_4bit", False)
# match with a list of linear layer classes. this is needed as sometimes you can
# have a mix Eg. T5 with 8bit has instances of torch.nn.Linear and bnb.nn.Linear8bitLt
linear_classes = (torch.nn.Linear, Conv1D)
if is_bnb_available():
import bitsandbytes as bnb

linear_classes = (bnb.nn.Linear8bitLt,) + linear_classes if is_loaded_in_8bit else linear_classes
linear_classes = (bnb.nn.Linear4bit,) + linear_classes if is_loaded_in_4bit else linear_classes

linear_module_names = set()
for name, module in model.named_modules():
# match with all linear classes.
if isinstance(module, linear_classes):
names = name.rsplit(".", 1)[-1] # get the base name
linear_module_names.add(names)

# ignore the last classification head for text generation models
output_emb = model.get_output_embeddings()
if output_emb is not None:
last_module_name = [name for name, module in model.named_modules() if module is output_emb][0]
linear_module_names -= {last_module_name}
peft_config.target_modules = list(linear_module_names)
SumanthRH marked this conversation as resolved.
Show resolved Hide resolved
return peft_config
1 change: 1 addition & 0 deletions src/peft/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
CONFIG_NAME,
WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
INCLUDE_LINEAR_LAYERS_SHORTHAND,
_set_trainable,
bloom_model_postprocess_past_key_value,
prepare_model_for_int8_training,
Expand Down
1 change: 1 addition & 0 deletions src/peft/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,4 @@ def starcoder_model_postprocess_past_key_value(past_key_values):
SAFETENSORS_WEIGHTS_NAME = "adapter_model.safetensors"
CONFIG_NAME = "adapter_config.json"
EMBEDDING_LAYER_NAMES = ["embed_tokens", "lm_head"]
INCLUDE_LINEAR_LAYERS_SHORTHAND = "all-linear"
2 changes: 2 additions & 0 deletions src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
COMMON_LAYERS_PATTERN,
CONFIG_NAME,
EMBEDDING_LAYER_NAMES,
INCLUDE_LINEAR_LAYERS_SHORTHAND,
SAFETENSORS_WEIGHTS_NAME,
TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING,
Expand All @@ -51,6 +52,7 @@
"TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING",
"TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING",
"WEIGHTS_NAME",
"INCLUDE_LINEAR_LAYERS_SHORTHAND",
"bloom_model_postprocess_past_key_value",
"starcoder_model_postprocess_past_key_value",
]
Expand Down
153 changes: 143 additions & 10 deletions tests/test_tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,26 @@
# limitations under the License.
import unittest

from diffusers import StableDiffusionPipeline
from parameterized import parameterized
from transformers import AutoModel
from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM

from peft import IA3Config, LoraConfig, get_peft_model
from peft.tuners.tuners_utils import check_target_module_exists, inspect_matched_modules
from peft import IA3Config, LoHaConfig, LoraConfig, get_peft_model
from peft.tuners.tuners_utils import (
INCLUDE_LINEAR_LAYERS_SHORTHAND,
_maybe_include_all_linear_layers,
check_target_module_exists,
inspect_matched_modules,
)

from .testing_utils import require_bitsandbytes, require_torch_gpu

# Implements tests for regex matching logic common for all BaseTuner subclasses, and also
# tests for correct behaviour with different config kwargs for BaseTuners (Ex: feedforward for IA3, etc)

TEST_CASES = [
# Implements tests for regex matching logic common for all BaseTuner subclasses, and
# tests for correct behaviour with different config kwargs for BaseTuners (Ex: feedforward for IA3, etc) and
# tests for utlity function to include all linear layers

REGEX_TEST_CASES = [
# tuple of
# 1. key
# 2. target_modules
Expand Down Expand Up @@ -92,6 +101,53 @@
("mlp.blocks.1.bias", ["weight"], [1], ["blocks"], False),
]

MAYBE_INCLUDE_ALL_LINEAR_LAYERS_TEST_CASES = [
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved
# model_name, model_type, initial_target_modules, expected_target_modules
# test for a causal Llama model
(
"HuggingFaceH4/tiny-random-LlamaForCausalLM",
"causal",
INCLUDE_LINEAR_LAYERS_SHORTHAND,
["k_proj", "v_proj", "q_proj", "o_proj", "down_proj", "up_proj", "gate_proj"],
),
# test for a Llama model without the LM head
(
"HuggingFaceH4/tiny-random-LlamaForCausalLM",
"base",
INCLUDE_LINEAR_LAYERS_SHORTHAND,
["k_proj", "v_proj", "q_proj", "o_proj", "down_proj", "up_proj", "gate_proj"],
),
# test for gpt2 with Conv1D layers
("hf-internal-testing/tiny-random-gpt2", "causal", INCLUDE_LINEAR_LAYERS_SHORTHAND, ["c_attn", "c_proj", "c_fc"]),
# test for T5 model
(
"hf-internal-testing/tiny-random-t5",
"seq2seq",
INCLUDE_LINEAR_LAYERS_SHORTHAND,
["k", "q", "v", "o", "wi", "wo"],
),
# test for GPTNeoX. output module list should exclude classification head - which is named as "embed_out" instead of the usual "lm_head" for GPTNeoX
(
"hf-internal-testing/tiny-random-GPTNeoXForCausalLM",
"causal",
INCLUDE_LINEAR_LAYERS_SHORTHAND,
["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"],
),
]

# tests for a few args that should remain unchanged
MAYBE_INCLUDE_ALL_LINEAR_LAYERS_TEST_INTERNALS = [
# initial_target_modules, expected_target_modules
(["k_proj"], ["k_proj"]),
# test with target_modules as None
(None, None),
# test with target_modules as a regex expression
(".*(q_proj|v_proj)$", ".*(q_proj|v_proj)$"),
]

BNB_QUANTIZATIONS = [("4bit",), ("8bit",)]
BNB_TEST_CASES = [(x + y) for x in MAYBE_INCLUDE_ALL_LINEAR_LAYERS_TEST_CASES for y in BNB_QUANTIZATIONS]


class PeftCustomKwargsTester(unittest.TestCase):
r"""
Expand All @@ -101,9 +157,9 @@ class PeftCustomKwargsTester(unittest.TestCase):

"""

transformers_class = AutoModel
transformers_class_map = {"causal": AutoModelForCausalLM, "seq2seq": AutoModelForSeq2SeqLM, "base": AutoModel}

@parameterized.expand(TEST_CASES)
@parameterized.expand(REGEX_TEST_CASES)
def test_regex_matching_valid(self, key, target_modules, layers_to_transform, layers_pattern, expected_result):
# We use a LoRA Config for testing, but the regex matching function is common for all BaseTuner subclasses.
# example model_id for config initialization. key is matched only against the target_modules given, so this can be any model
Expand All @@ -123,7 +179,7 @@ def test_module_matching_lora(self):
# configs that could exist. This is okay as the method calls `check_target_module_exists` internally, which
# has been extensively tested above.
model_id = "hf-internal-testing/tiny-random-BloomForCausalLM"
model = self.transformers_class.from_pretrained(model_id)
model = AutoModel.from_pretrained(model_id)
# by default, this model matches query_key_value
config = LoraConfig()
peft_model = get_peft_model(model, config)
Expand All @@ -146,7 +202,7 @@ def test_module_matching_lora(self):

def test_feedforward_matching_ia3(self):
model_id = "hf-internal-testing/tiny-random-T5ForConditionalGeneration"
model = self.transformers_class.from_pretrained(model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
# simple example for just one t5 block for testing
config_kwargs = {
"target_modules": ".*encoder.*block.0.*(SelfAttention|EncDecAttention|DenseReluDense).(k|q|v|wo|wi)$",
Expand Down Expand Up @@ -175,3 +231,80 @@ def test_feedforward_matching_ia3(self):
self.assertTrue(module.is_feedforward)
else: # other IA3 modules should not be marked as feedforward
self.assertFalse(module.is_feedforward)

@parameterized.expand(MAYBE_INCLUDE_ALL_LINEAR_LAYERS_TEST_CASES)
def test_maybe_include_all_linear_layers_lora(
self, model_id, model_type, initial_target_modules, expected_target_modules
):
model = self.transformers_class_map[model_type].from_pretrained(model_id)
config_cls = LoraConfig
self._check_match_with_expected_target_modules(
model_id, model, config_cls, initial_target_modules, expected_target_modules
)

@parameterized.expand(BNB_TEST_CASES)
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved
@require_torch_gpu
@require_bitsandbytes
def test_maybe_include_all_linear_layers_lora_bnb(
self, model_id, model_type, initial_target_modules, expected_target_modules, quantization
):
if quantization == "4bit":
config_kwargs = {"load_in_4bit": True}
elif quantization == "8bit":
config_kwargs = {"load_in_8bit": True}
model = self.transformers_class_map[model_type].from_pretrained(model_id, device_map="auto", **config_kwargs)
config_cls = LoraConfig
self._check_match_with_expected_target_modules(
model_id, model, config_cls, initial_target_modules, expected_target_modules
)

def _check_match_with_expected_target_modules(
self, model_id, model, config_cls, initial_target_modules, expected_target_modules
):
"""
Helper function for the test for `_maybe_include_all_linear_layers`
"""
actual_config = config_cls(base_model_name_or_path=model_id, target_modules=initial_target_modules)
expected_config = config_cls(base_model_name_or_path=model_id, target_modules=expected_target_modules)
actual_model = get_peft_model(model, peft_config=actual_config)
SumanthRH marked this conversation as resolved.
Show resolved Hide resolved
expected_model = get_peft_model(model, peft_config=expected_config)
expected_model_module_dict = dict(expected_model.named_modules())
# compare the two models and assert that all layers are of the same type
for name, actual_module in actual_model.named_modules():
expected_module = expected_model_module_dict[name]
self.assertEqual(type(actual_module), type(expected_module))

def test_maybe_include_all_linear_layers_ia3_loha(self):
model_id, initial_target_modules, expected_target_modules = (
"HuggingFaceH4/tiny-random-LlamaForCausalLM",
INCLUDE_LINEAR_LAYERS_SHORTHAND,
["k_proj", "v_proj", "q_proj", "o_proj", "down_proj", "up_proj", "gate_proj"],
)
model = AutoModelForCausalLM.from_pretrained(model_id)
config_cls = IA3Config
SumanthRH marked this conversation as resolved.
Show resolved Hide resolved
for config_cls in [IA3Config, LoHaConfig]:
self._check_match_with_expected_target_modules(
model_id, model, config_cls, initial_target_modules, expected_target_modules
)

@parameterized.expand(MAYBE_INCLUDE_ALL_LINEAR_LAYERS_TEST_INTERNALS)
def test_maybe_include_all_linear_layers_internals(self, initial_target_modules, expected_target_modules):
model_id = "HuggingFaceH4/tiny-random-LlamaForCausalLM"
model = AutoModelForCausalLM.from_pretrained(model_id)
config = LoraConfig(base_model_name_or_path=model_id, target_modules=initial_target_modules)
new_config = _maybe_include_all_linear_layers(config, model)
if isinstance(expected_target_modules, list):
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved
# assert that expected and actual target_modules have the same items
self.assertCountEqual(new_config.target_modules, expected_target_modules)
else:
self.assertEqual(new_config.target_modules, expected_target_modules)

def test_maybe_include_all_linear_layers_diffusion(self):
model_id = "hf-internal-testing/tiny-stable-diffusion-torch"
model = StableDiffusionPipeline.from_pretrained(model_id)
config = LoraConfig(base_model_name_or_path=model_id, target_modules="all-linear")
with self.assertRaisesRegex(
ValueError,
f"Only instances of PreTrainedModel are supported for the '{INCLUDE_LINEAR_LAYERS_SHORTHAND}' flag",
):
model.unet = get_peft_model(model.unet, config)
Loading