@@ -179,15 +179,9 @@ def __init__(
179179 self .num_heads = num_heads
180180 self .attention_dropout = attention_dropout
181181 self .dropout = dropout
182-
183182 self .qkv = nn .Linear (dim , dim * 3 , bias = qkv_bias )
184183 self .proj = nn .Linear (dim , dim , bias = proj_bias )
185184
186- # define a parameter table of relative position bias
187- self .relative_position_bias_table = nn .Parameter (
188- torch .zeros ((2 * window_size - 1 ) * (2 * window_size - 1 ), num_heads )
189- ) # 2*Wh-1 * 2*Ww-1, nH
190-
191185 # get pair-wise relative position index for each token inside the window
192186 coords_h = torch .arange (self .window_size )
193187 coords_w = torch .arange (self .window_size )
@@ -199,22 +193,25 @@ def __init__(
199193 relative_coords [:, :, 1 ] += self .window_size - 1
200194 relative_coords [:, :, 0 ] *= 2 * self .window_size - 1
201195 relative_position_index = relative_coords .sum (- 1 ).view (- 1 ) # Wh*Ww*Wh*Ww
202- self . register_buffer ( "relative_position_index" , relative_position_index )
203-
204- nn . init . trunc_normal_ ( self . relative_position_bias_table , std = 0.02 )
205-
206- def forward ( self , x : Tensor ):
207- relative_position_bias = self . relative_position_bias_table [self . relative_position_index ] # type: ignore[index]
196+
197+ # define a parameter table of relative position bias
198+ relative_position_bias_table = torch . zeros (( 2 * window_size - 1 ) * ( 2 * window_size - 1 ), num_heads ) # 2*Wh-1 * 2*Ww-1, nH
199+ nn . init . trunc_normal_ ( relative_position_bias_table , std = 0.02 )
200+
201+ relative_position_bias = relative_position_bias_table [relative_position_index ] # type: ignore[index]
208202 relative_position_bias = relative_position_bias .view (
209203 self .window_size * self .window_size , self .window_size * self .window_size , - 1
210204 )
211- relative_position_bias = relative_position_bias .permute (2 , 0 , 1 ).contiguous ().unsqueeze (0 )
205+ self .relative_position_bias = nn .Parameter (relative_position_bias .permute (2 , 0 , 1 ).contiguous ().unsqueeze (0 ))
206+
207+ def forward (self , x : Tensor ):
208+
212209
213210 return shifted_window_attention (
214211 x ,
215212 self .qkv .weight ,
216213 self .proj .weight ,
217- relative_position_bias ,
214+ self . relative_position_bias ,
218215 self .window_size ,
219216 self .num_heads ,
220217 shift_size = self .shift_size ,
0 commit comments