diff --git a/mmseg/models/backbones/swin.py b/mmseg/models/backbones/swin.py index d5d11ac83a..cbf13288af 100644 --- a/mmseg/models/backbones/swin.py +++ b/mmseg/models/backbones/swin.py @@ -479,7 +479,7 @@ class SwinTransformer(BaseModule): embed_dims (int): The feature dimension. Default: 96. patch_size (int | tuple[int]): Patch size. Default: 4. window_size (int): Window size. Default: 7. - mlp_ratio (int): Ratio of mlp hidden dim to embedding dim. + mlp_ratio (int | float): Ratio of mlp hidden dim to embedding dim. Default: 4. depths (tuple[int]): Depths of each Swin Transformer stage. Default: (2, 2, 6, 2). @@ -610,7 +610,7 @@ def __init__(self, stage = SwinBlockSequence( embed_dims=in_channels, num_heads=num_heads[i], - feedforward_channels=mlp_ratio * in_channels, + feedforward_channels=int(mlp_ratio * in_channels), depth=depths[i], window_size=window_size, qkv_bias=qkv_bias,