-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Fix qwen encoder hidden states mask #12655
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Improves attention mask handling for QwenImage transformer by: - Adding support for variable-length sequence masking - Implementing dynamic attention mask generation from encoder_hidden_states_mask - Ensuring RoPE embedding works correctly with padded sequences - Adding comprehensive test coverage for masked input scenarios Performance and flexibility benefits: - Enables more efficient processing of sequences with padding - Prevents padding tokens from contributing to attention computations - Maintains model performance with minimal overhead
Improves file naming convention for the Qwen image mask performance benchmark script Enhances code organization by using a more descriptive and consistent filename that clearly indicates the script's purpose
|
@cdutr it's great that you have also included the benchmarking script for fullest transparency. But we can remove that from this PR and instead have that as a GitHub gist. The benchmark numbers make sense to me. Some comments:
Also, I think a natural next step would be see how well this performs when combined with FA varlen. WDYT? @naykun what do you think about the changes? |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Thanks @sayakpaul! I removed the benchmark script, moved all tests to this gist. torch.compile testAlso tested the performance with Tested on NVIDIA A100 80GB PCIe: Also validated on RTX 4050 6GB (laptop) with similar results (2.38x speedup). The mask implementation is fully compatible with torch.compile. Image outputsTested End-to-end image generation: Successfully generated images using QwenImagePipeline and pipeline runs without errors, here is the output generated:
FA VarlenFA varlen is the natural next step, yes! I'm interested in working on it. Should I keep iterating in this PR, or should we merge it and create a new issue? The mask infrastructure this PR adds would translate well to varlen, instead of masking padding tokens, we'd pack only valid tokens using the same sequence length information |
|
Thanks for the results! Looks quite nice.
I think it's fine to first merge this PR and then we work on it afterwards. We're adding easier support for Sage and FA2 in this PR: #12439, so after that's merged, it will be quite easy to work on that (thanks to the Could we also check if the outputs deviate with and without the masks, i.e., the outputs we get on |
|
@dxqb would you maybe interested in checking this PR out as well? |
| f"must match encoder_hidden_states sequence length ({text_seq_len})" | ||
| ) | ||
|
|
||
| text_attention_mask = encoder_hidden_states_mask.bool() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This works if the encoder_hidden_states_mask is already bool, or a float tensor with the same semantics.
bool attention masks are enough for the usual usecase of masking unused text tokens, but if only bool attention masks are supported this should be clearly documented. also maybe change the type hint?
see https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html how float attention masks are interpreted by torch. a float 0.0 is not masked, a bool False is masked.
there are some usecases for float attention masks for text sequences, like putting an emphasis/bias on certain tokens. not very common though, so if you decide to only support bool attention masks that makes sense to me - but requires documentation.
|
@naykun WDYT about this PR? |
|
Hey. I would like to work on the last inputs and comments. I was off last week attending a conference but should go back to it tomorrow. Any additional feedback or request is appreciated, of course |
Updates type hint for encoder_hidden_states_mask parameter to use Optional[torch.BoolTensor] Clarifies the expected mask type for text padding tokens in the attention processor, making the interface more precise and type-safe. Adds documentation to explain the purpose and expected format of the mask parameter.
Modifies attention mask generation for improved compatibility with scaled dot product attention (SDPA)
|
Hi, I am back, will reply the points below: @dxqbComment 1: encoder_hidden_states_mask Type HintsUpdated to use Comment 2: 2D Attention Mask OptimizationImplemented. Changed from generating full @sayakpaultorch.compile PerformanceI've tested with Without compile:
With compile:
The mask implementation is compile-friendly and maintains performance. See full results in the gist. Output Comparison (main vs PR branch)I ran comprehensive tests comparing outputs between main and this PR branch: 1. Numerical Comparison:
2. Image Generation Comparison:
With mask:
Without mask:
Difference:
3. Main Branch Status: This is the exact bug this PR fixes. All test scripts and results available in the gist: https://gist.github.com/cdutr/bea337e4680268168550292d7819dc2f |
|
Thanks for your work here! Do you think there might some overlap between #12702 and this PR? |
|
@cdutr yes its somewhat duplicate work to what I have, I can remove the stuff from my PR, and then you can potentially point your PR to mine with these changes and that way we can merge both things? what do you think? |
|
Hey @kashif, the comparisons on your PR are really nice! Yes, I can point my PR to yours. Let me know once you've removed the overlapping parts and I'll rebase on top of your branch. |
|
@kashif @cdutr sorry this happened but we are really impressed with work here and will do the MVP rewards accordingly :) |
|
Adding on top of @yiyixuxu, @cdutr since you have already set up the benchmarking for this setup, maybe you could contribute that @kashif's PR as well? Maybe we could document all of that inside https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwenimage? Again, we're sorry that this happened. But regardless, we would like to honor your contributions and grant you the MVP rewards as Yiyi mentioned (after the PR is merged). |
|
Thanks a lot for the MVP rewards, I really appreciate it and I'm really enjoying working on this project. It makes total sense to move forward with merging the other PR since it's already in review. For sure, I can contribute the benchmarks and documentation there. I’ll join the discussion there |
|
@cdutr you should be able to commit dir to my branch |




What does this PR do?
Fixes the QwenImage encoder to properly apply
encoder_hidden_states_maskwhen passed to the model. Previously, the mask parameter was accepted but ignored, causing padding tokens to incorrectly influence attention computation.Changes
QwenDoubleStreamAttnProcessor2_0to create a 2D attention mask from the 1Dencoder_hidden_states_mask, properly masking text padding tokens while keeping all image tokens unmaskedImpact
This fix enables proper Classifier-Free Guidance (CFG) batching with variable-length text sequences, which is common when batching conditional and unconditional prompts together.
Benchmark Results
Overhead: +2.8% for mask processing without padding, +18.7% with actual padding (realistic CFG scenario)
The higher overhead with padding is expected and acceptable as it represents the cost of properly handling variable-length sequences in batched inference. This is a necessary correctness fix rather than an optimization. Test ran on RTX 4070 12GB.
Fixes #12294
Before submitting
Who can review?
@yiyixuxu @sayakpaul - Would appreciate your review, especially regarding the benchmarking approach. I used a custom benchmark rather than
BenchmarkMixinbecause:Note: The benchmark file is named
benchmarking_qwenimage_mask.py(with "benchmarking" prefix) rather thanbenchmark_qwenimage_mask.pyto prevent it from being picked up byrun_all.py, since it doesn't useBenchmarkMixinand produces a different CSV schema. If you prefer, I can adapt it to use the standard format instead.Happy to adjust the approach if you have suggestions!