Skip to content

Commit 4088e8a

Browse files
Add Support for Z-Image Series (#12703)
* Add Support for Z-Image. * Reformatting with make style, black & isort. * Remove init, Modify import utils, Merge forward in transformers block, Remove once func in pipeline. * modified main model forward, freqs_cis left * refactored to add B dim * fixed stack issue * fixed modulation bug * fixed modulation bug * fix bug * remove value_from_time_aware_config * styling * Fix neg embed and devide / bug; Reuse pad zero tensor; Turn cat -> repeat; Add hint for attn processor. * Replace padding with pad_sequence; Add gradient checkpointing. * Fix flash_attn3 in dispatch attn backend by _flash_attn_forward, replace its origin implement; Add DocString in pipeline for that. * Fix Docstring and Make Style. * Revert "Fix flash_attn3 in dispatch attn backend by _flash_attn_forward, replace its origin implement; Add DocString in pipeline for that." This reverts commit fbf26b7. * update z-image docstring * Revert attention dispatcher * update z-image docstring * styling * Recover attention_dispatch.py with its origin impl, later would special commit for fa3 compatibility. * Fix prev bug, and support for prompt_embeds pass in args after prompt pre-encode as List of torch Tensor. * Remove einop dependency. * remove redundant imports & make fix-copies * fix import --------- Co-authored-by: liudongyang <liudongyang0114@gmail.com>
1 parent d33d9f6 commit 4088e8a

File tree

11 files changed

+1382
-0
lines changed

11 files changed

+1382
-0
lines changed

src/diffusers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@
271271
"WanAnimateTransformer3DModel",
272272
"WanTransformer3DModel",
273273
"WanVACETransformer3DModel",
274+
"ZImageTransformer2DModel",
274275
"attention_backend",
275276
]
276277
)
@@ -647,6 +648,7 @@
647648
"WuerstchenCombinedPipeline",
648649
"WuerstchenDecoderPipeline",
649650
"WuerstchenPriorPipeline",
651+
"ZImagePipeline",
650652
]
651653
)
652654

@@ -1329,6 +1331,7 @@
13291331
WuerstchenCombinedPipeline,
13301332
WuerstchenDecoderPipeline,
13311333
WuerstchenPriorPipeline,
1334+
ZImagePipeline,
13321335
)
13331336

13341337
try:

src/diffusers/hooks/_helpers.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def _register_attention_processors_metadata():
111111
from ..models.transformers.transformer_hunyuanimage import HunyuanImageAttnProcessor
112112
from ..models.transformers.transformer_qwenimage import QwenDoubleStreamAttnProcessor2_0
113113
from ..models.transformers.transformer_wan import WanAttnProcessor2_0
114+
from ..models.transformers.transformer_z_image import ZSingleStreamAttnProcessor
114115

115116
# AttnProcessor2_0
116117
AttentionProcessorRegistry.register(
@@ -158,6 +159,14 @@ def _register_attention_processors_metadata():
158159
),
159160
)
160161

162+
# ZSingleStreamAttnProcessor
163+
AttentionProcessorRegistry.register(
164+
model_class=ZSingleStreamAttnProcessor,
165+
metadata=AttentionProcessorMetadata(
166+
skip_processor_output_fn=_skip_proc_output_fn_Attention_ZSingleStreamAttnProcessor,
167+
),
168+
)
169+
161170

162171
def _register_transformer_blocks_metadata():
163172
from ..models.attention import BasicTransformerBlock
@@ -179,6 +188,7 @@ def _register_transformer_blocks_metadata():
179188
from ..models.transformers.transformer_mochi import MochiTransformerBlock
180189
from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock
181190
from ..models.transformers.transformer_wan import WanTransformerBlock
191+
from ..models.transformers.transformer_z_image import ZImageTransformerBlock
182192

183193
# BasicTransformerBlock
184194
TransformerBlockRegistry.register(
@@ -312,6 +322,15 @@ def _register_transformer_blocks_metadata():
312322
),
313323
)
314324

325+
# ZImage
326+
TransformerBlockRegistry.register(
327+
model_class=ZImageTransformerBlock,
328+
metadata=TransformerBlockMetadata(
329+
return_hidden_states_index=0,
330+
return_encoder_hidden_states_index=None,
331+
),
332+
)
333+
315334

316335
# fmt: off
317336
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
@@ -338,4 +357,5 @@ def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, *
338357
_skip_proc_output_fn_Attention_FluxAttnProcessor = _skip_attention___ret___hidden_states
339358
_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0 = _skip_attention___ret___hidden_states
340359
_skip_proc_output_fn_Attention_HunyuanImageAttnProcessor = _skip_attention___ret___hidden_states
360+
_skip_proc_output_fn_Attention_ZSingleStreamAttnProcessor = _skip_attention___ret___hidden_states
341361
# fmt: on

src/diffusers/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@
110110
_import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"]
111111
_import_structure["transformers.transformer_wan_animate"] = ["WanAnimateTransformer3DModel"]
112112
_import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"]
113+
_import_structure["transformers.transformer_z_image"] = ["ZImageTransformer2DModel"]
113114
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
114115
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
115116
_import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"]
@@ -218,6 +219,7 @@
218219
WanAnimateTransformer3DModel,
219220
WanTransformer3DModel,
220221
WanVACETransformer3DModel,
222+
ZImageTransformer2DModel,
221223
)
222224
from .unets import (
223225
I2VGenXLUNet,

src/diffusers/models/transformers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,4 @@
4444
from .transformer_wan import WanTransformer3DModel
4545
from .transformer_wan_animate import WanAnimateTransformer3DModel
4646
from .transformer_wan_vace import WanVACETransformer3DModel
47+
from .transformer_z_image import ZImageTransformer2DModel

0 commit comments

Comments
 (0)