Skip to content

Commit 2a772ed

Browse files
authored
Merge pull request #582 from ProGamerGov/optim-wip-colorspace-fix
Optim wip: Fix InceptionV1 color space
2 parents 885ea4b + 09869b0 commit 2a772ed

File tree

2 files changed

+26
-3
lines changed

2 files changed

+26
-3
lines changed

captum/optim/_models/inception_v1.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Tuple, Union, cast
2+
13
import torch
24
import torch.nn as nn
35
import torch.nn.functional as F
@@ -25,10 +27,15 @@ def googlenet(
2527
training. Default: 1008 when pretrained is True.
2628
transform_input (bool): If True, preprocesses the input according to
2729
the method with which it was trained on ImageNet. Default: *False*
30+
bgr_transform (bool): If True and transform_input is True, perform an
31+
RGB to BGR transform in the internal preprocessing.
32+
Default: *False*
2833
"""
2934
if pretrained:
3035
if "transform_input" not in kwargs:
3136
kwargs["transform_input"] = True
37+
if "bgr_transform" not in kwargs:
38+
kwargs["bgr_transform"] = False
3239
if "aux_logits" not in kwargs:
3340
kwargs["aux_logits"] = False
3441
if "out_features" not in kwargs:
@@ -56,10 +63,12 @@ def __init__(
5663
out_features: int = 1008,
5764
aux_logits: bool = False,
5865
transform_input: bool = False,
66+
bgr_transform: bool = False,
5967
) -> None:
6068
super(InceptionV1, self).__init__()
6169
self.aux_logits = aux_logits
6270
self.transform_input = transform_input
71+
self.bgr_transform = bgr_transform
6372
lrn_vals = (9, 9.99999974738e-05, 0.5, 1.0)
6473

6574
self.conv1 = nn.Conv2d(
@@ -125,10 +134,12 @@ def _transform_input(self, x: torch.Tensor) -> torch.Tensor:
125134
assert x.min() >= 0.0 and x.max() <= 1.0
126135
x = x.unsqueeze(0) if x.dim() == 3 else x
127136
x = x * 255 - 117
128-
x = x.clone()[:, [2, 1, 0]] # RGB to BGR
137+
x = x[:, [2, 1, 0]] if self.bgr_transform else x
129138
return x
130139

131-
def forward(self, x: torch.Tensor) -> torch.Tensor:
140+
def forward(
141+
self, x: torch.Tensor
142+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
132143
x = self._transform_input(x)
133144
x = F.pad(x, (2, 3, 2, 3))
134145
x = self.conv1(x)
@@ -173,7 +184,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
173184
x = self.drop(x)
174185
x = self.fc(x)
175186
if not self.aux_logits:
176-
return x
187+
return cast(torch.Tensor, x)
177188
else:
178189
return x, aux1_output, aux2_output
179190

tests/optim/models/test_models.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,18 @@ def test_transform_inceptionv1(self) -> None:
3030
x = torch.randn(1, 3, 224, 224).clamp(0, 1)
3131
model = googlenet(pretrained=True)
3232
output = model._transform_input(x)
33+
expected_output = x * 255 - 117
34+
assertTensorAlmostEqual(self, output, expected_output, 0)
35+
36+
def test_transform_bgr_inceptionv1(self) -> None:
37+
if torch.__version__ <= "1.2.0":
38+
raise unittest.SkipTest(
39+
"Skipping inceptionV1 internal transform"
40+
+ " BGR due to insufficient Torch version."
41+
)
42+
x = torch.randn(1, 3, 224, 224).clamp(0, 1)
43+
model = googlenet(pretrained=True, bgr_transform=True)
44+
output = model._transform_input(x)
3345
expected_output = x[:, [2, 1, 0]] * 255 - 117
3446
assertTensorAlmostEqual(self, output, expected_output, 0)
3547

0 commit comments

Comments
 (0)