Skip to content

Conversation

@BenjaminBossan
Copy link
Member

@BenjaminBossan BenjaminBossan commented Jun 10, 2025

See also #2580

Resolves CI errors such as this one:

https://github.com/huggingface/peft/actions/runs/15481482956/job/43588020111#step:5:53182

This PR resolves 2 issues:

1. attention mask being a dict

After the transformers change in huggingface/transformers#37866, it can happen that:

Models using different types of attention in different layers (i.e. gemma3) will now have a dict returned by prepare_inputd_for_generation (one dict entry per attention type)

As PEFT operates on the attention mask for prompt learning methods, we need to adjust the code for the possibility of attention_mask being a dict. Right now, I simply extract the single value if the dict is just one element. For other sizes, I just raise an error, as I don't know how to deal with that. For our tests, this is enough but we might need to find a better solution in the future.

2. torch.compile errors during generation

#2458 fixed an issue with 4d attention masks and added gemma3 to the test suite, which uses 4d attention masks. However, the solution was insufficient, as it involves replacing the 4d attention mask with a 2d mask and handing it off to the model to create the correct 4d attention mask. The problem is that mask creation triggers an error with torch.compile and thus needs to be performed outside of the compile context, i.e. during prepare_inputs_for_generation. This PR now uses the same logic as transformers to do exactly that.

There are still issues with prefix tuning and incorrect shapes, which may be solvable, but require further work. Similarly, there is an issue with VBLoRA because this line is not torch.compile friendly:

if self.training and vblora_logits_A[0, 0].isinf().any():

The corresponding tests are skipped for now.

Finally, for these fixes to work, two more changes are needed on the transformers side:

  1. Await a new transformers release (>4.52) so that we can use create_masks_for_generate.
  2. For prompt learning, we remove the cache_position argument, I'm not quite sure if there is not a better solution. Anyway, because of this it needs to be recomputed but models like gemma recompute in a way that is not torch.compile-friendly. They should use a compile friendly method instead. When I locally patch transformers to do so, the tests pass. cache_position is no longer being removed from the model_kwargs, thus the aforementioned problem does not occur.

For these reasons, this PR stays in draft status for now and #2580 is used to make the CI green for the time being.

Resolves CI errors such as this one:

https://github.com/huggingface/peft/actions/runs/15481482956/job/43588020111#step:5:53182

After resolving that error, other errors can occur, but they're
unrelated and investigated independently.

After the transformers change in
huggingface/transformers#37866, it can happen
that:

> Models using different types of attention in different layers (i.e.
gemma3) will now have a dict returned by
prepare_inputd_for_generation (one dict entry per attention type)

As PEFT operates on the attention mask for prompt learning methods, we
need to adjust the code for the possibility of attention_mask being a
dict. Right now, I simply extract the single value if the dict is just
one element. For other sizes, I just raise an error, as I don't know how
to deal with that. For our tests, this is enough but we might need to
find a better solution in the future.
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@BenjaminBossan BenjaminBossan marked this pull request as draft June 10, 2025 09:56
@BenjaminBossan BenjaminBossan changed the title FIX Account for attention mask being a dict FIX Account for attention mask being a dict, fix generate issues with gemma Jun 11, 2025
Avoid regression, even though I'm not quite sure if the old behavior is
technically correct.
@BenjaminBossan BenjaminBossan marked this pull request as ready for review June 27, 2025 09:20
Copy link
Collaborator

@githubnemo githubnemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking this on :)

LGTM but it was a bit hard to understand. Added comments on what would've helped to have in terms of explanations to understand the code faster.

model_kwargs["attention_mask"] = torch.ones((bs, total_seq_len), dtype=attention_mask.dtype)
attention_mask_2d = torch.ones((bs, total_seq_len), dtype=attention_mask.dtype)

# heuristic to determine if we're in prefill stage
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the comment: please explain what the prefill stage is for and what it relates to (kv cache)

Comment on lines 1969 to 1970
# if in prefill stage, for prompt learning methods that are not prefix tuning, new tokens
# (embeddings) are inserted
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the comment: please explain why prefix tuning is exempt here

currently it reads as if prefix tuning doesn't insert inputs at all which is confusing because it does, but, IIUC, into the kv cache of all layers.

# if cache_position exists and if we're in the prefill stage
if (
(model_kwargs.get("cache_position") is not None)
and (model_kwargs["cache_position"][0] == 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it might make sense to move the cache position not None and cache position == 0 part higher and make that into is_prefill with a proper comment to repeat repetition and add a bit of semantic context.

@BenjaminBossan
Copy link
Member Author

Note: I checked locally and some Gemma3 tests are failing when run on GPU due to compile errors:

pytest tests/test_decoder_models.py::TestDecoderModels::test_generate[CPTConfig-config_kwargs3-hf-internal-testing/tiny-random-Gemma3ForCausalLM] tests/test_decoder_models.py::TestDecoderModels::test_generate[VBLoRAConfig-config_kwargs13-hf-internal-testing/tiny-random-Gemma3ForCausalLM] tests/test_decoder_models.py::TestDecoderModels::test_generate[PromptEncoderConfig-config_kwargs11-hf-internal-testing/tiny-random-Gemma3ForCausalLM] tests/test_decoder_models.py::TestDecoderModels::test_generate[PromptTuningConfig-config_kwargs12-hf-internal-testing/tiny-random-Gemma3ForCausalLM] tests/test_decoder_models.py::TestDecoderModels::test_generate_pos_args[CPTConfig-config_kwargs3-hf-internal-testing/tiny-random-Gemma3ForCausalLM] tests/test_decoder_models.py::TestDecoderModels::test_generate_pos_args[PromptTuningConfig-config_kwargs12-hf-internal-testing/tiny-random-Gemma3ForCausalLM] tests/test_decoder_models.py::TestDecoderModels::test_generate_pos_args[VBLoRAConfig-config_kwargs13-hf-internal-testing/tiny-random-Gemma3ForCausalLM] tests/test_decoder_models.py::TestDecoderModels::test_generate_pos_args[PromptEncoderConfig-config_kwargs11-hf-internal-testing/tiny-random-Gemma3ForCausalLM]

The error is:

E               torch._dynamo.exc.Unsupported: Data dependent operator
E                 Explanation: Operator `aten._local_scalar_dense.default` has a non-Tensor output whose value is dependent on the data of Tensor inputs.
E                 Hint: Enable tracing of data-dependent output operators with `torch._dynamo.config.capture_scalar_outputs = True`
E               
E                 Developer debug context: aten._local_scalar_dense.default
E               
E               
E               from user code:
E                  File "/home/name/work/forks/transformers/src/transformers/utils/generic.py", line 943, in wrapper
E                   output = func(self, *args, **kwargs)
E                 File "/home/name/work/forks/transformers/src/transformers/models/gemma3/modeling_gemma3.py", line 681, in forward
E                   outputs: BaseModelOutputWithPast = self.model(
E                 File "/home/name/work/forks/transformers/src/transformers/utils/generic.py", line 943, in wrapper
E                   output = func(self, *args, **kwargs)
E                 File "/home/name/work/forks/transformers/src/transformers/models/gemma3/modeling_gemma3.py", line 525, in forward
E                   cache_position = torch.arange(

The reason is that Gemma uses some code that is not torch.compile friendly to generate cache_position. If that code was switched to a compile friendly version, the tests should pass. Ping @Cyrilvallez

I could confirm that these failures are not caused by this PR, but rather that these failures were masked by the error that is fixed in this PR.

Copy link
Collaborator

@githubnemo githubnemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect, it reads a lot clearer now. Thanks :)

@BenjaminBossan BenjaminBossan merged commit 171da8e into huggingface:main Jun 27, 2025
10 of 14 checks passed
@BenjaminBossan BenjaminBossan deleted the fix-attention-mask-is-dict branch June 27, 2025 11:40
BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request Jul 3, 2025
- Bump versions
- Update a comment to poin to new PR
- Remove a test skip that is obsolete after huggingface#2579
@BenjaminBossan BenjaminBossan mentioned this pull request Jul 3, 2025
BenjaminBossan added a commit that referenced this pull request Jul 3, 2025
- Bump versions
- Update a comment to poin to new PR
- Remove a test skip that is obsolete after #2579
efraimdahl pushed a commit to efraimdahl/peft that referenced this pull request Jul 12, 2025
Resolves CI errors such as this one:

https://github.com/huggingface/peft/actions/runs/15481482956/job/43588020111#step:5:53182

After resolving that error, other errors can occur, but they're
unrelated and investigated independently.

After the transformers change in
huggingface/transformers#37866, it can happen
that:

> Models using different types of attention in different layers (i.e.
gemma3) will now have a dict returned by
prepare_inputd_for_generation (one dict entry per attention type)

As PEFT operates on the attention mask for prompt learning methods, we
need to adjust the code for the possibility of attention_mask being a
dict. Right now, I simply extract the single value if the dict is just
one element. For other sizes, I just raise an error, as I don't know how
to deal with that. For our tests, this is enough but we might need to
find a better solution in the future.
efraimdahl pushed a commit to efraimdahl/peft that referenced this pull request Jul 12, 2025
- Bump versions
- Update a comment to poin to new PR
- Remove a test skip that is obsolete after huggingface#2579
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants