From e71c4f9ef69eb045ba2ab49697f32366640ed9a7 Mon Sep 17 00:00:00 2001 From: "li.yunhao" Date: Fri, 22 Aug 2025 21:10:25 +0800 Subject: [PATCH] fix(transformerblock): conditionally initialize cross attention components Initialize cross attention layers only when with_cross_attention is True to avoid unnecessary computation and memory usage Signed-off-by: li.yunhao --- monai/networks/blocks/transformerblock.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index 6f0da73e7b..03b722a731 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -79,16 +79,16 @@ def __init__( ) self.norm2 = nn.LayerNorm(hidden_size) self.with_cross_attention = with_cross_attention - - self.norm_cross_attn = nn.LayerNorm(hidden_size) - self.cross_attn = CrossAttentionBlock( - hidden_size=hidden_size, - num_heads=num_heads, - dropout_rate=dropout_rate, - qkv_bias=qkv_bias, - causal=False, - use_flash_attention=use_flash_attention, - ) + if with_cross_attention: + self.norm_cross_attn = nn.LayerNorm(hidden_size) + self.cross_attn = CrossAttentionBlock( + hidden_size=hidden_size, + num_heads=num_heads, + dropout_rate=dropout_rate, + qkv_bias=qkv_bias, + causal=False, + use_flash_attention=use_flash_attention, + ) def forward( self, x: torch.Tensor, context: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None