Skip to content
6 changes: 6 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,16 +163,19 @@
)
_import_structure["hooks"].extend(
[
"FLUX_MAG_RATIOS",
"FasterCacheConfig",
"FirstBlockCacheConfig",
"HookRegistry",
"LayerSkipConfig",
"MagCacheConfig",
"PyramidAttentionBroadcastConfig",
"SmoothedEnergyGuidanceConfig",
"TaylorSeerCacheConfig",
"apply_faster_cache",
"apply_first_block_cache",
"apply_layer_skip",
"apply_mag_cache",
"apply_pyramid_attention_broadcast",
"apply_taylorseer_cache",
]
Expand Down Expand Up @@ -900,16 +903,19 @@
TangentialClassifierFreeGuidance,
)
from .hooks import (
FLUX_MAG_RATIOS,
FasterCacheConfig,
FirstBlockCacheConfig,
HookRegistry,
LayerSkipConfig,
MagCacheConfig,
PyramidAttentionBroadcastConfig,
SmoothedEnergyGuidanceConfig,
TaylorSeerCacheConfig,
apply_faster_cache,
apply_first_block_cache,
apply_layer_skip,
apply_mag_cache,
apply_pyramid_attention_broadcast,
apply_taylorseer_cache,
)
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .hooks import HookRegistry, ModelHook
from .layer_skip import LayerSkipConfig, apply_layer_skip
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
from .mag_cache import FLUX_MAG_RATIOS, MagCacheConfig, apply_mag_cache
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig
from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache
8 changes: 7 additions & 1 deletion src/diffusers/hooks/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@
_ATTENTION_CLASSES = (Attention, MochiAttention, AttentionModuleMixin)
_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward)

_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = (
"blocks",
"transformer_blocks",
"single_transformer_blocks",
"layers",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For ZImage I am guessing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this change is not from my branch. I only appended "visual_transformer_blocks" to the list to support Kandinsky 5.0.

"visual_transformer_blocks",
)
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")

Expand Down
22 changes: 21 additions & 1 deletion src/diffusers/hooks/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class AttentionProcessorMetadata:
class TransformerBlockMetadata:
return_hidden_states_index: int = None
return_encoder_hidden_states_index: int = None
hidden_states_argument_name: str = "hidden_states"

_cls: Type = None
_cached_parameter_indices: Dict[str, int] = None
Expand Down Expand Up @@ -169,7 +170,7 @@ def _register_attention_processors_metadata():


def _register_transformer_blocks_metadata():
from ..models.attention import BasicTransformerBlock
from ..models.attention import BasicTransformerBlock, JointTransformerBlock
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
from ..models.transformers.transformer_bria import BriaTransformerBlock
from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock
Expand All @@ -184,6 +185,7 @@ def _register_transformer_blocks_metadata():
HunyuanImageSingleTransformerBlock,
HunyuanImageTransformerBlock,
)
from ..models.transformers.transformer_kandinsky import Kandinsky5TransformerDecoderBlock
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
from ..models.transformers.transformer_mochi import MochiTransformerBlock
from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock
Expand Down Expand Up @@ -331,6 +333,24 @@ def _register_transformer_blocks_metadata():
),
)

TransformerBlockRegistry.register(
model_class=JointTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=1,
return_encoder_hidden_states_index=0,
),
)
Comment on lines +336 to +342
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For SD3 I am guessing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes


# Kandinsky 5.0 (Kandinsky5TransformerDecoderBlock)
TransformerBlockRegistry.register(
model_class=Kandinsky5TransformerDecoderBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
hidden_states_argument_name="visual_embed",
),
)


# fmt: off
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
Expand Down
Loading