Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
42658fa
Add Support for Z-Image.
JerryWu-code Nov 23, 2025
3e74bb2
Reformatting with make style, black & isort.
JerryWu-code Nov 23, 2025
a4b89a0
Remove init, Modify import utils, Merge forward in transformers block…
JerryWu-code Nov 24, 2025
7df350d
modified main model forward, freqs_cis left
ChrisLiu6 Nov 24, 2025
1dd587b
Merge remote-tracking branch 'JerryWu-code/z-image-dev' into fork/Jer…
ChrisLiu6 Nov 24, 2025
aae03cf
refactored to add B dim
ChrisLiu6 Nov 24, 2025
21d8130
fixed stack issue
ChrisLiu6 Nov 24, 2025
e3dfa9e
fixed modulation bug
ChrisLiu6 Nov 24, 2025
a7fa731
fixed modulation bug
ChrisLiu6 Nov 24, 2025
1e0cefe
fix bug
ChrisLiu6 Nov 24, 2025
7adaae8
remove value_from_time_aware_config
ChrisLiu6 Nov 24, 2025
5b4c907
styling
ChrisLiu6 Nov 24, 2025
2bb39f4
Fix neg embed and devide / bug; Reuse pad zero tensor; Turn cat -> re…
JerryWu-code Nov 24, 2025
71e8049
Replace padding with pad_sequence; Add gradient checkpointing.
JerryWu-code Nov 24, 2025
fbf26b7
Fix flash_attn3 in dispatch attn backend by _flash_attn_forward, repl…
JerryWu-code Nov 24, 2025
6c0c059
Fix Docstring and Make Style.
JerryWu-code Nov 24, 2025
28685dd
Revert "Fix flash_attn3 in dispatch attn backend by _flash_attn_forwa…
ChrisLiu6 Nov 25, 2025
8e391b7
update z-image docstring
ChrisLiu6 Nov 25, 2025
3b22e84
Revert attention dispatcher
ChrisLiu6 Nov 25, 2025
3d1a7aa
update z-image docstring
ChrisLiu6 Nov 25, 2025
336c5ce
styling
ChrisLiu6 Nov 25, 2025
38a89ed
Recover attention_dispatch.py with its origin impl, later would speci…
JerryWu-code Nov 25, 2025
69d61e5
Fix prev bug, and support for prompt_embeds pass in args after prompt…
JerryWu-code Nov 25, 2025
549ad57
Merge branch 'z-image-dev-ql' into z-image-dev
JerryWu-code Nov 25, 2025
1dd8f3c
Remove einop dependency.
JerryWu-code Nov 25, 2025
2f2d8c3
Merge branch 'z-image-dev' into z-image
JerryWu-code Nov 25, 2025
a74a0c4
Merge remote-tracking branch 'origin/main' into z-image
JerryWu-code Nov 25, 2025
e49a1f9
remove redundant imports & make fix-copies
ChrisLiu6 Nov 25, 2025
1048d0a
fix import
ChrisLiu6 Nov 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@
"WanAnimateTransformer3DModel",
"WanTransformer3DModel",
"WanVACETransformer3DModel",
"ZImageTransformer2DModel",
"attention_backend",
]
)
Expand Down Expand Up @@ -647,6 +648,7 @@
"WuerstchenCombinedPipeline",
"WuerstchenDecoderPipeline",
"WuerstchenPriorPipeline",
"ZImagePipeline",
]
)

Expand Down Expand Up @@ -1329,6 +1331,7 @@
WuerstchenCombinedPipeline,
WuerstchenDecoderPipeline,
WuerstchenPriorPipeline,
ZImagePipeline,
)

try:
Expand Down
20 changes: 20 additions & 0 deletions src/diffusers/hooks/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def _register_attention_processors_metadata():
from ..models.transformers.transformer_hunyuanimage import HunyuanImageAttnProcessor
from ..models.transformers.transformer_qwenimage import QwenDoubleStreamAttnProcessor2_0
from ..models.transformers.transformer_wan import WanAttnProcessor2_0
from ..models.transformers.transformer_z_image import ZSingleStreamAttnProcessor

# AttnProcessor2_0
AttentionProcessorRegistry.register(
Expand Down Expand Up @@ -158,6 +159,14 @@ def _register_attention_processors_metadata():
),
)

# ZSingleStreamAttnProcessor
AttentionProcessorRegistry.register(
model_class=ZSingleStreamAttnProcessor,
metadata=AttentionProcessorMetadata(
skip_processor_output_fn=_skip_proc_output_fn_Attention_ZSingleStreamAttnProcessor,
),
)


def _register_transformer_blocks_metadata():
from ..models.attention import BasicTransformerBlock
Expand All @@ -179,6 +188,7 @@ def _register_transformer_blocks_metadata():
from ..models.transformers.transformer_mochi import MochiTransformerBlock
from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock
from ..models.transformers.transformer_wan import WanTransformerBlock
from ..models.transformers.transformer_z_image import ZImageTransformerBlock

# BasicTransformerBlock
TransformerBlockRegistry.register(
Expand Down Expand Up @@ -312,6 +322,15 @@ def _register_transformer_blocks_metadata():
),
)

# ZImage
TransformerBlockRegistry.register(
model_class=ZImageTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
),
)


# fmt: off
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
Expand All @@ -338,4 +357,5 @@ def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, *
_skip_proc_output_fn_Attention_FluxAttnProcessor = _skip_attention___ret___hidden_states
_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0 = _skip_attention___ret___hidden_states
_skip_proc_output_fn_Attention_HunyuanImageAttnProcessor = _skip_attention___ret___hidden_states
_skip_proc_output_fn_Attention_ZSingleStreamAttnProcessor = _skip_attention___ret___hidden_states
# fmt: on
2 changes: 2 additions & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
_import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"]
_import_structure["transformers.transformer_wan_animate"] = ["WanAnimateTransformer3DModel"]
_import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"]
_import_structure["transformers.transformer_z_image"] = ["ZImageTransformer2DModel"]
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
_import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"]
Expand Down Expand Up @@ -218,6 +219,7 @@
WanAnimateTransformer3DModel,
WanTransformer3DModel,
WanVACETransformer3DModel,
ZImageTransformer2DModel,
)
from .unets import (
I2VGenXLUNet,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,4 @@
from .transformer_wan import WanTransformer3DModel
from .transformer_wan_animate import WanAnimateTransformer3DModel
from .transformer_wan_vace import WanVACETransformer3DModel
from .transformer_z_image import ZImageTransformer2DModel
Loading
Loading