Skip to content

Commit ff77dfa

Browse files
authored
Merge pull request #2400 from adamjstewart/types/nn-module
Fix nn.Module type hints
2 parents 47811bc + f5c4d5c commit ff77dfa

File tree

6 files changed

+40
-40
lines changed

6 files changed

+40
-40
lines changed

timm/layers/attention2d.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Optional, Union
1+
from typing import List, Optional, Type, Union
22

33
import torch
44
from torch import nn as nn
@@ -106,7 +106,7 @@ def __init__(
106106
padding: Union[str, int, List[int]] = '',
107107
attn_drop: float = 0.,
108108
proj_drop: float = 0.,
109-
norm_layer: nn.Module = nn.BatchNorm2d,
109+
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
110110
use_bias: bool = False,
111111
):
112112
"""Initializer.

timm/models/fastvit.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#
88
import os
99
from functools import partial
10-
from typing import List, Optional, Tuple, Union
10+
from typing import List, Optional, Tuple, Type, Union
1111

1212
import torch
1313
import torch.nn as nn
@@ -54,7 +54,7 @@ def __init__(
5454
use_act: bool = True,
5555
use_scale_branch: bool = True,
5656
num_conv_branches: int = 1,
57-
act_layer: nn.Module = nn.GELU,
57+
act_layer: Type[nn.Module] = nn.GELU,
5858
) -> None:
5959
"""Construct a MobileOneBlock module.
6060
@@ -426,7 +426,7 @@ def _fuse_bn(
426426
def convolutional_stem(
427427
in_chs: int,
428428
out_chs: int,
429-
act_layer: nn.Module = nn.GELU,
429+
act_layer: Type[nn.Module] = nn.GELU,
430430
inference_mode: bool = False
431431
) -> nn.Sequential:
432432
"""Build convolutional stem with MobileOne blocks.
@@ -545,7 +545,7 @@ def __init__(
545545
stride: int,
546546
in_chs: int,
547547
embed_dim: int,
548-
act_layer: nn.Module = nn.GELU,
548+
act_layer: Type[nn.Module] = nn.GELU,
549549
lkc_use_act: bool = False,
550550
use_se: bool = False,
551551
inference_mode: bool = False,
@@ -718,7 +718,7 @@ def __init__(
718718
in_chs: int,
719719
hidden_channels: Optional[int] = None,
720720
out_chs: Optional[int] = None,
721-
act_layer: nn.Module = nn.GELU,
721+
act_layer: Type[nn.Module] = nn.GELU,
722722
drop: float = 0.0,
723723
) -> None:
724724
"""Build convolutional FFN module.
@@ -890,7 +890,7 @@ def __init__(
890890
dim: int,
891891
kernel_size: int = 3,
892892
mlp_ratio: float = 4.0,
893-
act_layer: nn.Module = nn.GELU,
893+
act_layer: Type[nn.Module] = nn.GELU,
894894
proj_drop: float = 0.0,
895895
drop_path: float = 0.0,
896896
layer_scale_init_value: float = 1e-5,
@@ -947,8 +947,8 @@ def __init__(
947947
self,
948948
dim: int,
949949
mlp_ratio: float = 4.0,
950-
act_layer: nn.Module = nn.GELU,
951-
norm_layer: nn.Module = nn.BatchNorm2d,
950+
act_layer: Type[nn.Module] = nn.GELU,
951+
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
952952
proj_drop: float = 0.0,
953953
drop_path: float = 0.0,
954954
layer_scale_init_value: float = 1e-5,
@@ -1007,8 +1007,8 @@ def __init__(
10071007
pos_emb_layer: Optional[nn.Module] = None,
10081008
kernel_size: int = 3,
10091009
mlp_ratio: float = 4.0,
1010-
act_layer: nn.Module = nn.GELU,
1011-
norm_layer: nn.Module = nn.BatchNorm2d,
1010+
act_layer: Type[nn.Module] = nn.GELU,
1011+
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
10121012
proj_drop_rate: float = 0.0,
10131013
drop_path_rate: float = 0.0,
10141014
layer_scale_init_value: Optional[float] = 1e-5,
@@ -1121,8 +1121,8 @@ def __init__(
11211121
fork_feat: bool = False,
11221122
cls_ratio: float = 2.0,
11231123
global_pool: str = 'avg',
1124-
norm_layer: nn.Module = nn.BatchNorm2d,
1125-
act_layer: nn.Module = nn.GELU,
1124+
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
1125+
act_layer: Type[nn.Module] = nn.GELU,
11261126
inference_mode: bool = False,
11271127
) -> None:
11281128
super().__init__()

timm/models/hiera.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -316,8 +316,8 @@ def __init__(
316316
mlp_ratio: float = 4.0,
317317
drop_path: float = 0.0,
318318
init_values: Optional[float] = None,
319-
norm_layer: nn.Module = nn.LayerNorm,
320-
act_layer: nn.Module = nn.GELU,
319+
norm_layer: Type[nn.Module] = nn.LayerNorm,
320+
act_layer: Type[nn.Module] = nn.GELU,
321321
q_stride: int = 1,
322322
window_size: int = 0,
323323
use_expand_proj: bool = True,

timm/models/swin_transformer_v2.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# Written by Ze Liu
1414
# --------------------------------------------------------
1515
import math
16-
from typing import Callable, List, Optional, Tuple, Union
16+
from typing import Callable, List, Optional, Tuple, Type, Union
1717

1818
import torch
1919
import torch.nn as nn
@@ -230,7 +230,7 @@ def __init__(
230230
attn_drop: float = 0.,
231231
drop_path: float = 0.,
232232
act_layer: LayerType = "gelu",
233-
norm_layer: nn.Module = nn.LayerNorm,
233+
norm_layer: Type[nn.Module] = nn.LayerNorm,
234234
pretrained_window_size: _int_or_tuple_2_t = 0,
235235
):
236236
"""
@@ -422,7 +422,7 @@ def __init__(
422422
self,
423423
dim: int,
424424
out_dim: Optional[int] = None,
425-
norm_layer: nn.Module = nn.LayerNorm
425+
norm_layer: Type[nn.Module] = nn.LayerNorm
426426
):
427427
"""
428428
Args:
@@ -470,7 +470,7 @@ def __init__(
470470
attn_drop: float = 0.,
471471
drop_path: float = 0.,
472472
act_layer: Union[str, Callable] = 'gelu',
473-
norm_layer: nn.Module = nn.LayerNorm,
473+
norm_layer: Type[nn.Module] = nn.LayerNorm,
474474
pretrained_window_size: _int_or_tuple_2_t = 0,
475475
output_nchw: bool = False,
476476
) -> None:

timm/models/vgg.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
66
Copyright 2021 Ross Wightman
77
"""
8-
from typing import Any, Dict, List, Optional, Union, cast
8+
from typing import Any, Dict, List, Optional, Type, Union, cast
99

1010
import torch
1111
import torch.nn as nn
@@ -38,8 +38,8 @@ def __init__(
3838
kernel_size=7,
3939
mlp_ratio=1.0,
4040
drop_rate: float = 0.2,
41-
act_layer: nn.Module = None,
42-
conv_layer: nn.Module = None,
41+
act_layer: Optional[Type[nn.Module]] = None,
42+
conv_layer: Optional[Type[nn.Module]] = None,
4343
):
4444
super(ConvMlp, self).__init__()
4545
self.input_kernel_size = kernel_size
@@ -72,9 +72,9 @@ def __init__(
7272
in_chans: int = 3,
7373
output_stride: int = 32,
7474
mlp_ratio: float = 1.0,
75-
act_layer: nn.Module = nn.ReLU,
76-
conv_layer: nn.Module = nn.Conv2d,
77-
norm_layer: nn.Module = None,
75+
act_layer: Type[nn.Module] = nn.ReLU,
76+
conv_layer: Type[nn.Module] = nn.Conv2d,
77+
norm_layer: Optional[Type[nn.Module]] = None,
7878
global_pool: str = 'avg',
7979
drop_rate: float = 0.,
8080
) -> None:
@@ -295,4 +295,4 @@ def vgg19_bn(pretrained: bool = False, **kwargs: Any) -> VGG:
295295
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`._
296296
"""
297297
model_args = dict(norm_layer=nn.BatchNorm2d, **kwargs)
298-
return _create_vgg('vgg19_bn', pretrained=pretrained, **model_args)
298+
return _create_vgg('vgg19_bn', pretrained=pretrained, **model_args)

timm/models/vision_transformer.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def __init__(
6767
proj_bias: bool = True,
6868
attn_drop: float = 0.,
6969
proj_drop: float = 0.,
70-
norm_layer: nn.Module = nn.LayerNorm,
70+
norm_layer: Type[nn.Module] = nn.LayerNorm,
7171
) -> None:
7272
super().__init__()
7373
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
@@ -135,9 +135,9 @@ def __init__(
135135
attn_drop: float = 0.,
136136
init_values: Optional[float] = None,
137137
drop_path: float = 0.,
138-
act_layer: nn.Module = nn.GELU,
139-
norm_layer: nn.Module = nn.LayerNorm,
140-
mlp_layer: nn.Module = Mlp,
138+
act_layer: Type[nn.Module] = nn.GELU,
139+
norm_layer: Type[nn.Module] = nn.LayerNorm,
140+
mlp_layer: Type[nn.Module] = Mlp,
141141
) -> None:
142142
super().__init__()
143143
self.norm1 = norm_layer(dim)
@@ -184,9 +184,9 @@ def __init__(
184184
attn_drop: float = 0.,
185185
init_values: Optional[float] = None,
186186
drop_path: float = 0.,
187-
act_layer: nn.Module = nn.GELU,
188-
norm_layer: nn.Module = nn.LayerNorm,
189-
mlp_layer: nn.Module = Mlp,
187+
act_layer: Type[nn.Module] = nn.GELU,
188+
norm_layer: Type[nn.Module] = nn.LayerNorm,
189+
mlp_layer: Type[nn.Module] = Mlp,
190190
) -> None:
191191
super().__init__()
192192
self.init_values = init_values
@@ -247,9 +247,9 @@ def __init__(
247247
attn_drop: float = 0.,
248248
init_values: Optional[float] = None,
249249
drop_path: float = 0.,
250-
act_layer: nn.Module = nn.GELU,
251-
norm_layer: nn.Module = nn.LayerNorm,
252-
mlp_layer: Optional[nn.Module] = None,
250+
act_layer: Type[nn.Module] = nn.GELU,
251+
norm_layer: Type[nn.Module] = nn.LayerNorm,
252+
mlp_layer: Optional[Type[nn.Module]] = None,
253253
) -> None:
254254
super().__init__()
255255
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
@@ -342,9 +342,9 @@ def __init__(
342342
proj_drop: float = 0.,
343343
attn_drop: float = 0.,
344344
drop_path: float = 0.,
345-
act_layer: nn.Module = nn.GELU,
346-
norm_layer: nn.Module = nn.LayerNorm,
347-
mlp_layer: nn.Module = Mlp,
345+
act_layer: Type[nn.Module] = nn.GELU,
346+
norm_layer: Type[nn.Module] = nn.LayerNorm,
347+
mlp_layer: Type[nn.Module] = Mlp,
348348
) -> None:
349349
super().__init__()
350350
self.num_parallel = num_parallel

0 commit comments

Comments
 (0)