Skip to content

Conversation

@cdutr
Copy link
Contributor

@cdutr cdutr commented Nov 14, 2025

What does this PR do?

Fixes the QwenImage encoder to properly apply encoder_hidden_states_mask when passed to the model. Previously, the mask parameter was accepted but ignored, causing padding tokens to incorrectly influence attention computation.

Changes

  • Attention mask application: Modified QwenDoubleStreamAttnProcessor2_0 to create a 2D attention mask from the 1D encoder_hidden_states_mask, properly masking text padding tokens while keeping all image tokens unmasked
  • RoPE adjustment: Updated positional embedding computation to use the full padded sequence length when a mask is present, ensuring correct position indices
  • Tests: Added comprehensive tests validating that:
    • Padding tokens are properly isolated and don't affect outputs
    • Masked outputs differ significantly from unmasked outputs
  • Benchmarks: Included performance analysis showing acceptable overhead (<20% for inference, ~19% for training scenarios)

Impact

This fix enables proper Classifier-Free Guidance (CFG) batching with variable-length text sequences, which is common when batching conditional and unconditional prompts together.

Benchmark Results

Scenario Latency (ms) Peak Memory (MB) Throughput (iter/s)
Baseline (no mask) 11.68 ± 0.23 301.5 84.70
Mask all-ones (no padding) 12.01 ± 0.26 301.5 82.34
Mask with padding (CFG) 13.86 ± 0.24 301.5 71.42

Overhead: +2.8% for mask processing without padding, +18.7% with actual padding (realistic CFG scenario)

The higher overhead with padding is expected and acceptable as it represents the cost of properly handling variable-length sequences in batched inference. This is a necessary correctness fix rather than an optimization. Test ran on RTX 4070 12GB.

Fixes #12294


Before submitting

  • This PR fixes a bug in the code
  • This PR adds tests that verify the fix
  • This PR includes benchmarks demonstrating performance impact
  • Did you write any new necessary tests?

Who can review?

@yiyixuxu @sayakpaul - Would appreciate your review, especially regarding the benchmarking approach. I used a custom benchmark rather than BenchmarkMixin because:

  1. This tests a specific bug fix (mask application) rather than optimization strategies
  2. The fix uses synthetic models to isolate the mask handling logic
  3. Standard benchmarks focus on pretrained model performance with different quantization/offloading strategies
  4. The metrics needed are different (latency distribution, throughput) vs standard format (compile/plain time comparison)

Note: The benchmark file is named benchmarking_qwenimage_mask.py (with "benchmarking" prefix) rather than benchmark_qwenimage_mask.py to prevent it from being picked up by run_all.py, since it doesn't use BenchmarkMixin and produces a different CSV schema. If you prefer, I can adapt it to use the standard format instead.

Happy to adjust the approach if you have suggestions!

Improves attention mask handling for QwenImage transformer by:
- Adding support for variable-length sequence masking
- Implementing dynamic attention mask generation from encoder_hidden_states_mask
- Ensuring RoPE embedding works correctly with padded sequences
- Adding comprehensive test coverage for masked input scenarios

Performance and flexibility benefits:
- Enables more efficient processing of sequences with padding
- Prevents padding tokens from contributing to attention computations
- Maintains model performance with minimal overhead
Improves file naming convention for the Qwen image mask performance benchmark script

Enhances code organization by using a more descriptive and consistent filename that clearly indicates the script's purpose
@sayakpaul
Copy link
Member

@cdutr it's great that you have also included the benchmarking script for fullest transparency. But we can remove that from this PR and instead have that as a GitHub gist.

The benchmark numbers make sense to me. Some comments:

  • Could we also check the performance with torch.compile?
  • Could we also see some image outputs with and without the changes introduced in this PR?

Also, I think a natural next step would be see how well this performs when combined with FA varlen. WDYT?

@naykun what do you think about the changes?

@sayakpaul sayakpaul requested a review from yiyixuxu November 20, 2025 03:59
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@cdutr
Copy link
Contributor Author

cdutr commented Nov 22, 2025

Thanks @sayakpaul! I removed the benchmark script, moved all tests to this gist.

torch.compile test

Also tested the performance with torch.compile, and results were similar, the details are below.

Tested on NVIDIA A100 80GB PCIe:

Without compile: 4.70ms per iteration
With compile:    1.93ms per iteration
Speedup:         2.43x
Compilation overhead: 7.25s (one-time cost)

Also validated on RTX 4050 6GB (laptop) with similar results (2.38x speedup). The mask implementation is fully compatible with torch.compile.

Image outputs

Tested End-to-end image generation: Successfully generated images using QwenImagePipeline and pipeline runs without errors, here is the output generated:

test_output

FA Varlen

FA varlen is the natural next step, yes! I'm interested in working on it. Should I keep iterating in this PR, or should we merge it and create a new issue?

The mask infrastructure this PR adds would translate well to varlen, instead of masking padding tokens, we'd pack only valid tokens using the same sequence length information

@sayakpaul
Copy link
Member

Thanks for the results! Looks quite nice.

FA varlen is the natural next step, yes! I'm interested in working on it. Should I keep iterating in this PR, or should we merge it and create a new issue?

I think it's fine to first merge this PR and then we work on it afterwards. We're adding easier support for Sage and FA2 in this PR: #12439, so after that's merged, it will be quite easy to work on that (thanks to the kernels lib).

Could we also check if the outputs deviate with and without the masks, i.e., the outputs we get on main and this PR branch?

@sayakpaul
Copy link
Member

sayakpaul commented Nov 22, 2025

@dxqb would you maybe interested in checking this PR out as well?

f"must match encoder_hidden_states sequence length ({text_seq_len})"
)

text_attention_mask = encoder_hidden_states_mask.bool()
Copy link

@dxqb dxqb Nov 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works if the encoder_hidden_states_mask is already bool, or a float tensor with the same semantics.
bool attention masks are enough for the usual usecase of masking unused text tokens, but if only bool attention masks are supported this should be clearly documented. also maybe change the type hint?

see https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html how float attention masks are interpreted by torch. a float 0.0 is not masked, a bool False is masked.

there are some usecases for float attention masks for text sequences, like putting an emphasis/bias on certain tokens. not very common though, so if you decide to only support bool attention masks that makes sense to me - but requires documentation.

@sayakpaul
Copy link
Member

@naykun WDYT about this PR?

@cdutr
Copy link
Contributor Author

cdutr commented Dec 8, 2025

Hey. I would like to work on the last inputs and comments. I was off last week attending a conference but should go back to it tomorrow. Any additional feedback or request is appreciated, of course

cdutr added 3 commits December 9, 2025 23:15
Updates type hint for encoder_hidden_states_mask parameter to use Optional[torch.BoolTensor]

Clarifies the expected mask type for text padding tokens in the attention processor, making the interface more precise and type-safe. Adds documentation to explain the purpose and expected format of the mask parameter.
Modifies attention mask generation for improved compatibility with scaled dot product attention (SDPA)
@cdutr
Copy link
Contributor Author

cdutr commented Dec 10, 2025

Hi, I am back, will reply the points below:

@dxqb

Comment 1: encoder_hidden_states_mask Type Hints

Updated to use torch.BoolTensor type hint and added documentation clarifying that only boolean masks are supported. Float masks have different semantics in PyTorch SDPA (0.0 = not masked, False = masked), and supporting both would require additional conversion logic. Bool masks cover the standard use case of excluding padding tokens.

Comment 2: 2D Attention Mask Optimization

Implemented. Changed from generating full [batch, seq_len, seq_len] masks to broadcasting-friendly [batch, 1, 1, seq_len] shape. This significantly reduces memory usage (e.g., from 73,728 to 384 elements for batch=2, seq_len=192) while maintaining correctness through broadcasting in SDPA.

@sayakpaul

torch.compile Performance

I've tested with torch.compile() and the results show no performance regression:

Without compile:

  • With mask: 0.0439s per iteration
  • Without mask: 0.0438s per iteration

With compile:

  • With mask: 0.0437s per iteration
  • Without mask: 0.0436s per iteration

The mask implementation is compile-friendly and maintains performance. See full results in the gist.

Output Comparison (main vs PR branch)

I ran comprehensive tests comparing outputs between main and this PR branch:

1. Numerical Comparison:

  • PR branch: Mean difference = 1.08e-04 (mask functional)
  • Main branch: Mean difference = 0.00e+00 (mask ignored)

2. Image Generation Comparison:
Using batched prompts with different lengths (natural padding scenario):

  • Prompt: "A cat" (7 tokens) batched with longer prompt (14 tokens) creates 7 padding tokens
  • With mask (PR): Padding properly excluded from attention
  • Without mask (simulating bug): Padding treated as real tokens
  • Result: 87.17% of pixels differ (mean absolute difference: 9.35/255)

With mask:

with_mask

Without mask:

without_mask

Difference:

difference_map

3. Main Branch Status:
Main branch crashes when using masks with batched variable-length prompts:

RuntimeError: The size of tensor a (14) must match the size of tensor b (7)

This is the exact bug this PR fixes.

All test scripts and results available in the gist: https://gist.github.com/cdutr/bea337e4680268168550292d7819dc2f

@sayakpaul
Copy link
Member

Thanks for your work here!

Do you think there might some overlap between #12702 and this PR?

@kashif
Copy link
Contributor

kashif commented Dec 11, 2025

@cdutr yes its somewhat duplicate work to what I have, I can remove the stuff from my PR, and then you can potentially point your PR to mine with these changes and that way we can merge both things? what do you think?

@cdutr
Copy link
Contributor Author

cdutr commented Dec 11, 2025

Hey @kashif, the comparisons on your PR are really nice! Yes, I can point my PR to yours. Let me know once you've removed the overlapping parts and I'll rebase on top of your branch.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Dec 11, 2025

@kashif
Is it ok we just go with your PR and make @cdutr a co-author in your PR instead?
the PR is already under review for qwen team, it is a big PR and will take some work to review & test for them. Let's not make it hadrer than necessarily

@cdutr sorry this happened but we are really impressed with work here and will do the MVP rewards accordingly :)

@sayakpaul
Copy link
Member

Adding on top of @yiyixuxu,

@cdutr since you have already set up the benchmarking for this setup, maybe you could contribute that @kashif's PR as well? Maybe we could document all of that inside https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwenimage?

Again, we're sorry that this happened. But regardless, we would like to honor your contributions and grant you the MVP rewards as Yiyi mentioned (after the PR is merged).

@cdutr
Copy link
Contributor Author

cdutr commented Dec 11, 2025

Thanks a lot for the MVP rewards, I really appreciate it and I'm really enjoying working on this project.

It makes total sense to move forward with merging the other PR since it's already in review. For sure, I can contribute the benchmarks and documentation there. I’ll join the discussion there

@kashif
Copy link
Contributor

kashif commented Dec 11, 2025

@cdutr you should be able to commit dir to my branch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Qwen-image] encoder_hidden_states_mask is not used

6 participants