diff --git a/invokeai/backend/z_image/extensions/regional_prompting_extension.py b/invokeai/backend/z_image/extensions/regional_prompting_extension.py index 3bb11d5ead4..26f91749f70 100644 --- a/invokeai/backend/z_image/extensions/regional_prompting_extension.py +++ b/invokeai/backend/z_image/extensions/regional_prompting_extension.py @@ -66,12 +66,16 @@ def _prepare_regional_attn_mask( ) -> torch.Tensor | None: """Prepare a regional attention mask for Z-Image. - The mask controls which tokens can attend to each other: - - Image tokens within a region attend only to each other + This uses an 'unrestricted' image self-attention approach (similar to FLUX): + - Image tokens can attend to ALL other image tokens (unrestricted self-attention) - Image tokens attend only to their corresponding regional text - Text tokens attend only to their corresponding regional image - Text tokens attend to themselves + The unrestricted image self-attention allows the model to maintain global + coherence across regions, preventing the generation of separate/disconnected + images for each region. + Z-Image sequence order: [img_tokens, txt_tokens] Args: @@ -129,12 +133,6 @@ def _prepare_regional_attn_mask( # 3. txt attends to corresponding regional img # Reshape mask to (1, img_seq_len) for broadcasting regional_attention_mask[txt_start:txt_end, :img_seq_len] = mask_flat.view(1, img_seq_len) - - # 4. img self-attention within region - # mask @ mask.T creates pairwise attention within the masked region - regional_attention_mask[:img_seq_len, :img_seq_len] += mask_flat.view(img_seq_len, 1) @ mask_flat.view( - 1, img_seq_len - ) else: # Global prompt: allow attention to/from background regions only if background_region_mask is not None: @@ -152,10 +150,10 @@ def _prepare_regional_attn_mask( regional_attention_mask[:img_seq_len, txt_start:txt_end] = 1.0 regional_attention_mask[txt_start:txt_end, :img_seq_len] = 1.0 - # Allow background regions to attend to themselves - if background_region_mask is not None: - bg_mask = background_region_mask.view(img_seq_len, 1) - regional_attention_mask[:img_seq_len, :img_seq_len] += bg_mask @ bg_mask.T + # 4. Allow unrestricted image self-attention + # This is the key difference from the restricted approach - all image tokens + # can attend to each other, which helps maintain global coherence across regions + regional_attention_mask[:img_seq_len, :img_seq_len] = 1.0 # Convert to boolean mask regional_attention_mask = regional_attention_mask > 0.5