From e71c4f9ef69eb045ba2ab49697f32366640ed9a7 Mon Sep 17 00:00:00 2001
From: "li.yunhao"
Date: Fri, 22 Aug 2025 21:10:25 +0800
Subject: [PATCH] fix(transformerblock): conditionally initialize cross
attention components
Initialize cross attention layers only when with_cross_attention is True to avoid unnecessary computation and memory usage
Signed-off-by: li.yunhao
---
monai/networks/blocks/transformerblock.py | 20 ++++++++++----------
1 file changed, 10 insertions(+), 10 deletions(-)
diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py
index 6f0da73e7b..03b722a731 100644
--- a/monai/networks/blocks/transformerblock.py
+++ b/monai/networks/blocks/transformerblock.py
@@ -79,16 +79,16 @@ def __init__(
)
self.norm2 = nn.LayerNorm(hidden_size)
self.with_cross_attention = with_cross_attention
-
- self.norm_cross_attn = nn.LayerNorm(hidden_size)
- self.cross_attn = CrossAttentionBlock(
- hidden_size=hidden_size,
- num_heads=num_heads,
- dropout_rate=dropout_rate,
- qkv_bias=qkv_bias,
- causal=False,
- use_flash_attention=use_flash_attention,
- )
+ if with_cross_attention:
+ self.norm_cross_attn = nn.LayerNorm(hidden_size)
+ self.cross_attn = CrossAttentionBlock(
+ hidden_size=hidden_size,
+ num_heads=num_heads,
+ dropout_rate=dropout_rate,
+ qkv_bias=qkv_bias,
+ causal=False,
+ use_flash_attention=use_flash_attention,
+ )
def forward(
self, x: torch.Tensor, context: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None