Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 41 additions & 2 deletions src/diffusers/models/transformers/transformer_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,11 @@ class QwenDoubleStreamAttnProcessor2_0:
"""
Attention processor for Qwen double-stream architecture, matching DoubleStreamLayerMegatron logic. This processor
implements joint attention computation where text and image streams are processed together.

Args:
encoder_hidden_states_mask (`torch.BoolTensor`, *optional*):
Boolean mask for text padding tokens. Shape: `[batch_size, text_seq_len]`. `True` indicates tokens that
should be attended to, `False` masks out padding tokens. Only boolean masks are supported.
"""

_attention_backend = None
Expand All @@ -278,7 +283,7 @@ def __call__(
attn: Attention,
hidden_states: torch.FloatTensor, # Image stream
encoder_hidden_states: torch.FloatTensor = None, # Text stream
encoder_hidden_states_mask: torch.FloatTensor = None,
encoder_hidden_states_mask: Optional[torch.BoolTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
Expand Down Expand Up @@ -330,6 +335,32 @@ def __call__(
joint_key = torch.cat([txt_key, img_key], dim=1)
joint_value = torch.cat([txt_value, img_value], dim=1)

# Convert encoder_hidden_states_mask to 2D attention mask if provided.
if encoder_hidden_states_mask is not None and attention_mask is None:
batch_size = hidden_states.shape[0]
image_seq_len = hidden_states.shape[1]
text_seq_len = encoder_hidden_states.shape[1]

if encoder_hidden_states_mask.shape[0] != batch_size:
raise ValueError(
f"encoder_hidden_states_mask batch size ({encoder_hidden_states_mask.shape[0]}) "
f"must match hidden_states batch size ({batch_size})"
)
if encoder_hidden_states_mask.shape[1] != text_seq_len:
raise ValueError(
f"encoder_hidden_states_mask sequence length ({encoder_hidden_states_mask.shape[1]}) "
f"must match encoder_hidden_states sequence length ({text_seq_len})"
)

text_attention_mask = encoder_hidden_states_mask.bool()
Copy link

@dxqb dxqb Nov 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works if the encoder_hidden_states_mask is already bool, or a float tensor with the same semantics.
bool attention masks are enough for the usual usecase of masking unused text tokens, but if only bool attention masks are supported this should be clearly documented. also maybe change the type hint?

see https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html how float attention masks are interpreted by torch. a float 0.0 is not masked, a bool False is masked.

there are some usecases for float attention masks for text sequences, like putting an emphasis/bias on certain tokens. not very common though, so if you decide to only support bool attention masks that makes sense to me - but requires documentation.

image_attention_mask = torch.ones(
(batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device
)

joint_attention_mask_1d = torch.cat([text_attention_mask, image_attention_mask], dim=1)
# broadcastable shape for SDPA
attention_mask = joint_attention_mask_1d[:, None, None, :]

# Compute joint attention
joint_hidden_states = dispatch_attention_fn(
joint_query,
Expand Down Expand Up @@ -630,7 +661,15 @@ def forward(
else self.time_text_embed(timestep, guidance, hidden_states)
)

image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
# Use padded sequence length for RoPE when mask is present.
# The attention mask will handle excluding padding tokens.
if encoder_hidden_states_mask is not None:
txt_seq_lens_for_rope = [encoder_hidden_states.shape[1]] * encoder_hidden_states.shape[0]
else:
txt_seq_lens_for_rope = (
txt_seq_lens if txt_seq_lens is not None else [encoder_hidden_states.shape[1]] * encoder_hidden_states.shape[0]
)
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens_for_rope, device=hidden_states.device)

for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
Expand Down
118 changes: 118 additions & 0 deletions tests/models/transformers/test_models_transformer_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,124 @@ def test_gradient_checkpointing_is_applied(self):
expected_set = {"QwenImageTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

def test_attention_mask_with_padding(self):
"""Test that encoder_hidden_states_mask properly handles padded sequences."""
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device).eval()

batch_size = 2
height = width = 4
num_latent_channels = embedding_dim = 16
text_seq_len = 7
vae_scale_factor = 4

# Create inputs with padding
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, text_seq_len, embedding_dim)).to(torch_device)

# First sample: 5 real tokens, 2 padding
# Second sample: 3 real tokens, 4 padding
encoder_hidden_states_mask = torch.tensor(
[[1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 0, 0, 0, 0]], dtype=torch.long
).to(torch_device)

# Zero out padding in embeddings
encoder_hidden_states = encoder_hidden_states * encoder_hidden_states_mask.unsqueeze(-1).float()

timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
orig_height = height * 2 * vae_scale_factor
orig_width = width * 2 * vae_scale_factor
img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size
txt_seq_lens = encoder_hidden_states_mask.sum(dim=1).tolist()

inputs_with_mask = {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"encoder_hidden_states_mask": encoder_hidden_states_mask,
"timestep": timestep,
"img_shapes": img_shapes,
"txt_seq_lens": txt_seq_lens,
}

# Run with proper mask
with torch.no_grad():
output_with_mask = model(**inputs_with_mask).sample

# Run with all-ones mask (treating padding as real tokens)
inputs_without_mask = {
"hidden_states": hidden_states.clone(),
"encoder_hidden_states": encoder_hidden_states.clone(),
"encoder_hidden_states_mask": torch.ones_like(encoder_hidden_states_mask),
"timestep": timestep,
"img_shapes": img_shapes,
"txt_seq_lens": [text_seq_len] * batch_size,
}

with torch.no_grad():
output_without_mask = model(**inputs_without_mask).sample

# Outputs should differ when mask is applied correctly
diff = (output_with_mask - output_without_mask).abs().mean().item()
assert diff > 1e-5, f"Mask appears to be ignored (diff={diff})"

def test_attention_mask_padding_isolation(self):
"""Test that changing padding content doesn't affect output when mask is used."""
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device).eval()

batch_size = 2
height = width = 4
num_latent_channels = embedding_dim = 16
text_seq_len = 7
vae_scale_factor = 4

# Create inputs
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, text_seq_len, embedding_dim)).to(torch_device)
encoder_hidden_states_mask = torch.tensor(
[[1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 0, 0, 0, 0]], dtype=torch.long
).to(torch_device)

timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
orig_height = height * 2 * vae_scale_factor
orig_width = width * 2 * vae_scale_factor
img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size
txt_seq_lens = encoder_hidden_states_mask.sum(dim=1).tolist()

inputs1 = {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"encoder_hidden_states_mask": encoder_hidden_states_mask,
"timestep": timestep,
"img_shapes": img_shapes,
"txt_seq_lens": txt_seq_lens,
}

with torch.no_grad():
output1 = model(**inputs1).sample

# Modify padding content with large noise
encoder_hidden_states2 = encoder_hidden_states.clone()
mask = encoder_hidden_states_mask.unsqueeze(-1).float()
noise = torch.randn_like(encoder_hidden_states2) * 10.0
encoder_hidden_states2 = encoder_hidden_states2 + noise * (1 - mask)

inputs2 = {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states2,
"encoder_hidden_states_mask": encoder_hidden_states_mask,
"timestep": timestep,
"img_shapes": img_shapes,
"txt_seq_lens": txt_seq_lens,
}

with torch.no_grad():
output2 = model(**inputs2).sample

# Outputs should be nearly identical (padding is masked out)
diff = (output1 - output2).abs().mean().item()
assert diff < 1e-4, f"Padding content affected output (diff={diff})"


class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = QwenImageTransformer2DModel
Expand Down