-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FIX: Don't target the classification head when using target_modules="all-linear" #2033
FIX: Don't target the classification head when using target_modules="all-linear" #2033
Conversation
Fixes huggingface#2027 When using a transformers sequence classification model, target_modules="all-linear" should not wrap the classification head with LoRA. This is because it is already wrapped with ModulesToSave, i.e. it will be fully fine-tuned, which is the generally desired behavior. Before this bug fix, the classification head would be double-wrapped. With huggingface#2028, this now raises an error. With this PR, it is avoided completely. Still, keeping huggingface#2028 is good because it helps prevent other situations where double-wrapping might occur due to misconfiguration. Note that there is no fool-proof method to detect the classification head, we have to rely on transformers convention.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a single comment regarding the design. But not a blocker.
src/peft/tuners/tuners_utils.py
Outdated
cls_head = getattr(model, "score", None) or getattr(model, "classifier", None) | ||
if cls_head is not None: | ||
last_module_name = [name for name, module in model.named_modules() if module is cls_head][0] | ||
module_names_to_exclude.add(last_module_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps we could define a MAP between the task types and the attributes we know we should exclude and use that?
EXCLUSION_MAP = {TaskType.SEQ_CLS: ["score", "classifier"], ...}
...
cls_head = None
for exclude_candidate in EXLCUSION_MAP[TaskType.SEQ_CLS]:
cls_head = getattr(model, exclude_candidate, None)
if cls_head is not None:
...
The advantage of that is we just have to update the MAP in case we discover more attrbiutes and it should work out nicely.
WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point about making this easier to extend. I'm not sure if a map is the right approach, because for causal LM (the if
condition above), we use a different approach based on get_output_embeddings()
, so the map could not be used consistently for that task. But I will move ["score", "classifier"]
into a constant and use that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still think a map or any adjacent approach is better in the long term but you of course know better here.
assert isinstance(model.base_model.score.original_module, nn.Linear) | ||
assert isinstance(model.base_model.score.modules_to_save["default"], nn.Linear) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sleek!
Fixes #2027
When using a transformers sequence classification model,
target_modules="all-linear"
should not wrap the classification head with LoRA. This is because it is already wrapped withModulesToSave
, i.e. it will be fully fine-tuned, which is the generally desired behavior.Before this bug fix, the classification head would be double-wrapped. With #2028, this now raises an error. With this PR, it is avoided completely. Still, keeping #2028 is good because it helps prevent other situations where double-wrapping might occur due to mis-configuration.
Note that there is no fool-proof method to detect the classification head, we have to rely on transformers convention.