diff --git a/captum/optim/_models/inception_v1.py b/captum/optim/_models/inception_v1.py index 7fa6441720..5275264403 100644 --- a/captum/optim/_models/inception_v1.py +++ b/captum/optim/_models/inception_v1.py @@ -1,3 +1,5 @@ +from typing import Tuple, Union, cast + import torch import torch.nn as nn import torch.nn.functional as F @@ -25,10 +27,15 @@ def googlenet( training. Default: 1008 when pretrained is True. transform_input (bool): If True, preprocesses the input according to the method with which it was trained on ImageNet. Default: *False* + bgr_transform (bool): If True and transform_input is True, perform an + RGB to BGR transform in the internal preprocessing. + Default: *False* """ if pretrained: if "transform_input" not in kwargs: kwargs["transform_input"] = True + if "bgr_transform" not in kwargs: + kwargs["bgr_transform"] = False if "aux_logits" not in kwargs: kwargs["aux_logits"] = False if "out_features" not in kwargs: @@ -56,10 +63,12 @@ def __init__( out_features: int = 1008, aux_logits: bool = False, transform_input: bool = False, + bgr_transform: bool = False, ) -> None: super(InceptionV1, self).__init__() self.aux_logits = aux_logits self.transform_input = transform_input + self.bgr_transform = bgr_transform lrn_vals = (9, 9.99999974738e-05, 0.5, 1.0) self.conv1 = nn.Conv2d( @@ -125,10 +134,12 @@ def _transform_input(self, x: torch.Tensor) -> torch.Tensor: assert x.min() >= 0.0 and x.max() <= 1.0 x = x.unsqueeze(0) if x.dim() == 3 else x x = x * 255 - 117 - x = x.clone()[:, [2, 1, 0]] # RGB to BGR + x = x[:, [2, 1, 0]] if self.bgr_transform else x return x - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward( + self, x: torch.Tensor + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: x = self._transform_input(x) x = F.pad(x, (2, 3, 2, 3)) x = self.conv1(x) @@ -173,7 +184,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.drop(x) x = self.fc(x) if not self.aux_logits: - return x + return cast(torch.Tensor, x) else: return x, aux1_output, aux2_output diff --git a/tests/optim/models/test_models.py b/tests/optim/models/test_models.py index b4cdab6b4c..44a4455a0c 100644 --- a/tests/optim/models/test_models.py +++ b/tests/optim/models/test_models.py @@ -30,6 +30,18 @@ def test_transform_inceptionv1(self) -> None: x = torch.randn(1, 3, 224, 224).clamp(0, 1) model = googlenet(pretrained=True) output = model._transform_input(x) + expected_output = x * 255 - 117 + assertTensorAlmostEqual(self, output, expected_output, 0) + + def test_transform_bgr_inceptionv1(self) -> None: + if torch.__version__ <= "1.2.0": + raise unittest.SkipTest( + "Skipping inceptionV1 internal transform" + + " BGR due to insufficient Torch version." + ) + x = torch.randn(1, 3, 224, 224).clamp(0, 1) + model = googlenet(pretrained=True, bgr_transform=True) + output = model._transform_input(x) expected_output = x[:, [2, 1, 0]] * 255 - 117 assertTensorAlmostEqual(self, output, expected_output, 0)