Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Don't target the classification head when using target_modules="all-linear" #2033

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND
from peft.utils.constants import DUMMY_TARGET_MODULES
from peft.utils.peft_types import PeftType
from peft.utils.peft_types import PeftType, TaskType

from ..config import PeftConfig
from ..utils import ModulesToSaveWrapper, _get_submodules
Expand Down Expand Up @@ -814,11 +814,23 @@ def _maybe_include_all_linear_layers(peft_config: PeftConfig, model: nn.Module)
names = name.rsplit(".", 1)[-1] # get the base name
linear_module_names.add(names)

# ignore the last classification head for text generation models
# Try to remove linear layers that should not be targeted as best as possible. We have to rely on convention as
# there are no hard rules to detect these modules.
module_names_to_exclude = set()
output_emb = model.get_output_embeddings()
if output_emb is not None:
# ignore the last classification head for text generation models
last_module_name = [name for name, module in model.named_modules() if module is output_emb][0]
linear_module_names -= {last_module_name}
module_names_to_exclude.add(last_module_name)
elif peft_config.task_type == TaskType.SEQ_CLS:
# ignore classifier head for classification models (issue 2027)
# there is no fix name for the classifier head, "score" and "classifier" are most common
cls_head = getattr(model, "score", None) or getattr(model, "classifier", None)
if cls_head is not None:
last_module_name = [name for name, module in model.named_modules() if module is cls_head][0]
module_names_to_exclude.add(last_module_name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we could define a MAP between the task types and the attributes we know we should exclude and use that?

EXCLUSION_MAP = {TaskType.SEQ_CLS: ["score", "classifier"], ...}

...

cls_head = None
for exclude_candidate in EXLCUSION_MAP[TaskType.SEQ_CLS]:
    cls_head = getattr(model, exclude_candidate, None)

if cls_head is not None:
    ...

The advantage of that is we just have to update the MAP in case we discover more attrbiutes and it should work out nicely.

WDYT?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point about making this easier to extend. I'm not sure if a map is the right approach, because for causal LM (the if condition above), we use a different approach based on get_output_embeddings(), so the map could not be used consistently for that task. But I will move ["score", "classifier"] into a constant and use that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think a map or any adjacent approach is better in the long term but you of course know better here.


linear_module_names -= module_names_to_exclude
peft_config.target_modules = linear_module_names
return peft_config

Expand Down
27 changes: 25 additions & 2 deletions tests/test_tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@
from diffusers import StableDiffusionPipeline
from parameterized import parameterized
from torch import nn
from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM, BitsAndBytesConfig
from transformers import (
AutoModel,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
BitsAndBytesConfig,
)

from peft import (
AdaptionPromptConfig,
Expand All @@ -42,7 +48,7 @@
check_target_module_exists,
inspect_matched_modules,
)
from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND, infer_device
from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND, ModulesToSaveWrapper, infer_device

from .testing_utils import require_bitsandbytes, require_non_cpu, require_torch_gpu

Expand Down Expand Up @@ -330,6 +336,23 @@ def test_maybe_include_all_linear_layers_diffusion(self):
):
model.unet = get_peft_model(model.unet, config)

def test_maybe_include_all_linear_does_not_target_classifier_head(self):
# See issue 2027
# Ensure that if a SEQ_CLS model is being used with target_modules="all-linear", the classification head is not
# targeted by the adapter layer.
model_id = "HuggingFaceH4/tiny-random-LlamaForCausalLM"
model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=10)
# sanity check
assert isinstance(model.score, nn.Linear)

config = LoraConfig(task_type="SEQ_CLS", target_modules="all-linear")
model = get_peft_model(model, config)
assert isinstance(model.base_model.score, ModulesToSaveWrapper)

# the bug was that these were lora.Linear instances
assert isinstance(model.base_model.score.original_module, nn.Linear)
assert isinstance(model.base_model.score.modules_to_save["default"], nn.Linear)
Comment on lines +353 to +354
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sleek!



class MLP(nn.Module):
def __init__(self, bias=True):
Expand Down
Loading