Skip to content

Commit

Permalink
FIX Don't target cls head with all-linear
Browse files Browse the repository at this point in the history
Fixes huggingface#2027

When using a transformers sequence classification model,
target_modules="all-linear" should not wrap the classification head with
LoRA. This is because it is already wrapped with ModulesToSave, i.e. it
will be fully fine-tuned, which is the generally desired behavior.

Before this bug fix, the classification head would be double-wrapped.
With huggingface#2028, this now raises an error. With this PR, it is avoided
completely. Still, keeping huggingface#2028 is good because it helps prevent other
situations where double-wrapping might occur due to misconfiguration.

Note that there is no fool-proof method to detect the classification
head, we have to rely on transformers convention.
  • Loading branch information
BenjaminBossan committed Aug 22, 2024
1 parent 6c832c1 commit d28f087
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 5 deletions.
18 changes: 15 additions & 3 deletions src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND
from peft.utils.constants import DUMMY_TARGET_MODULES
from peft.utils.peft_types import PeftType
from peft.utils.peft_types import PeftType, TaskType

from ..config import PeftConfig
from ..utils import ModulesToSaveWrapper, _get_submodules
Expand Down Expand Up @@ -814,11 +814,23 @@ def _maybe_include_all_linear_layers(peft_config: PeftConfig, model: nn.Module)
names = name.rsplit(".", 1)[-1] # get the base name
linear_module_names.add(names)

# ignore the last classification head for text generation models
# Try to remove linear layers that should not be targeted as best as possible. We have to rely on convention as
# there are no hard rules to detect these modules.
module_names_to_exclude = set()
output_emb = model.get_output_embeddings()
if output_emb is not None:
# ignore the last classification head for text generation models
last_module_name = [name for name, module in model.named_modules() if module is output_emb][0]
linear_module_names -= {last_module_name}
module_names_to_exclude.add(last_module_name)
elif peft_config.task_type == TaskType.SEQ_CLS:
# ignore classifier head for classification models (issue 2027)
# there is no fix name for the classifier head, "score" and "classifier" are most common
cls_head = getattr(model, "score", None) or getattr(model, "classifier", None)
if cls_head is not None:
last_module_name = [name for name, module in model.named_modules() if module is cls_head][0]
module_names_to_exclude.add(last_module_name)

linear_module_names -= module_names_to_exclude
peft_config.target_modules = linear_module_names
return peft_config

Expand Down
27 changes: 25 additions & 2 deletions tests/test_tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@
from diffusers import StableDiffusionPipeline
from parameterized import parameterized
from torch import nn
from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM, BitsAndBytesConfig
from transformers import (
AutoModel,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
BitsAndBytesConfig,
)

from peft import (
AdaptionPromptConfig,
Expand All @@ -42,7 +48,7 @@
check_target_module_exists,
inspect_matched_modules,
)
from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND, infer_device
from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND, ModulesToSaveWrapper, infer_device

from .testing_utils import require_bitsandbytes, require_non_cpu, require_torch_gpu

Expand Down Expand Up @@ -330,6 +336,23 @@ def test_maybe_include_all_linear_layers_diffusion(self):
):
model.unet = get_peft_model(model.unet, config)

def test_maybe_include_all_linear_does_not_target_classifier_head(self):
# See issue 2027
# Ensure that if a SEQ_CLS model is being used with target_modules="all-linear", the classification head is not
# targeted by the adapter layer.
model_id = "HuggingFaceH4/tiny-random-LlamaForCausalLM"
model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=10)
# sanity check
assert isinstance(model.score, nn.Linear)

config = LoraConfig(task_type="SEQ_CLS", target_modules="all-linear")
model = get_peft_model(model, config)
assert isinstance(model.base_model.score, ModulesToSaveWrapper)

# the bug was that these were lora.Linear instances
assert isinstance(model.base_model.score.original_module, nn.Linear)
assert isinstance(model.base_model.score.modules_to_save["default"], nn.Linear)


class MLP(nn.Module):
def __init__(self, bias=True):
Expand Down

0 comments on commit d28f087

Please sign in to comment.