@@ -126,7 +126,8 @@ def shifted_window_attention(
126126 qkv_bias : Optional [Tensor ] = None ,
127127 proj_bias : Optional [Tensor ] = None ,
128128 logit_scale : Optional [torch .Tensor ] = None ,
129- ):
129+ training : bool = True ,
130+ ) -> Tensor :
130131 """
131132 Window based multi-head self attention (W-MSA) module with relative position bias.
132133 It supports both of shifted and non-shifted window.
@@ -143,6 +144,7 @@ def shifted_window_attention(
143144 qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None.
144145 proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None.
145146 logit_scale (Tensor[out_dim], optional): Logit scale of cosine attention for Swin Transformer V2. Default: None.
147+ training (bool, optional): Training flag used by the dropout parameters. Default: True.
146148 Returns:
147149 Tensor[N, H, W, C]: The output tensor after shifted window attention.
148150 """
@@ -207,11 +209,11 @@ def shifted_window_attention(
207209 attn = attn .view (- 1 , num_heads , x .size (1 ), x .size (1 ))
208210
209211 attn = F .softmax (attn , dim = - 1 )
210- attn = F .dropout (attn , p = attention_dropout )
212+ attn = F .dropout (attn , p = attention_dropout , training = training )
211213
212214 x = attn .matmul (v ).transpose (1 , 2 ).reshape (x .size (0 ), x .size (1 ), C )
213215 x = F .linear (x , proj_weight , proj_bias )
214- x = F .dropout (x , p = dropout )
216+ x = F .dropout (x , p = dropout , training = training )
215217
216218 # reverse windows
217219 x = x .view (B , pad_H // window_size [0 ], pad_W // window_size [1 ], window_size [0 ], window_size [1 ], C )
@@ -286,7 +288,7 @@ def get_relative_position_bias(self) -> torch.Tensor:
286288 self .relative_position_bias_table , self .relative_position_index , self .window_size # type: ignore[arg-type]
287289 )
288290
289- def forward (self , x : Tensor ):
291+ def forward (self , x : Tensor ) -> Tensor :
290292 """
291293 Args:
292294 x (Tensor): Tensor with layout of [B, H, W, C]
@@ -306,6 +308,7 @@ def forward(self, x: Tensor):
306308 dropout = self .dropout ,
307309 qkv_bias = self .qkv .bias ,
308310 proj_bias = self .proj .bias ,
311+ training = self .training ,
309312 )
310313
311314
@@ -391,6 +394,7 @@ def forward(self, x: Tensor):
391394 qkv_bias = self .qkv .bias ,
392395 proj_bias = self .proj .bias ,
393396 logit_scale = self .logit_scale ,
397+ training = self .training ,
394398 )
395399
396400
0 commit comments