File tree Expand file tree Collapse file tree 1 file changed +5
-4
lines changed Expand file tree Collapse file tree 1 file changed +5
-4
lines changed Original file line number Diff line number Diff line change @@ -193,19 +193,20 @@ def __init__(
193193 relative_coords [:, :, 1 ] += self .window_size - 1
194194 relative_coords [:, :, 0 ] *= 2 * self .window_size - 1
195195 relative_position_index = relative_coords .sum (- 1 ).view (- 1 ) # Wh*Ww*Wh*Ww
196-
196+
197197 # define a parameter table of relative position bias
198- relative_position_bias_table = torch .zeros ((2 * window_size - 1 ) * (2 * window_size - 1 ), num_heads ) # 2*Wh-1 * 2*Ww-1, nH
198+ relative_position_bias_table = torch .zeros (
199+ (2 * window_size - 1 ) * (2 * window_size - 1 ), num_heads
200+ ) # 2*Wh-1 * 2*Ww-1, nH
199201 nn .init .trunc_normal_ (relative_position_bias_table , std = 0.02 )
200-
202+
201203 relative_position_bias = relative_position_bias_table [relative_position_index ] # type: ignore[index]
202204 relative_position_bias = relative_position_bias .view (
203205 self .window_size * self .window_size , self .window_size * self .window_size , - 1
204206 )
205207 self .relative_position_bias = nn .Parameter (relative_position_bias .permute (2 , 0 , 1 ).contiguous ().unsqueeze (0 ))
206208
207209 def forward (self , x : Tensor ):
208-
209210
210211 return shifted_window_attention (
211212 x ,
You can’t perform that action at this time.
0 commit comments