Skip to content

Commit 1a5d0f8

Browse files
FIX: Don't target the classification head when using target_modules="all-linear" (#2033)
Fixes #2027 When using a transformers sequence classification model, target_modules="all-linear" should not wrap the classification head with LoRA. This is because it is already wrapped with ModulesToSave, i.e. it will be fully fine-tuned, which is the generally desired behavior. Before this bug fix, the classification head would be double-wrapped. With #2028, this now raises an error. With this PR, it is avoided completely. Still, keeping #2028 is good because it helps prevent other situations where double-wrapping might occur due to misconfiguration. Note that there is no fool-proof method to detect the classification head, we have to rely on transformers convention.
1 parent f3c7c6e commit 1a5d0f8

File tree

3 files changed

+44
-6
lines changed

3 files changed

+44
-6
lines changed

src/peft/tuners/tuners_utils.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030
from transformers.pytorch_utils import Conv1D
3131

3232
from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND
33-
from peft.utils.constants import DUMMY_TARGET_MODULES
34-
from peft.utils.peft_types import PeftType
33+
from peft.utils.constants import DUMMY_TARGET_MODULES, SEQ_CLS_HEAD_NAMES
34+
from peft.utils.peft_types import PeftType, TaskType
3535

3636
from ..config import PeftConfig
3737
from ..utils import ModulesToSaveWrapper, _get_submodules
@@ -812,11 +812,25 @@ def _maybe_include_all_linear_layers(peft_config: PeftConfig, model: nn.Module)
812812
names = name.rsplit(".", 1)[-1] # get the base name
813813
linear_module_names.add(names)
814814

815-
# ignore the last classification head for text generation models
815+
# Try to remove linear layers that should not be targeted as best as possible. We have to rely on convention as
816+
# there are no hard rules to detect these modules.
817+
module_names_to_exclude = set()
816818
output_emb = model.get_output_embeddings()
817819
if output_emb is not None:
820+
# ignore the last classification head for text generation models
818821
last_module_name = [name for name, module in model.named_modules() if module is output_emb][0]
819-
linear_module_names -= {last_module_name}
822+
module_names_to_exclude.add(last_module_name)
823+
elif peft_config.task_type == TaskType.SEQ_CLS:
824+
# ignore classifier head for classification models (issue 2027)
825+
# there is no fix name for the classifier head, so check the common ones
826+
for name in SEQ_CLS_HEAD_NAMES:
827+
cls_head = getattr(model, name, None)
828+
if cls_head is not None:
829+
last_module_name = [name for name, module in model.named_modules() if module is cls_head][0]
830+
module_names_to_exclude.add(last_module_name)
831+
break
832+
833+
linear_module_names -= module_names_to_exclude
820834
peft_config.target_modules = linear_module_names
821835
return peft_config
822836

src/peft/utils/constants.py

+1
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ def starcoder_model_postprocess_past_key_value(past_key_values):
257257
SAFETENSORS_WEIGHTS_NAME = "adapter_model.safetensors"
258258
CONFIG_NAME = "adapter_config.json"
259259
EMBEDDING_LAYER_NAMES = ["embed_tokens", "lm_head"]
260+
SEQ_CLS_HEAD_NAMES = ["score", "classifier"]
260261
INCLUDE_LINEAR_LAYERS_SHORTHAND = "all-linear"
261262
TOKENIZER_CONFIG_NAME = "tokenizer_config.json"
262263
DUMMY_TARGET_MODULES = "dummy-target-modules"

tests/test_tuners_utils.py

+25-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,13 @@
2323
from diffusers import StableDiffusionPipeline
2424
from parameterized import parameterized
2525
from torch import nn
26-
from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM, BitsAndBytesConfig
26+
from transformers import (
27+
AutoModel,
28+
AutoModelForCausalLM,
29+
AutoModelForSeq2SeqLM,
30+
AutoModelForSequenceClassification,
31+
BitsAndBytesConfig,
32+
)
2733

2834
from peft import (
2935
AdaptionPromptConfig,
@@ -42,7 +48,7 @@
4248
check_target_module_exists,
4349
inspect_matched_modules,
4450
)
45-
from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND, infer_device
51+
from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND, ModulesToSaveWrapper, infer_device
4652

4753
from .testing_utils import require_bitsandbytes, require_non_cpu, require_torch_gpu
4854

@@ -330,6 +336,23 @@ def test_maybe_include_all_linear_layers_diffusion(self):
330336
):
331337
model.unet = get_peft_model(model.unet, config)
332338

339+
def test_maybe_include_all_linear_does_not_target_classifier_head(self):
340+
# See issue 2027
341+
# Ensure that if a SEQ_CLS model is being used with target_modules="all-linear", the classification head is not
342+
# targeted by the adapter layer.
343+
model_id = "HuggingFaceH4/tiny-random-LlamaForCausalLM"
344+
model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=10)
345+
# sanity check
346+
assert isinstance(model.score, nn.Linear)
347+
348+
config = LoraConfig(task_type="SEQ_CLS", target_modules="all-linear")
349+
model = get_peft_model(model, config)
350+
assert isinstance(model.base_model.score, ModulesToSaveWrapper)
351+
352+
# the bug was that these were lora.Linear instances
353+
assert isinstance(model.base_model.score.original_module, nn.Linear)
354+
assert isinstance(model.base_model.score.modules_to_save["default"], nn.Linear)
355+
333356

334357
class MLP(nn.Module):
335358
def __init__(self, bias=True):

0 commit comments

Comments
 (0)