Skip to content

Commit 99f872c

Browse files
authored
Merge pull request #574 from ProGamerGov/optim-wip-circuits
Optim-wip: Add model linearization, and expanded weights spatial positions
2 parents 2a772ed + 2fc46d0 commit 99f872c

File tree

11 files changed

+1106
-238
lines changed

11 files changed

+1106
-238
lines changed

.circleci/config.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,9 @@ commands:
102102
steps:
103103
- run:
104104
name: "Simple PIP install"
105-
command: python -m pip install -e .[dev]
105+
command: |
106+
python -m pip install --upgrade pip
107+
python -m pip install -e .[dev]
106108
107109
py_3_7_setup:
108110
description: "Set python version to 3.7 and install pip and pytest"
@@ -112,8 +114,6 @@ commands:
112114
command: |
113115
pyenv versions
114116
pyenv global 3.7.0
115-
sudo python -m pip install --upgrade pip
116-
sudo python -m pip install pytest
117117
118118
install_cuda:
119119
description: "Install CUDA for GPU Machine"

captum/optim/_models/inception_v1.py

Lines changed: 82 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
1-
from typing import Tuple, Union, cast
1+
from typing import Optional, Tuple, Union, cast
22

33
import torch
44
import torch.nn as nn
55
import torch.nn.functional as F
66

7-
import captum.optim._utils.models as model_utils
7+
from captum.optim._utils.models import (
8+
AvgPool2dConstrained,
9+
LocalResponseNormLayer,
10+
RedirectedReluLayer,
11+
ReluLayer,
12+
SkipLayer,
13+
)
814

915
GS_SAVED_WEIGHTS_URL = (
1016
"https://github.com/pytorch/captum/raw/"
@@ -13,29 +19,42 @@
1319

1420

1521
def googlenet(
16-
pretrained: bool = False, progress: bool = True, model_path: str = None, **kwargs
22+
pretrained: bool = False,
23+
progress: bool = True,
24+
model_path: Optional[str] = None,
25+
**kwargs
1726
):
1827
r"""GoogLeNet (also known as Inception v1 & Inception 5h) model architecture from
1928
`"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.
2029
Args:
21-
pretrained (bool): If True, returns a model pre-trained on ImageNet
22-
progress (bool): If True, displays a progress bar of the download to stderr
23-
model_path (str): Optional path for InceptionV1 model file
24-
aux_logits (bool): If True, adds two auxiliary branches that can improve
25-
training. Default: *False* when pretrained is True otherwise *True*
26-
out_features (int): Number of output features in the model used for
30+
pretrained (bool, optional): If True, returns a model pre-trained on ImageNet.
31+
progress (bool, optional): If True, displays a progress bar of the download to
32+
stderr
33+
model_path (str, optional): Optional path for InceptionV1 model file.
34+
replace_relus_with_redirectedrelu (bool, optional): If True, return pretrained
35+
model with Redirected ReLU in place of ReLU layers.
36+
use_linear_modules_only (bool, optional): If True, return pretrained
37+
model with all nonlinear layers replaced with linear equivalents.
38+
aux_logits (bool, optional): If True, adds two auxiliary branches that can
39+
improve training. Default: *False* when pretrained is True otherwise *True*
40+
out_features (int, optional): Number of output features in the model used for
2741
training. Default: 1008 when pretrained is True.
28-
transform_input (bool): If True, preprocesses the input according to
42+
transform_input (bool, optional): If True, preprocesses the input according to
2943
the method with which it was trained on ImageNet. Default: *False*
30-
bgr_transform (bool): If True and transform_input is True, perform an
44+
bgr_transform (bool, optional): If True and transform_input is True, perform an
3145
RGB to BGR transform in the internal preprocessing.
3246
Default: *False*
3347
"""
48+
3449
if pretrained:
3550
if "transform_input" not in kwargs:
3651
kwargs["transform_input"] = True
3752
if "bgr_transform" not in kwargs:
3853
kwargs["bgr_transform"] = False
54+
if "replace_relus_with_redirectedrelu" not in kwargs:
55+
kwargs["replace_relus_with_redirectedrelu"] = True
56+
if "use_linear_modules_only" not in kwargs:
57+
kwargs["use_linear_modules_only"] = False
3958
if "aux_logits" not in kwargs:
4059
kwargs["aux_logits"] = False
4160
if "out_features" not in kwargs:
@@ -50,27 +69,38 @@ def googlenet(
5069
else:
5170
state_dict = torch.load(model_path, map_location="cpu")
5271
model.load_state_dict(state_dict)
53-
model_utils.replace_layers(model)
5472
return model
5573

5674
return InceptionV1(**kwargs)
5775

5876

59-
# Better version of Inception V1/GoogleNet for Inception5h
77+
# Better version of Inception V1 / GoogleNet for Inception5h
6078
class InceptionV1(nn.Module):
6179
def __init__(
6280
self,
6381
out_features: int = 1008,
6482
aux_logits: bool = False,
6583
transform_input: bool = False,
6684
bgr_transform: bool = False,
85+
replace_relus_with_redirectedrelu: bool = False,
86+
use_linear_modules_only: bool = False,
6787
) -> None:
6888
super(InceptionV1, self).__init__()
6989
self.aux_logits = aux_logits
7090
self.transform_input = transform_input
7191
self.bgr_transform = bgr_transform
7292
lrn_vals = (9, 9.99999974738e-05, 0.5, 1.0)
7393

94+
if use_linear_modules_only:
95+
activ = SkipLayer
96+
pool = AvgPool2dConstrained
97+
else:
98+
if replace_relus_with_redirectedrelu:
99+
activ = RedirectedReluLayer
100+
else:
101+
activ = ReluLayer
102+
pool = nn.MaxPool2d
103+
74104
self.conv1 = nn.Conv2d(
75105
in_channels=3,
76106
out_channels=64,
@@ -79,9 +109,9 @@ def __init__(
79109
groups=1,
80110
bias=True,
81111
)
82-
self.conv1_relu = model_utils.ReluLayer()
83-
self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
84-
self.localresponsenorm1 = model_utils.LocalResponseNormLayer(*lrn_vals)
112+
self.conv1_relu = activ()
113+
self.pool1 = pool(kernel_size=3, stride=2, padding=0)
114+
self.local_response_norm1 = LocalResponseNormLayer(*lrn_vals)
85115

86116
self.conv2 = nn.Conv2d(
87117
in_channels=64,
@@ -91,7 +121,7 @@ def __init__(
91121
groups=1,
92122
bias=True,
93123
)
94-
self.conv2_relu = model_utils.ReluLayer()
124+
self.conv2_relu = activ()
95125
self.conv3 = nn.Conv2d(
96126
in_channels=64,
97127
out_channels=192,
@@ -100,29 +130,29 @@ def __init__(
100130
groups=1,
101131
bias=True,
102132
)
103-
self.conv3_relu = model_utils.ReluLayer()
104-
self.localresponsenorm2 = model_utils.LocalResponseNormLayer(*lrn_vals)
133+
self.conv3_relu = activ()
134+
self.local_response_norm2 = LocalResponseNormLayer(*lrn_vals)
105135

106-
self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
107-
self.mixed3a = InceptionModule(192, 64, 96, 128, 16, 32, 32)
108-
self.mixed3b = InceptionModule(256, 128, 128, 192, 32, 96, 64)
109-
self.pool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
110-
self.mixed4a = InceptionModule(480, 192, 96, 204, 16, 48, 64)
136+
self.pool2 = pool(kernel_size=3, stride=2, padding=0)
137+
self.mixed3a = InceptionModule(192, 64, 96, 128, 16, 32, 32, activ, pool)
138+
self.mixed3b = InceptionModule(256, 128, 128, 192, 32, 96, 64, activ, pool)
139+
self.pool3 = pool(kernel_size=3, stride=2, padding=0)
140+
self.mixed4a = InceptionModule(480, 192, 96, 204, 16, 48, 64, activ, pool)
111141

112142
if self.aux_logits:
113-
self.aux1 = AuxBranch(508, out_features)
143+
self.aux1 = AuxBranch(508, out_features, activ)
114144

115-
self.mixed4b = InceptionModule(508, 160, 112, 224, 24, 64, 64)
116-
self.mixed4c = InceptionModule(512, 128, 128, 256, 24, 64, 64)
117-
self.mixed4d = InceptionModule(512, 112, 144, 288, 32, 64, 64)
145+
self.mixed4b = InceptionModule(508, 160, 112, 224, 24, 64, 64, activ, pool)
146+
self.mixed4c = InceptionModule(512, 128, 128, 256, 24, 64, 64, activ, pool)
147+
self.mixed4d = InceptionModule(512, 112, 144, 288, 32, 64, 64, activ, pool)
118148

119149
if self.aux_logits:
120-
self.aux2 = AuxBranch(528, out_features)
150+
self.aux2 = AuxBranch(528, out_features, activ)
121151

122-
self.mixed4e = InceptionModule(528, 256, 160, 320, 32, 128, 128)
123-
self.pool4 = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
124-
self.mixed5a = InceptionModule(832, 256, 160, 320, 48, 128, 128)
125-
self.mixed5b = InceptionModule(832, 384, 192, 384, 48, 128, 128)
152+
self.mixed4e = InceptionModule(528, 256, 160, 320, 32, 128, 128, activ, pool)
153+
self.pool4 = pool(kernel_size=3, stride=2, padding=0)
154+
self.mixed5a = InceptionModule(832, 256, 160, 320, 48, 128, 128, activ, pool)
155+
self.mixed5b = InceptionModule(832, 384, 192, 384, 48, 128, 128, activ, pool)
126156

127157
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
128158
self.drop = nn.Dropout(0.4000000059604645)
@@ -146,14 +176,14 @@ def forward(
146176
x = self.conv1_relu(x)
147177
x = F.pad(x, (0, 1, 0, 1), value=float("-inf"))
148178
x = self.pool1(x)
149-
x = self.localresponsenorm1(x)
179+
x = self.local_response_norm1(x)
150180

151181
x = self.conv2(x)
152182
x = self.conv2_relu(x)
153183
x = F.pad(x, (1, 1, 1, 1))
154184
x = self.conv3(x)
155185
x = self.conv3_relu(x)
156-
x = self.localresponsenorm2(x)
186+
x = self.local_response_norm2(x)
157187

158188
x = F.pad(x, (0, 1, 0, 1), value=float("-inf"))
159189
x = self.pool2(x)
@@ -199,6 +229,8 @@ def __init__(
199229
c5x5reduce: int,
200230
c5x5: int,
201231
pool_proj: int,
232+
activ=ReluLayer,
233+
p_layer=nn.MaxPool2d,
202234
) -> None:
203235
super(InceptionModule, self).__init__()
204236
self.conv_1x1 = nn.Conv2d(
@@ -209,7 +241,7 @@ def __init__(
209241
groups=1,
210242
bias=True,
211243
)
212-
self.conv_1x1_relu = model_utils.ReluLayer()
244+
self.conv_1x1_relu = activ()
213245

214246
self.conv_3x3_reduce = nn.Conv2d(
215247
in_channels=in_channels,
@@ -219,7 +251,7 @@ def __init__(
219251
groups=1,
220252
bias=True,
221253
)
222-
self.conv_3x3_reduce_relu = model_utils.ReluLayer()
254+
self.conv_3x3_reduce_relu = activ()
223255
self.conv_3x3 = nn.Conv2d(
224256
in_channels=c3x3reduce,
225257
out_channels=c3x3,
@@ -228,7 +260,7 @@ def __init__(
228260
groups=1,
229261
bias=True,
230262
)
231-
self.conv_3x3_relu = model_utils.ReluLayer()
263+
self.conv_3x3_relu = activ()
232264

233265
self.conv_5x5_reduce = nn.Conv2d(
234266
in_channels=in_channels,
@@ -238,7 +270,7 @@ def __init__(
238270
groups=1,
239271
bias=True,
240272
)
241-
self.conv_5x5_reduce_relu = model_utils.ReluLayer()
273+
self.conv_5x5_reduce_relu = activ()
242274
self.conv_5x5 = nn.Conv2d(
243275
in_channels=c5x5reduce,
244276
out_channels=c5x5,
@@ -247,9 +279,9 @@ def __init__(
247279
groups=1,
248280
bias=True,
249281
)
250-
self.conv_5x5_relu = model_utils.ReluLayer()
282+
self.conv_5x5_relu = activ()
251283

252-
self.pool = nn.MaxPool2d(kernel_size=3, stride=1, padding=0)
284+
self.pool = p_layer(kernel_size=3, stride=1, padding=0)
253285
self.pool_proj = nn.Conv2d(
254286
in_channels=in_channels,
255287
out_channels=pool_proj,
@@ -258,7 +290,7 @@ def __init__(
258290
groups=1,
259291
bias=True,
260292
)
261-
self.pool_proj_relu = model_utils.ReluLayer()
293+
self.pool_proj_relu = activ()
262294

263295
def forward(self, x: torch.Tensor) -> torch.Tensor:
264296
c1x1 = self.conv_1x1(x)
@@ -284,7 +316,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
284316

285317

286318
class AuxBranch(nn.Module):
287-
def __init__(self, in_channels: int = 508, out_features: int = 1008) -> None:
319+
def __init__(
320+
self,
321+
in_channels: int = 508,
322+
out_features: int = 1008,
323+
activ=ReluLayer,
324+
) -> None:
288325
super(AuxBranch, self).__init__()
289326
self.avg_pool = nn.AdaptiveAvgPool2d((4, 4))
290327
self.loss_conv = nn.Conv2d(
@@ -295,9 +332,9 @@ def __init__(self, in_channels: int = 508, out_features: int = 1008) -> None:
295332
groups=1,
296333
bias=True,
297334
)
298-
self.loss_conv_relu = model_utils.ReluLayer()
335+
self.loss_conv_relu = activ()
299336
self.loss_fc = nn.Linear(in_features=2048, out_features=1024, bias=True)
300-
self.loss_fc_relu = model_utils.ReluLayer()
337+
self.loss_fc_relu = activ()
301338
self.loss_dropout = nn.Dropout(0.699999988079071)
302339
self.loss_classifier = nn.Linear(
303340
in_features=1024, out_features=out_features, bias=True

captum/optim/_param/image/transform.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -135,14 +135,22 @@ class CenterCrop(torch.nn.Module):
135135
pixels_from_edges (bool, optional): Whether to treat crop size
136136
values as the number of pixels from the tensor's edge, or an
137137
exact shape in the center.
138+
offset_left (bool, optional): If the cropped away sides are not
139+
equal in size, offset center by +1 to the left and/or top.
140+
Default is set to False. This parameter is only valid when
141+
pixels_from_edges is False.
138142
"""
139143

140144
def __init__(
141-
self, size: IntSeqOrIntType = 0, pixels_from_edges: bool = False
145+
self,
146+
size: IntSeqOrIntType = 0,
147+
pixels_from_edges: bool = False,
148+
offset_left: bool = False,
142149
) -> None:
143150
super(CenterCrop, self).__init__()
144151
self.crop_vals = size
145152
self.pixels_from_edges = pixels_from_edges
153+
self.offset_left = offset_left
146154

147155
def forward(self, input: torch.Tensor) -> torch.Tensor:
148156
"""
@@ -153,11 +161,16 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
153161
tensor (torch.Tensor): A center cropped tensor.
154162
"""
155163

156-
return center_crop(input, self.crop_vals, self.pixels_from_edges)
164+
return center_crop(
165+
input, self.crop_vals, self.pixels_from_edges, self.offset_left
166+
)
157167

158168

159169
def center_crop(
160-
input: torch.Tensor, crop_vals: IntSeqOrIntType, pixels_from_edges: bool = False
170+
input: torch.Tensor,
171+
crop_vals: IntSeqOrIntType,
172+
pixels_from_edges: bool = False,
173+
offset_left: bool = False,
161174
) -> torch.Tensor:
162175
"""
163176
Center crop a specified amount from a tensor.
@@ -167,6 +180,10 @@ def center_crop(
167180
pixels_from_edges (bool, optional): Whether to treat crop size
168181
values as the number of pixels from the tensor's edge, or an
169182
exact shape in the center.
183+
offset_left (bool, optional): If the cropped away sides are not
184+
equal in size, offset center by +1 to the left and/or top.
185+
Default is set to False. This parameter is only valid when
186+
pixels_from_edges is False.
170187
Returns:
171188
*tensor*: A center cropped tensor.
172189
"""
@@ -188,8 +205,12 @@ def center_crop(
188205
sw, sh = w // 2 - (w_crop // 2), h // 2 - (h_crop // 2)
189206
x = input[..., sh : sh + h_crop, sw : sw + w_crop]
190207
else:
191-
h_crop = h - int(round((h - crop_vals[0]) / 2.0))
192-
w_crop = w - int(round((w - crop_vals[1]) / 2.0))
208+
h_crop = h - int(math.ceil((h - crop_vals[0]) / 2.0))
209+
w_crop = w - int(math.ceil((w - crop_vals[1]) / 2.0))
210+
if h % 2 == 0 and crop_vals[0] % 2 != 0 or h % 2 != 0 and crop_vals[0] % 2 == 0:
211+
h_crop = h_crop + 1 if offset_left else h_crop
212+
if w % 2 == 0 and crop_vals[1] % 2 != 0 or w % 2 != 0 and crop_vals[1] % 2 == 0:
213+
w_crop = w_crop + 1 if offset_left else w_crop
193214
x = input[..., h_crop - crop_vals[0] : h_crop, w_crop - crop_vals[1] : w_crop]
194215
return x
195216

0 commit comments

Comments
 (0)