Skip to content

Commit 981c763

Browse files
authored
Disable Attention Pooling by default because of its input size restrictions
1 parent 599d8e1 commit 981c763

File tree

2 files changed

+51
-9
lines changed

2 files changed

+51
-9
lines changed

captum/optim/models/_image/clip_resnet50x4_image.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@ def clip_resnet50x4_image(
2323
This model can be combined with the CLIP ResNet 50x4 Text model to create the full
2424
CLIP ResNet 50x4 model.
2525
26-
Note that model inputs are expected to have a shape of: [B, 3, 288, 288] or
27-
[3, 288, 288].
26+
Note that the model was trained on inputs with a shape of: [B, 3, 288, 288].
2827
2928
See here for more details:
3029
https://github.com/openai/CLIP
@@ -48,6 +47,10 @@ def clip_resnet50x4_image(
4847
transform_input (bool, optional): If True, preprocesses the input according to
4948
the method with which it was trained.
5049
Default: *True* when pretrained is True otherwise *False*
50+
use_attnpool (bool, optional): Whether or not to use the final AttentionPool2d
51+
layer in the forward function. If set to True, model inputs are required
52+
to have a shape of: [B, 3, 288, 288] or [3, 288, 288].
53+
Default: False
5154
5255
Returns:
5356
**CLIP_ResNet50x4Image** (CLIP_ResNet50x4Image): A CLIP ResNet 50x4 model's
@@ -60,6 +63,8 @@ def clip_resnet50x4_image(
6063
kwargs["replace_relus_with_redirectedrelu"] = True
6164
if "use_linear_modules_only" not in kwargs:
6265
kwargs["use_linear_modules_only"] = False
66+
if "use_attnpool" not in kwargs:
67+
kwargs["use_attnpool"] = False
6368

6469
model = CLIP_ResNet50x4Image(**kwargs)
6570

@@ -81,13 +86,14 @@ class CLIP_ResNet50x4Image(nn.Module):
8186
Visual Models From Natural Language Supervision': https://arxiv.org/abs/2103.00020
8287
"""
8388

84-
__constants__ = ["transform_input"]
89+
__constants__ = ["transform_input", "use_attnpool"]
8590

8691
def __init__(
8792
self,
8893
transform_input: bool = False,
8994
replace_relus_with_redirectedrelu: bool = False,
9095
use_linear_modules_only: bool = False,
96+
use_attnpool: bool = True,
9197
) -> None:
9298
"""
9399
Args:
@@ -101,6 +107,11 @@ def __init__(
101107
transform_input (bool, optional): If True, preprocesses the input according
102108
to the method with which it was trained on.
103109
Default: False
110+
use_attnpool (bool, optional): Whether or not to use the final
111+
AttentionPool2d layer in the forward function. If set to True, model
112+
inputs are required to have a shape of: [B, 3, 288, 288] or
113+
[3, 288, 288].
114+
Default: True
104115
"""
105116
super().__init__()
106117
if use_linear_modules_only:
@@ -112,6 +123,7 @@ def __init__(
112123
activ = nn.ReLU
113124

114125
self.transform_input = transform_input
126+
self.use_attnpool = use_attnpool
115127

116128
# Stem layers
117129
self.conv1 = nn.Conv2d(3, 40, kernel_size=3, stride=2, padding=1, bias=False)
@@ -216,7 +228,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
216228
x = self.layer4(x)
217229

218230
# Attention Pooling
219-
x = self.attnpool(x)
231+
if self.use_attnpool:
232+
x = self.attnpool(x)
220233
return x
221234

222235

tests/optim/models/test_clip_resnet50x4_image.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,10 @@ def test_clip_resnet50x4_image_load_and_forward(self) -> None:
8181
+ " insufficient Torch version."
8282
)
8383
x = torch.zeros(1, 3, 288, 288)
84-
model = clip_resnet50x4_image(pretrained=True)
84+
model = clip_resnet50x4_image(pretrained=True, use_attnpool=True)
8585
output = model(x)
8686
self.assertEqual(list(output.shape), [1, 640])
87+
self.assertTrue(model.use_attnpool)
8788

8889
def test_untrained_clip_resnet50x4_image_load_and_forward(self) -> None:
8990
if version.parse(torch.__version__) <= version.parse("1.6.0"):
@@ -92,9 +93,10 @@ def test_untrained_clip_resnet50x4_image_load_and_forward(self) -> None:
9293
+ " insufficient Torch version."
9394
)
9495
x = torch.zeros(1, 3, 288, 288)
95-
model = clip_resnet50x4_image(pretrained=False)
96+
model = clip_resnet50x4_image(pretrained=False, use_attnpool=True)
9697
output = model(x)
9798
self.assertEqual(list(output.shape), [1, 640])
99+
self.assertTrue(model.use_attnpool)
98100

99101
def test_clip_resnet50x4_image_warning(self) -> None:
100102
if version.parse(torch.__version__) <= version.parse("1.6.0"):
@@ -109,6 +111,30 @@ def test_clip_resnet50x4_image_warning(self) -> None:
109111
with self.assertWarns(UserWarning):
110112
_ = model._transform_input(x)
111113

114+
def test_clip_resnet50x4_image_use_attnpool_false(self) -> None:
115+
if version.parse(torch.__version__) <= version.parse("1.6.0"):
116+
raise unittest.SkipTest(
117+
"Skipping basic pretrained CLIP ResNet 50x4 Image use_attnpool"
118+
+ " forward due to insufficient Torch version."
119+
)
120+
x = torch.zeros(1, 3, 288, 288)
121+
model = clip_resnet50x4_image(pretrained=True, use_attnpool=False)
122+
output = model(x)
123+
self.assertEqual(list(output.shape), [1, 2560, 9, 9])
124+
self.assertFalse(model.use_attnpool)
125+
126+
def test_clip_resnet50x4_image_use_attnpool_false_size_128(self) -> None:
127+
if version.parse(torch.__version__) <= version.parse("1.6.0"):
128+
raise unittest.SkipTest(
129+
"Skipping basic pretrained CLIP ResNet 50x4 Image use_attnpool"
130+
+ " forward with 128x128 input due to insufficient Torch version."
131+
)
132+
x = torch.zeros(1, 3, 128, 128)
133+
model = clip_resnet50x4_image(pretrained=True, use_attnpool=False)
134+
output = model(x)
135+
self.assertEqual(list(output.shape), [1, 2560, 4, 4])
136+
self.assertFalse(model.use_attnpool)
137+
112138
def test_clip_resnet50x4_image_forward_cuda(self) -> None:
113139
if version.parse(torch.__version__) <= version.parse("1.6.0"):
114140
raise unittest.SkipTest(
@@ -121,11 +147,12 @@ def test_clip_resnet50x4_image_forward_cuda(self) -> None:
121147
+ " not supporting CUDA."
122148
)
123149
x = torch.zeros(1, 3, 288, 288).cuda()
124-
model = clip_resnet50x4_image(pretrained=True).cuda()
150+
model = clip_resnet50x4_image(pretrained=True, use_attnpool=True).cuda()
125151
output = model(x)
126152

127153
self.assertTrue(output.is_cuda)
128154
self.assertEqual(list(output.shape), [1, 640])
155+
self.assertTrue(model.use_attnpool)
129156

130157
def test_clip_resnet50x4_image_jit_module_no_redirected_relu(self) -> None:
131158
if version.parse(torch.__version__) <= version.parse("1.8.0"):
@@ -135,11 +162,12 @@ def test_clip_resnet50x4_image_jit_module_no_redirected_relu(self) -> None:
135162
)
136163
x = torch.zeros(1, 3, 288, 288)
137164
model = clip_resnet50x4_image(
138-
pretrained=True, replace_relus_with_redirectedrelu=False
165+
pretrained=True, replace_relus_with_redirectedrelu=False, use_attnpool=True
139166
)
140167
jit_model = torch.jit.script(model)
141168
output = jit_model(x)
142169
self.assertEqual(list(output.shape), [1, 640])
170+
self.assertTrue(model.use_attnpool)
143171

144172
def test_clip_resnet50x4_image_jit_module_with_redirected_relu(self) -> None:
145173
if version.parse(torch.__version__) <= version.parse("1.8.0"):
@@ -149,8 +177,9 @@ def test_clip_resnet50x4_image_jit_module_with_redirected_relu(self) -> None:
149177
)
150178
x = torch.zeros(1, 3, 288, 288)
151179
model = clip_resnet50x4_image(
152-
pretrained=True, replace_relus_with_redirectedrelu=True
180+
pretrained=True, replace_relus_with_redirectedrelu=True, use_attnpool=True
153181
)
154182
jit_model = torch.jit.script(model)
155183
output = jit_model(x)
156184
self.assertEqual(list(output.shape), [1, 640])
185+
self.assertTrue(model.use_attnpool)

0 commit comments

Comments
 (0)