- 
                Notifications
    You must be signed in to change notification settings 
- Fork 2.1k
FIX Account for attention mask being a dict, fix generate issues with gemma #2579
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FIX Account for attention mask being a dict, fix generate issues with gemma #2579
Conversation
Resolves CI errors such as this one: https://github.com/huggingface/peft/actions/runs/15481482956/job/43588020111#step:5:53182 After resolving that error, other errors can occur, but they're unrelated and investigated independently. After the transformers change in huggingface/transformers#37866, it can happen that: > Models using different types of attention in different layers (i.e. gemma3) will now have a dict returned by prepare_inputd_for_generation (one dict entry per attention type) As PEFT operates on the attention mask for prompt learning methods, we need to adjust the code for the possibility of attention_mask being a dict. Right now, I simply extract the single value if the dict is just one element. For other sizes, I just raise an error, as I don't know how to deal with that. For our tests, this is enough but we might need to find a better solution in the future.
| The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. | 
Avoid regression, even though I'm not quite sure if the old behavior is technically correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for taking this on :)
LGTM but it was a bit hard to understand. Added comments on what would've helped to have in terms of explanations to understand the code faster.
        
          
                src/peft/peft_model.py
              
                Outdated
          
        
      | model_kwargs["attention_mask"] = torch.ones((bs, total_seq_len), dtype=attention_mask.dtype) | ||
| attention_mask_2d = torch.ones((bs, total_seq_len), dtype=attention_mask.dtype) | ||
|  | ||
| # heuristic to determine if we're in prefill stage | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in the comment: please explain what the prefill stage is for and what it relates to (kv cache)
        
          
                src/peft/peft_model.py
              
                Outdated
          
        
      | # if in prefill stage, for prompt learning methods that are not prefix tuning, new tokens | ||
| # (embeddings) are inserted | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in the comment: please explain why prefix tuning is exempt here
currently it reads as if prefix tuning doesn't insert inputs at all which is confusing because it does, but, IIUC, into the kv cache of all layers.
        
          
                src/peft/peft_model.py
              
                Outdated
          
        
      | # if cache_position exists and if we're in the prefill stage | ||
| if ( | ||
| (model_kwargs.get("cache_position") is not None) | ||
| and (model_kwargs["cache_position"][0] == 0) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it might make sense to move the cache position not None and cache position == 0 part higher and make that into is_prefill with a proper comment to repeat repetition and add a bit of semantic context.
| Note: I checked locally and some Gemma3 tests are failing when run on GPU due to compile errors: 
 The error is: The reason is that Gemma uses some code that is not  I could confirm that these failures are not caused by this PR, but rather that these failures were masked by the error that is fixed in this PR. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect, it reads a lot clearer now. Thanks :)
- Bump versions - Update a comment to poin to new PR - Remove a test skip that is obsolete after huggingface#2579
- Bump versions - Update a comment to poin to new PR - Remove a test skip that is obsolete after #2579
Resolves CI errors such as this one: https://github.com/huggingface/peft/actions/runs/15481482956/job/43588020111#step:5:53182 After resolving that error, other errors can occur, but they're unrelated and investigated independently. After the transformers change in huggingface/transformers#37866, it can happen that: > Models using different types of attention in different layers (i.e. gemma3) will now have a dict returned by prepare_inputd_for_generation (one dict entry per attention type) As PEFT operates on the attention mask for prompt learning methods, we need to adjust the code for the possibility of attention_mask being a dict. Right now, I simply extract the single value if the dict is just one element. For other sizes, I just raise an error, as I don't know how to deal with that. For our tests, this is enough but we might need to find a better solution in the future.
- Bump versions - Update a comment to poin to new PR - Remove a test skip that is obsolete after huggingface#2579
See also #2580
Resolves CI errors such as this one:
https://github.com/huggingface/peft/actions/runs/15481482956/job/43588020111#step:5:53182
This PR resolves 2 issues:
1. attention mask being a dict
After the transformers change in huggingface/transformers#37866, it can happen that:
As PEFT operates on the attention mask for prompt learning methods, we need to adjust the code for the possibility of
attention_maskbeing a dict. Right now, I simply extract the single value if the dict is just one element. For other sizes, I just raise an error, as I don't know how to deal with that. For our tests, this is enough but we might need to find a better solution in the future.2. torch.compile errors during generation
#2458 fixed an issue with 4d attention masks and added gemma3 to the test suite, which uses 4d attention masks. However, the solution was insufficient, as it involves replacing the 4d attention mask with a 2d mask and handing it off to the model to create the correct 4d attention mask. The problem is that mask creation triggers an error with
torch.compileand thus needs to be performed outside of the compile context, i.e. duringprepare_inputs_for_generation. This PR now uses the same logic as transformers to do exactly that.There are still issues with prefix tuning and incorrect shapes, which may be solvable, but require further work. Similarly, there is an issue with VBLoRA because this line is not
torch.compilefriendly:peft/src/peft/tuners/vblora/layer.py
Line 191 in e67052b
The corresponding tests are skipped for now.
Finally, for these fixes to work, two more changes are needed on the transformers side:
create_masks_for_generate.For prompt learning, we remove thecache_positionargument, I'm not quite sure if there is not a better solution. Anyway, because of this it needs to be recomputed but models like gemma recompute in a way that is not torch.compile-friendly. They should use a compile friendly method instead. When I locally patch transformers to do so, the tests pass.cache_positionis no longer being removed from themodel_kwargs, thus the aforementioned problem does not occur.For these reasons, this PR stays in draft status for now and #2580 is used to make the CI green for the time being.