1- from typing import Tuple , Union , cast
1+ from typing import Optional , Tuple , Union , cast
22
33import torch
44import torch .nn as nn
55import torch .nn .functional as F
66
7- import captum .optim ._utils .models as model_utils
7+ from captum .optim ._utils .models import (
8+ AvgPool2dConstrained ,
9+ LocalResponseNormLayer ,
10+ RedirectedReluLayer ,
11+ ReluLayer ,
12+ SkipLayer ,
13+ )
814
915GS_SAVED_WEIGHTS_URL = (
1016 "https://github.com/pytorch/captum/raw/"
1319
1420
1521def googlenet (
16- pretrained : bool = False , progress : bool = True , model_path : str = None , ** kwargs
22+ pretrained : bool = False ,
23+ progress : bool = True ,
24+ model_path : Optional [str ] = None ,
25+ ** kwargs
1726):
1827 r"""GoogLeNet (also known as Inception v1 & Inception 5h) model architecture from
1928 `"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.
2029 Args:
21- pretrained (bool): If True, returns a model pre-trained on ImageNet
22- progress (bool): If True, displays a progress bar of the download to stderr
23- model_path (str): Optional path for InceptionV1 model file
24- aux_logits (bool): If True, adds two auxiliary branches that can improve
25- training. Default: *False* when pretrained is True otherwise *True*
26- out_features (int): Number of output features in the model used for
30+ pretrained (bool, optional): If True, returns a model pre-trained on ImageNet.
31+ progress (bool, optional): If True, displays a progress bar of the download to
32+ stderr
33+ model_path (str, optional): Optional path for InceptionV1 model file.
34+ replace_relus_with_redirectedrelu (bool, optional): If True, return pretrained
35+ model with Redirected ReLU in place of ReLU layers.
36+ use_linear_modules_only (bool, optional): If True, return pretrained
37+ model with all nonlinear layers replaced with linear equivalents.
38+ aux_logits (bool, optional): If True, adds two auxiliary branches that can
39+ improve training. Default: *False* when pretrained is True otherwise *True*
40+ out_features (int, optional): Number of output features in the model used for
2741 training. Default: 1008 when pretrained is True.
28- transform_input (bool): If True, preprocesses the input according to
42+ transform_input (bool, optional ): If True, preprocesses the input according to
2943 the method with which it was trained on ImageNet. Default: *False*
30- bgr_transform (bool): If True and transform_input is True, perform an
44+ bgr_transform (bool, optional ): If True and transform_input is True, perform an
3145 RGB to BGR transform in the internal preprocessing.
3246 Default: *False*
3347 """
48+
3449 if pretrained :
3550 if "transform_input" not in kwargs :
3651 kwargs ["transform_input" ] = True
3752 if "bgr_transform" not in kwargs :
3853 kwargs ["bgr_transform" ] = False
54+ if "replace_relus_with_redirectedrelu" not in kwargs :
55+ kwargs ["replace_relus_with_redirectedrelu" ] = True
56+ if "use_linear_modules_only" not in kwargs :
57+ kwargs ["use_linear_modules_only" ] = False
3958 if "aux_logits" not in kwargs :
4059 kwargs ["aux_logits" ] = False
4160 if "out_features" not in kwargs :
@@ -50,27 +69,38 @@ def googlenet(
5069 else :
5170 state_dict = torch .load (model_path , map_location = "cpu" )
5271 model .load_state_dict (state_dict )
53- model_utils .replace_layers (model )
5472 return model
5573
5674 return InceptionV1 (** kwargs )
5775
5876
59- # Better version of Inception V1/ GoogleNet for Inception5h
77+ # Better version of Inception V1 / GoogleNet for Inception5h
6078class InceptionV1 (nn .Module ):
6179 def __init__ (
6280 self ,
6381 out_features : int = 1008 ,
6482 aux_logits : bool = False ,
6583 transform_input : bool = False ,
6684 bgr_transform : bool = False ,
85+ replace_relus_with_redirectedrelu : bool = False ,
86+ use_linear_modules_only : bool = False ,
6787 ) -> None :
6888 super (InceptionV1 , self ).__init__ ()
6989 self .aux_logits = aux_logits
7090 self .transform_input = transform_input
7191 self .bgr_transform = bgr_transform
7292 lrn_vals = (9 , 9.99999974738e-05 , 0.5 , 1.0 )
7393
94+ if use_linear_modules_only :
95+ activ = SkipLayer
96+ pool = AvgPool2dConstrained
97+ else :
98+ if replace_relus_with_redirectedrelu :
99+ activ = RedirectedReluLayer
100+ else :
101+ activ = ReluLayer
102+ pool = nn .MaxPool2d
103+
74104 self .conv1 = nn .Conv2d (
75105 in_channels = 3 ,
76106 out_channels = 64 ,
@@ -79,9 +109,9 @@ def __init__(
79109 groups = 1 ,
80110 bias = True ,
81111 )
82- self .conv1_relu = model_utils . ReluLayer ()
83- self .pool1 = nn . MaxPool2d (kernel_size = 3 , stride = 2 , padding = 0 )
84- self .localresponsenorm1 = model_utils . LocalResponseNormLayer (* lrn_vals )
112+ self .conv1_relu = activ ()
113+ self .pool1 = pool (kernel_size = 3 , stride = 2 , padding = 0 )
114+ self .local_response_norm1 = LocalResponseNormLayer (* lrn_vals )
85115
86116 self .conv2 = nn .Conv2d (
87117 in_channels = 64 ,
@@ -91,7 +121,7 @@ def __init__(
91121 groups = 1 ,
92122 bias = True ,
93123 )
94- self .conv2_relu = model_utils . ReluLayer ()
124+ self .conv2_relu = activ ()
95125 self .conv3 = nn .Conv2d (
96126 in_channels = 64 ,
97127 out_channels = 192 ,
@@ -100,29 +130,29 @@ def __init__(
100130 groups = 1 ,
101131 bias = True ,
102132 )
103- self .conv3_relu = model_utils . ReluLayer ()
104- self .localresponsenorm2 = model_utils . LocalResponseNormLayer (* lrn_vals )
133+ self .conv3_relu = activ ()
134+ self .local_response_norm2 = LocalResponseNormLayer (* lrn_vals )
105135
106- self .pool2 = nn . MaxPool2d (kernel_size = 3 , stride = 2 , padding = 0 )
107- self .mixed3a = InceptionModule (192 , 64 , 96 , 128 , 16 , 32 , 32 )
108- self .mixed3b = InceptionModule (256 , 128 , 128 , 192 , 32 , 96 , 64 )
109- self .pool3 = nn . MaxPool2d (kernel_size = 3 , stride = 2 , padding = 0 )
110- self .mixed4a = InceptionModule (480 , 192 , 96 , 204 , 16 , 48 , 64 )
136+ self .pool2 = pool (kernel_size = 3 , stride = 2 , padding = 0 )
137+ self .mixed3a = InceptionModule (192 , 64 , 96 , 128 , 16 , 32 , 32 , activ , pool )
138+ self .mixed3b = InceptionModule (256 , 128 , 128 , 192 , 32 , 96 , 64 , activ , pool )
139+ self .pool3 = pool (kernel_size = 3 , stride = 2 , padding = 0 )
140+ self .mixed4a = InceptionModule (480 , 192 , 96 , 204 , 16 , 48 , 64 , activ , pool )
111141
112142 if self .aux_logits :
113- self .aux1 = AuxBranch (508 , out_features )
143+ self .aux1 = AuxBranch (508 , out_features , activ )
114144
115- self .mixed4b = InceptionModule (508 , 160 , 112 , 224 , 24 , 64 , 64 )
116- self .mixed4c = InceptionModule (512 , 128 , 128 , 256 , 24 , 64 , 64 )
117- self .mixed4d = InceptionModule (512 , 112 , 144 , 288 , 32 , 64 , 64 )
145+ self .mixed4b = InceptionModule (508 , 160 , 112 , 224 , 24 , 64 , 64 , activ , pool )
146+ self .mixed4c = InceptionModule (512 , 128 , 128 , 256 , 24 , 64 , 64 , activ , pool )
147+ self .mixed4d = InceptionModule (512 , 112 , 144 , 288 , 32 , 64 , 64 , activ , pool )
118148
119149 if self .aux_logits :
120- self .aux2 = AuxBranch (528 , out_features )
150+ self .aux2 = AuxBranch (528 , out_features , activ )
121151
122- self .mixed4e = InceptionModule (528 , 256 , 160 , 320 , 32 , 128 , 128 )
123- self .pool4 = nn . MaxPool2d (kernel_size = 3 , stride = 2 , padding = 0 )
124- self .mixed5a = InceptionModule (832 , 256 , 160 , 320 , 48 , 128 , 128 )
125- self .mixed5b = InceptionModule (832 , 384 , 192 , 384 , 48 , 128 , 128 )
152+ self .mixed4e = InceptionModule (528 , 256 , 160 , 320 , 32 , 128 , 128 , activ , pool )
153+ self .pool4 = pool (kernel_size = 3 , stride = 2 , padding = 0 )
154+ self .mixed5a = InceptionModule (832 , 256 , 160 , 320 , 48 , 128 , 128 , activ , pool )
155+ self .mixed5b = InceptionModule (832 , 384 , 192 , 384 , 48 , 128 , 128 , activ , pool )
126156
127157 self .avgpool = nn .AdaptiveAvgPool2d ((1 , 1 ))
128158 self .drop = nn .Dropout (0.4000000059604645 )
@@ -146,14 +176,14 @@ def forward(
146176 x = self .conv1_relu (x )
147177 x = F .pad (x , (0 , 1 , 0 , 1 ), value = float ("-inf" ))
148178 x = self .pool1 (x )
149- x = self .localresponsenorm1 (x )
179+ x = self .local_response_norm1 (x )
150180
151181 x = self .conv2 (x )
152182 x = self .conv2_relu (x )
153183 x = F .pad (x , (1 , 1 , 1 , 1 ))
154184 x = self .conv3 (x )
155185 x = self .conv3_relu (x )
156- x = self .localresponsenorm2 (x )
186+ x = self .local_response_norm2 (x )
157187
158188 x = F .pad (x , (0 , 1 , 0 , 1 ), value = float ("-inf" ))
159189 x = self .pool2 (x )
@@ -199,6 +229,8 @@ def __init__(
199229 c5x5reduce : int ,
200230 c5x5 : int ,
201231 pool_proj : int ,
232+ activ = ReluLayer ,
233+ p_layer = nn .MaxPool2d ,
202234 ) -> None :
203235 super (InceptionModule , self ).__init__ ()
204236 self .conv_1x1 = nn .Conv2d (
@@ -209,7 +241,7 @@ def __init__(
209241 groups = 1 ,
210242 bias = True ,
211243 )
212- self .conv_1x1_relu = model_utils . ReluLayer ()
244+ self .conv_1x1_relu = activ ()
213245
214246 self .conv_3x3_reduce = nn .Conv2d (
215247 in_channels = in_channels ,
@@ -219,7 +251,7 @@ def __init__(
219251 groups = 1 ,
220252 bias = True ,
221253 )
222- self .conv_3x3_reduce_relu = model_utils . ReluLayer ()
254+ self .conv_3x3_reduce_relu = activ ()
223255 self .conv_3x3 = nn .Conv2d (
224256 in_channels = c3x3reduce ,
225257 out_channels = c3x3 ,
@@ -228,7 +260,7 @@ def __init__(
228260 groups = 1 ,
229261 bias = True ,
230262 )
231- self .conv_3x3_relu = model_utils . ReluLayer ()
263+ self .conv_3x3_relu = activ ()
232264
233265 self .conv_5x5_reduce = nn .Conv2d (
234266 in_channels = in_channels ,
@@ -238,7 +270,7 @@ def __init__(
238270 groups = 1 ,
239271 bias = True ,
240272 )
241- self .conv_5x5_reduce_relu = model_utils . ReluLayer ()
273+ self .conv_5x5_reduce_relu = activ ()
242274 self .conv_5x5 = nn .Conv2d (
243275 in_channels = c5x5reduce ,
244276 out_channels = c5x5 ,
@@ -247,9 +279,9 @@ def __init__(
247279 groups = 1 ,
248280 bias = True ,
249281 )
250- self .conv_5x5_relu = model_utils . ReluLayer ()
282+ self .conv_5x5_relu = activ ()
251283
252- self .pool = nn . MaxPool2d (kernel_size = 3 , stride = 1 , padding = 0 )
284+ self .pool = p_layer (kernel_size = 3 , stride = 1 , padding = 0 )
253285 self .pool_proj = nn .Conv2d (
254286 in_channels = in_channels ,
255287 out_channels = pool_proj ,
@@ -258,7 +290,7 @@ def __init__(
258290 groups = 1 ,
259291 bias = True ,
260292 )
261- self .pool_proj_relu = model_utils . ReluLayer ()
293+ self .pool_proj_relu = activ ()
262294
263295 def forward (self , x : torch .Tensor ) -> torch .Tensor :
264296 c1x1 = self .conv_1x1 (x )
@@ -284,7 +316,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
284316
285317
286318class AuxBranch (nn .Module ):
287- def __init__ (self , in_channels : int = 508 , out_features : int = 1008 ) -> None :
319+ def __init__ (
320+ self ,
321+ in_channels : int = 508 ,
322+ out_features : int = 1008 ,
323+ activ = ReluLayer ,
324+ ) -> None :
288325 super (AuxBranch , self ).__init__ ()
289326 self .avg_pool = nn .AdaptiveAvgPool2d ((4 , 4 ))
290327 self .loss_conv = nn .Conv2d (
@@ -295,9 +332,9 @@ def __init__(self, in_channels: int = 508, out_features: int = 1008) -> None:
295332 groups = 1 ,
296333 bias = True ,
297334 )
298- self .loss_conv_relu = model_utils . ReluLayer ()
335+ self .loss_conv_relu = activ ()
299336 self .loss_fc = nn .Linear (in_features = 2048 , out_features = 1024 , bias = True )
300- self .loss_fc_relu = model_utils . ReluLayer ()
337+ self .loss_fc_relu = activ ()
301338 self .loss_dropout = nn .Dropout (0.699999988079071 )
302339 self .loss_classifier = nn .Linear (
303340 in_features = 1024 , out_features = out_features , bias = True
0 commit comments