-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FIX: Don't target the classification head when using target_modules="all-linear" #2033
Merged
BenjaminBossan
merged 3 commits into
huggingface:main
from
BenjaminBossan:fix-all-linear-dont-target-classifier-head
Aug 23, 2024
Merged
Changes from 1 commit
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,7 +23,13 @@ | |
from diffusers import StableDiffusionPipeline | ||
from parameterized import parameterized | ||
from torch import nn | ||
from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM, BitsAndBytesConfig | ||
from transformers import ( | ||
AutoModel, | ||
AutoModelForCausalLM, | ||
AutoModelForSeq2SeqLM, | ||
AutoModelForSequenceClassification, | ||
BitsAndBytesConfig, | ||
) | ||
|
||
from peft import ( | ||
AdaptionPromptConfig, | ||
|
@@ -42,7 +48,7 @@ | |
check_target_module_exists, | ||
inspect_matched_modules, | ||
) | ||
from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND, infer_device | ||
from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND, ModulesToSaveWrapper, infer_device | ||
|
||
from .testing_utils import require_bitsandbytes, require_non_cpu, require_torch_gpu | ||
|
||
|
@@ -330,6 +336,23 @@ def test_maybe_include_all_linear_layers_diffusion(self): | |
): | ||
model.unet = get_peft_model(model.unet, config) | ||
|
||
def test_maybe_include_all_linear_does_not_target_classifier_head(self): | ||
# See issue 2027 | ||
# Ensure that if a SEQ_CLS model is being used with target_modules="all-linear", the classification head is not | ||
# targeted by the adapter layer. | ||
model_id = "HuggingFaceH4/tiny-random-LlamaForCausalLM" | ||
model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=10) | ||
# sanity check | ||
assert isinstance(model.score, nn.Linear) | ||
|
||
config = LoraConfig(task_type="SEQ_CLS", target_modules="all-linear") | ||
model = get_peft_model(model, config) | ||
assert isinstance(model.base_model.score, ModulesToSaveWrapper) | ||
|
||
# the bug was that these were lora.Linear instances | ||
assert isinstance(model.base_model.score.original_module, nn.Linear) | ||
assert isinstance(model.base_model.score.modules_to_save["default"], nn.Linear) | ||
Comment on lines
+353
to
+354
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sleek! |
||
|
||
|
||
class MLP(nn.Module): | ||
def __init__(self, bias=True): | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps we could define a MAP between the task types and the attributes we know we should exclude and use that?
The advantage of that is we just have to update the MAP in case we discover more attrbiutes and it should work out nicely.
WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point about making this easier to extend. I'm not sure if a map is the right approach, because for causal LM (the
if
condition above), we use a different approach based onget_output_embeddings()
, so the map could not be used consistently for that task. But I will move["score", "classifier"]
into a constant and use that.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still think a map or any adjacent approach is better in the long term but you of course know better here.