File tree 1 file changed +2
-2
lines changed
1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -479,7 +479,7 @@ class SwinTransformer(BaseModule):
479
479
embed_dims (int): The feature dimension. Default: 96.
480
480
patch_size (int | tuple[int]): Patch size. Default: 4.
481
481
window_size (int): Window size. Default: 7.
482
- mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
482
+ mlp_ratio (int | float ): Ratio of mlp hidden dim to embedding dim.
483
483
Default: 4.
484
484
depths (tuple[int]): Depths of each Swin Transformer stage.
485
485
Default: (2, 2, 6, 2).
@@ -610,7 +610,7 @@ def __init__(self,
610
610
stage = SwinBlockSequence (
611
611
embed_dims = in_channels ,
612
612
num_heads = num_heads [i ],
613
- feedforward_channels = mlp_ratio * in_channels ,
613
+ feedforward_channels = int ( mlp_ratio * in_channels ) ,
614
614
depth = depths [i ],
615
615
window_size = window_size ,
616
616
qkv_bias = qkv_bias ,
You can’t perform that action at this time.
0 commit comments