diff --git a/captum/optim/models/__init__.py b/captum/optim/models/__init__.py index a970e68ec4..f1a8d45999 100755 --- a/captum/optim/models/__init__.py +++ b/captum/optim/models/__init__.py @@ -8,6 +8,13 @@ ) from ._image.inception5h_classes import INCEPTION5H_CLASSES # noqa: F401 from ._image.inception_v1 import InceptionV1, googlenet # noqa: F401 +from ._image.inception_v1_places365 import ( # noqa: F401 + InceptionV1Places365, + googlenet_places365, +) +from ._image.inception_v1_places365_classes import ( # noqa: F401 + INCEPTIONV1_PLACES365_CLASSES, +) __all__ = [ "RedirectedReluLayer", @@ -19,4 +26,6 @@ "InceptionV1", "googlenet", "INCEPTION5H_CLASSES", + "googlenet_places365", + "INCEPTIONV1_PLACES365_CLASSES", ] diff --git a/captum/optim/models/_common.py b/captum/optim/models/_common.py index 50cb903fd0..cf9a33955a 100644 --- a/captum/optim/models/_common.py +++ b/captum/optim/models/_common.py @@ -1,324 +1,325 @@ -import math -from inspect import signature -from typing import Dict, List, Optional, Tuple, Type, Union, cast - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from captum.optim._core.output_hook import ActivationFetcher -from captum.optim._utils.typing import ModuleOutputMapping, TupleOfTensorsOrTensorType - - -def get_model_layers(model: nn.Module) -> List[str]: - """ - Return a list of hookable layers for the target model. - """ - layers = [] - - def get_layers(net: nn.Module, prefix: List = []) -> None: - if hasattr(net, "_modules"): - for name, layer in net._modules.items(): - if layer is None: - continue - separator = "" if str(name).isdigit() else "." - name = "[" + str(name) + "]" if str(name).isdigit() else name - layers.append(separator.join(prefix + [name])) - get_layers(layer, prefix=prefix + [name]) - - get_layers(model) - return layers - - -class RedirectedReLU(torch.autograd.Function): - """ - A workaround when there is no gradient flow from an initial random input. - ReLU layers will block the gradient flow during backpropagation when their - input is less than 0. This means that it can be impossible to visualize a - target without allowing negative values to pass through ReLU layers during - backpropagation. - See: - https://github.com/tensorflow/lucid/blob/master/lucid/misc/redirected_relu_grad.py - """ - - @staticmethod - def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: - self.save_for_backward(input_tensor) - return input_tensor.clamp(min=0) - - @staticmethod - def backward(self, grad_output: torch.Tensor) -> torch.Tensor: - (input_tensor,) = self.saved_tensors - relu_grad = grad_output.clone() - relu_grad[input_tensor < 0] = 0 - if torch.equal(relu_grad, torch.zeros_like(relu_grad)): - # Let "wrong" gradients flow if gradient is completely 0 - return grad_output.clone() - return relu_grad - - -class RedirectedReluLayer(nn.Module): - """ - Class for applying RedirectedReLU - """ - - def forward(self, input: torch.Tensor) -> torch.Tensor: - return RedirectedReLU.apply(input) - - -def replace_layers( - model: nn.Module, - layer1: Type[nn.Module], - layer2: Type[nn.Module], - transfer_vars: bool = False, - **kwargs -) -> None: - """ - Replace all target layers with new layers inside the specified model, - possibly with the same initialization variables. - - Args: - model: (nn.Module): A PyTorch model instance. - layer1: (Type[nn.Module]): The layer class that you want to transfer - initialization variables from. - layer2: (Type[nn.Module]): The layer class to create with the variables - from layer1. - transfer_vars (bool, optional): Wether or not to try and copy - initialization variables from layer1 instances to the replacement - layer2 instances. - kwargs: (Any, optional): Any additional variables to use when creating - the new layer. - """ - - for name, child in model._modules.items(): - if isinstance(child, layer1): - if transfer_vars: - new_layer = _transfer_layer_vars(child, layer2, **kwargs) - else: - new_layer = layer2(**kwargs) - setattr(model, name, new_layer) - elif child is not None: - replace_layers(child, layer1, layer2, transfer_vars, **kwargs) - - -def _transfer_layer_vars( - layer1: nn.Module, layer2: Type[nn.Module], **kwargs -) -> nn.Module: - """ - Given a layer instance, create a new layer instance of another class - with the same initialization variables as the original layer. - Args: - layer1: (nn.Module): A layer instance that you want to transfer - initialization variables from. - layer2: (nn.Module): The layer class to create with the variables - from of layer1. - kwargs: (Any, optional): Any additional variables to use when creating - the new layer. - Returns: - layer2 instance (nn.Module): An instance of layer2 with the initialization - variables that it shares with layer1, and any specified additional - initialization variables. - """ - - l2_vars = list(signature(layer2.__init__).parameters.values()) - l2_vars = [ - str(l2_vars[i]).split()[0] - for i in range(len(l2_vars)) - if str(l2_vars[i]) != "self" - ] - l2_vars = [p.split(":")[0] if ":" in p and "=" not in p else p for p in l2_vars] - l2_vars = [p.split("=")[0] if "=" in p and ":" not in p else p for p in l2_vars] - layer2_vars: Dict = {k: [] for k in dict.fromkeys(l2_vars).keys()} - - layer1_vars = {k: v for k, v in vars(layer1).items() if not k.startswith("_")} - shared_vars = {k: v for k, v in layer1_vars.items() if k in layer2_vars} - new_vars = dict(item for d in (shared_vars, kwargs) for item in d.items()) - return layer2(**new_vars) - - -class Conv2dSame(nn.Conv2d): - """ - Tensorflow like 'SAME' convolution wrapper for 2D convolutions. - TODO: Replace with torch.nn.Conv2d when support for padding='same' - is in stable version - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: Union[int, Tuple[int, int]], - stride: Union[int, Tuple[int, int]] = 1, - padding: Union[int, Tuple[int, int]] = 0, - dilation: Union[int, Tuple[int, int]] = 1, - groups: int = 1, - bias: bool = True, - ) -> None: - super().__init__( - in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias - ) - - def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int: - return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - ih, iw = x.size()[-2:] - kh, kw = self.weight.size()[-2:] - pad_h = self.calc_same_pad(i=ih, k=kh, s=self.stride[0], d=self.dilation[0]) - pad_w = self.calc_same_pad(i=iw, k=kw, s=self.stride[1], d=self.dilation[1]) - - if pad_h > 0 or pad_w > 0: - x = F.pad( - x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] - ) - return F.conv2d( - x, - self.weight, - self.bias, - self.stride, - self.padding, - self.dilation, - self.groups, - ) - - -def collect_activations( - model: nn.Module, - targets: Union[nn.Module, List[nn.Module]], - model_input: TupleOfTensorsOrTensorType = torch.zeros(1, 3, 224, 224), -) -> ModuleOutputMapping: - """ - Collect target activations for a model. - """ - if not hasattr(targets, "__iter__"): - targets = [targets] - catch_activ = ActivationFetcher(model, targets) - activ_out = catch_activ(model_input) - return activ_out - - -class SkipLayer(torch.nn.Module): - """ - This layer is made to take the place of any layer that needs to be skipped over - during the forward pass. Use cases include removing nonlinear activation layers - like ReLU for circuits research. - - This layer works almost exactly the same way that nn.Indentiy does, except it also - ignores any additional arguments passed to the forward function. Any layer replaced - by SkipLayer must have the same input and output shapes. - - See nn.Identity for more details: - https://pytorch.org/docs/stable/generated/torch.nn.Identity.html - - Args: - args (Any): Any argument. Arguments will be safely ignored. - kwargs (Any) Any keyword argument. Arguments will be safely ignored. - """ - - def __init__(self, *args, **kwargs) -> None: - super().__init__() - - def forward( - self, x: Union[torch.Tensor, Tuple[torch.Tensor]], *args, **kwargs - ) -> Union[torch.Tensor, Tuple[torch.Tensor]]: - """ - Args: - x (torch.Tensor or tuple of torch.Tensor): The input tensor or tensors. - args (Any): Any argument. Arguments will be safely ignored. - kwargs (Any) Any keyword argument. Arguments will be safely ignored. - Returns: - x (torch.Tensor or tuple of torch.Tensor): The unmodified input tensor or - tensors. - """ - return x - - -def skip_layers( - model: nn.Module, layers: Union[List[Type[nn.Module]], Type[nn.Module]] -) -> None: - """ - This function is a wrapper function for - replace_layers and replaces the target layer - with layers that do nothing. - This is useful for removing the nonlinear ReLU - layers when creating expanded weights. - Args: - model (nn.Module): A PyTorch model instance. - layers (nn.Module or list of nn.Module): The layer - class type to replace in the model. - """ - if not hasattr(layers, "__iter__"): - layers = cast(Type[nn.Module], layers) - replace_layers(model, layers, SkipLayer) - else: - layers = cast(List[Type[nn.Module]], layers) - for target_layer in layers: - replace_layers(model, target_layer, SkipLayer) - - -class MaxPool2dRelaxed(torch.nn.Module): - """ - A relaxed pooling layer, that's useful for calculating attributions of spatial - positions. Noise in the gradient is reduced by the continuous relaxation of the - gradient of models using this layer. - - This layer is meant to be combined with forward-mode AD, so that the class - attributions of spatial posititions can be estimated using the rate at which - increasing the neuron affects the output classes. - - This layer peforms a MaxPool2d operation on the input, while using an equivalent - AvgPool2d layer to compute the gradient. This means that the forward pass returns - nn.MaxPool2d(input) while the backward pass uses nn.AvgPool2d(input). - - Carter, et al., "Activation Atlas", Distill, 2019. - https://distill.pub/2019/activation-atlas/ - - The Lucid equivalent of this class can be found here: - https://github.com/ - tensorflow/lucid/blob/master/lucid/optvis/overrides/smoothed_maxpool_grad.py - - An additional Lucid reference implementation can be found here: - https://colab.research.google.com/github/tensorflow/ - lucid/blob/master/notebooks/building-blocks/AttrSpatial.ipynb - """ - - def __init__( - self, - kernel_size: Union[int, Tuple[int, ...]], - stride: Optional[Union[int, Tuple[int, ...]]] = None, - padding: Union[int, Tuple[int, ...]] = 0, - ceil_mode: bool = False, - ) -> None: - """ - Args: - - kernel_size (int or tuple of int): The size of the window to perform max & - average pooling with. - stride (int or tuple of int, optional): The stride window size to use. - Default: None - padding (int or tuple of int): The amount of zero padding to add to both - sides in the nn.MaxPool2d & nn.AvgPool2d modules. - Default: 0 - ceil_mode (bool, optional): Whether to use ceil or floor for creating the - output shape. - Default: False - """ - super().__init__() - self.maxpool = torch.nn.MaxPool2d( - kernel_size=kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode - ) - self.avgpool = torch.nn.AvgPool2d( - kernel_size=kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Args: - - x (torch.Tensor): An input tensor to run the pooling operations on. - - Returns: - x (torch.Tensor): A max pooled x tensor with gradient of an equivalent avg - pooled tensor. - """ - return self.maxpool(x.detach()) + self.avgpool(x) - self.avgpool(x.detach()) +import math +from inspect import signature +from typing import Dict, List, Optional, Tuple, Type, Union, cast + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from captum.optim._core.output_hook import ActivationFetcher +from captum.optim._utils.typing import ModuleOutputMapping, TupleOfTensorsOrTensorType + + +def get_model_layers(model: nn.Module) -> List[str]: + """ + Return a list of hookable layers for the target model. + """ + layers = [] + + def get_layers(net: nn.Module, prefix: List = []) -> None: + if hasattr(net, "_modules"): + for name, layer in net._modules.items(): + if layer is None: + continue + separator = "" if str(name).isdigit() else "." + name = "[" + str(name) + "]" if str(name).isdigit() else name + layers.append(separator.join(prefix + [name])) + get_layers(layer, prefix=prefix + [name]) + + get_layers(model) + return layers + + +class RedirectedReLU(torch.autograd.Function): + """ + A workaround when there is no gradient flow from an initial random input. + ReLU layers will block the gradient flow during backpropagation when their + input is less than 0. This means that it can be impossible to visualize a + target without allowing negative values to pass through ReLU layers during + backpropagation. + See: + https://github.com/tensorflow/lucid/blob/master/lucid/misc/redirected_relu_grad.py + """ + + @staticmethod + def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: + self.save_for_backward(input_tensor) + return input_tensor.clamp(min=0) + + @staticmethod + def backward(self, grad_output: torch.Tensor) -> torch.Tensor: + (input_tensor,) = self.saved_tensors + relu_grad = grad_output.clone() + relu_grad[input_tensor < 0] = 0 + if torch.equal(relu_grad, torch.zeros_like(relu_grad)): + # Let "wrong" gradients flow if gradient is completely 0 + return grad_output.clone() + return relu_grad + + +class RedirectedReluLayer(nn.Module): + """ + Class for applying RedirectedReLU + """ + + @torch.jit.ignore + def forward(self, input: torch.Tensor) -> torch.Tensor: + return RedirectedReLU.apply(input) + + +def replace_layers( + model: nn.Module, + layer1: Type[nn.Module], + layer2: Type[nn.Module], + transfer_vars: bool = False, + **kwargs +) -> None: + """ + Replace all target layers with new layers inside the specified model, + possibly with the same initialization variables. + + Args: + model: (nn.Module): A PyTorch model instance. + layer1: (Type[nn.Module]): The layer class that you want to transfer + initialization variables from. + layer2: (Type[nn.Module]): The layer class to create with the variables + from layer1. + transfer_vars (bool, optional): Wether or not to try and copy + initialization variables from layer1 instances to the replacement + layer2 instances. + kwargs: (Any, optional): Any additional variables to use when creating + the new layer. + """ + + for name, child in model._modules.items(): + if isinstance(child, layer1): + if transfer_vars: + new_layer = _transfer_layer_vars(child, layer2, **kwargs) + else: + new_layer = layer2(**kwargs) + setattr(model, name, new_layer) + elif child is not None: + replace_layers(child, layer1, layer2, transfer_vars, **kwargs) + + +def _transfer_layer_vars( + layer1: nn.Module, layer2: Type[nn.Module], **kwargs +) -> nn.Module: + """ + Given a layer instance, create a new layer instance of another class + with the same initialization variables as the original layer. + Args: + layer1: (nn.Module): A layer instance that you want to transfer + initialization variables from. + layer2: (nn.Module): The layer class to create with the variables + from of layer1. + kwargs: (Any, optional): Any additional variables to use when creating + the new layer. + Returns: + layer2 instance (nn.Module): An instance of layer2 with the initialization + variables that it shares with layer1, and any specified additional + initialization variables. + """ + + l2_vars = list(signature(layer2.__init__).parameters.values()) + l2_vars = [ + str(l2_vars[i]).split()[0] + for i in range(len(l2_vars)) + if str(l2_vars[i]) != "self" + ] + l2_vars = [p.split(":")[0] if ":" in p and "=" not in p else p for p in l2_vars] + l2_vars = [p.split("=")[0] if "=" in p and ":" not in p else p for p in l2_vars] + layer2_vars: Dict = {k: [] for k in dict.fromkeys(l2_vars).keys()} + + layer1_vars = {k: v for k, v in vars(layer1).items() if not k.startswith("_")} + shared_vars = {k: v for k, v in layer1_vars.items() if k in layer2_vars} + new_vars = dict(item for d in (shared_vars, kwargs) for item in d.items()) + return layer2(**new_vars) + + +class Conv2dSame(nn.Conv2d): + """ + Tensorflow like 'SAME' convolution wrapper for 2D convolutions. + TODO: Replace with torch.nn.Conv2d when support for padding='same' + is in stable version + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]] = 1, + padding: Union[int, Tuple[int, int]] = 0, + dilation: Union[int, Tuple[int, int]] = 1, + groups: int = 1, + bias: bool = True, + ) -> None: + super().__init__( + in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias + ) + + def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int: + return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + ih, iw = x.size()[-2:] + kh, kw = self.weight.size()[-2:] + pad_h = self.calc_same_pad(i=ih, k=kh, s=self.stride[0], d=self.dilation[0]) + pad_w = self.calc_same_pad(i=iw, k=kw, s=self.stride[1], d=self.dilation[1]) + + if pad_h > 0 or pad_w > 0: + x = F.pad( + x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] + ) + return F.conv2d( + x, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + + +def collect_activations( + model: nn.Module, + targets: Union[nn.Module, List[nn.Module]], + model_input: TupleOfTensorsOrTensorType = torch.zeros(1, 3, 224, 224), +) -> ModuleOutputMapping: + """ + Collect target activations for a model. + """ + if not hasattr(targets, "__iter__"): + targets = [targets] + catch_activ = ActivationFetcher(model, targets) + activ_out = catch_activ(model_input) + return activ_out + + +class SkipLayer(torch.nn.Module): + """ + This layer is made to take the place of any layer that needs to be skipped over + during the forward pass. Use cases include removing nonlinear activation layers + like ReLU for circuits research. + + This layer works almost exactly the same way that nn.Indentiy does, except it also + ignores any additional arguments passed to the forward function. Any layer replaced + by SkipLayer must have the same input and output shapes. + + See nn.Identity for more details: + https://pytorch.org/docs/stable/generated/torch.nn.Identity.html + + Args: + args (Any): Any argument. Arguments will be safely ignored. + kwargs (Any) Any keyword argument. Arguments will be safely ignored. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__() + + def forward( + self, x: Union[torch.Tensor, Tuple[torch.Tensor]], *args, **kwargs + ) -> Union[torch.Tensor, Tuple[torch.Tensor]]: + """ + Args: + x (torch.Tensor or tuple of torch.Tensor): The input tensor or tensors. + args (Any): Any argument. Arguments will be safely ignored. + kwargs (Any) Any keyword argument. Arguments will be safely ignored. + Returns: + x (torch.Tensor or tuple of torch.Tensor): The unmodified input tensor or + tensors. + """ + return x + + +def skip_layers( + model: nn.Module, layers: Union[List[Type[nn.Module]], Type[nn.Module]] +) -> None: + """ + This function is a wrapper function for + replace_layers and replaces the target layer + with layers that do nothing. + This is useful for removing the nonlinear ReLU + layers when creating expanded weights. + Args: + model (nn.Module): A PyTorch model instance. + layers (nn.Module or list of nn.Module): The layer + class type to replace in the model. + """ + if not hasattr(layers, "__iter__"): + layers = cast(Type[nn.Module], layers) + replace_layers(model, layers, SkipLayer) + else: + layers = cast(List[Type[nn.Module]], layers) + for target_layer in layers: + replace_layers(model, target_layer, SkipLayer) + + +class MaxPool2dRelaxed(torch.nn.Module): + """ + A relaxed pooling layer, that's useful for calculating attributions of spatial + positions. Noise in the gradient is reduced by the continuous relaxation of the + gradient of models using this layer. + + This layer is meant to be combined with forward-mode AD, so that the class + attributions of spatial posititions can be estimated using the rate at which + increasing the neuron affects the output classes. + + This layer peforms a MaxPool2d operation on the input, while using an equivalent + AvgPool2d layer to compute the gradient. This means that the forward pass returns + nn.MaxPool2d(input) while the backward pass uses nn.AvgPool2d(input). + + Carter, et al., "Activation Atlas", Distill, 2019. + https://distill.pub/2019/activation-atlas/ + + The Lucid equivalent of this class can be found here: + https://github.com/ + tensorflow/lucid/blob/master/lucid/optvis/overrides/smoothed_maxpool_grad.py + + An additional Lucid reference implementation can be found here: + https://colab.research.google.com/github/tensorflow/ + lucid/blob/master/notebooks/building-blocks/AttrSpatial.ipynb + """ + + def __init__( + self, + kernel_size: Union[int, Tuple[int, ...]], + stride: Optional[Union[int, Tuple[int, ...]]] = None, + padding: Union[int, Tuple[int, ...]] = 0, + ceil_mode: bool = False, + ) -> None: + """ + Args: + + kernel_size (int or tuple of int): The size of the window to perform max & + average pooling with. + stride (int or tuple of int, optional): The stride window size to use. + Default: None + padding (int or tuple of int): The amount of zero padding to add to both + sides in the nn.MaxPool2d & nn.AvgPool2d modules. + Default: 0 + ceil_mode (bool, optional): Whether to use ceil or floor for creating the + output shape. + Default: False + """ + super().__init__() + self.maxpool = torch.nn.MaxPool2d( + kernel_size=kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode + ) + self.avgpool = torch.nn.AvgPool2d( + kernel_size=kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + + x (torch.Tensor): An input tensor to run the pooling operations on. + + Returns: + x (torch.Tensor): A max pooled x tensor with gradient of an equivalent avg + pooled tensor. + """ + return self.maxpool(x.detach()) + self.avgpool(x) - self.avgpool(x.detach()) diff --git a/captum/optim/models/_image/inception_v1.py b/captum/optim/models/_image/inception_v1.py index 102581c095..b9e534b91f 100644 --- a/captum/optim/models/_image/inception_v1.py +++ b/captum/optim/models/_image/inception_v1.py @@ -1,4 +1,5 @@ -from typing import Optional, Tuple, Type, Union, cast +from typing import Optional, Tuple, Type, Union +from warnings import warn import torch import torch.nn as nn @@ -19,24 +20,37 @@ def googlenet( ) -> "InceptionV1": r"""GoogLeNet (also known as Inception v1 & Inception 5h) model architecture from `"Going Deeper with Convolutions" `_. + Args: + pretrained (bool, optional): If True, returns a model pre-trained on ImageNet. + Default: False progress (bool, optional): If True, displays a progress bar of the download to stderr + Default: True model_path (str, optional): Optional path for InceptionV1 model file. + Default: None replace_relus_with_redirectedrelu (bool, optional): If True, return pretrained model with Redirected ReLU in place of ReLU layers. + Default: *True* when pretrained is True otherwise *False* use_linear_modules_only (bool, optional): If True, return pretrained model with all nonlinear layers replaced with linear equivalents. + Default: False aux_logits (bool, optional): If True, adds two auxiliary branches that can - improve training. Default: *False* when pretrained is True otherwise *True* + improve training. + Default: False out_features (int, optional): Number of output features in the model used for - training. Default: 1008 when pretrained is True. + training. + Default: 1008 transform_input (bool, optional): If True, preprocesses the input according to - the method with which it was trained on ImageNet. Default: *False* + the method with which it was trained on ImageNet. + Default: False bgr_transform (bool, optional): If True and transform_input is True, perform an RGB to BGR transform in the internal preprocessing. - Default: *False* + Default: False + + Returns: + **InceptionV1** (InceptionV1): An Inception5h model. """ if pretrained: @@ -69,6 +83,8 @@ def googlenet( # Better version of Inception V1 / GoogleNet for Inception5h class InceptionV1(nn.Module): + __constants__ = ["aux_logits", "transform_input", "bgr_transform"] + def __init__( self, out_features: int = 1008, @@ -78,7 +94,29 @@ def __init__( replace_relus_with_redirectedrelu: bool = False, use_linear_modules_only: bool = False, ) -> None: - super(InceptionV1, self).__init__() + """ + Args: + + replace_relus_with_redirectedrelu (bool, optional): If True, return + pretrained model with Redirected ReLU in place of ReLU layers. + Default: False + use_linear_modules_only (bool, optional): If True, return pretrained + model with all nonlinear layers replaced with linear equivalents. + Default: False + aux_logits (bool, optional): If True, adds two auxiliary branches that can + improve training. + Default: False + out_features (int, optional): Number of output features in the model used + for training. + Default: 1008 + transform_input (bool, optional): If True, preprocesses the input according + to the method with which it was trained on ImageNet. + Default: False + bgr_transform (bool, optional): If True and transform_input is True, + perform an RGB to BGR transform in the internal preprocessing. + Default: False + """ + super().__init__() self.aux_logits = aux_logits self.transform_input = transform_input self.bgr_transform = bgr_transform @@ -99,7 +137,6 @@ def __init__( out_channels=64, kernel_size=(7, 7), stride=(2, 2), - padding=3, groups=1, bias=True, ) @@ -121,7 +158,6 @@ def __init__( out_channels=192, kernel_size=(3, 3), stride=(1, 1), - padding=1, groups=1, bias=True, ) @@ -163,9 +199,18 @@ def __init__( self.fc = nn.Linear(1024, out_features) def _transform_input(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + + x (torch.Tensor): An input tensor to normalize and scale the values of. + + Returns: + x (torch.Tensor): A transformed tensor. + """ if self.transform_input: assert x.dim() == 3 or x.dim() == 4 - assert x.min() >= 0.0 and x.max() <= 1.0 + if x.min() < 0.0 or x.max() > 1.0: + warn("Model input has values outside of the range [0, 1].") x = x.unsqueeze(0) if x.dim() == 3 else x x = x * 255 - 117 x = x[:, [2, 1, 0]] if self.bgr_transform else x @@ -174,6 +219,15 @@ def _transform_input(self, x: torch.Tensor) -> torch.Tensor: def forward( self, x: torch.Tensor ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: + """ + Args: + + x (torch.Tensor): An input tensor to normalize and scale the values of. + + Returns: + x (torch.Tensor or tuple of torch.Tensor): A single or multiple output + tensors from the model. + """ x = self._transform_input(x) x = self.conv1(x) x = self.conv1_relu(x) @@ -212,7 +266,7 @@ def forward( x = self.drop(x) x = self.fc(x) if not self.aux_logits: - return cast(torch.Tensor, x) + return x else: return x, aux1_output, aux2_output @@ -230,7 +284,25 @@ def __init__( activ: Type[nn.Module] = nn.ReLU, p_layer: Type[nn.Module] = nn.MaxPool2d, ) -> None: - super(InceptionModule, self).__init__() + """ + Args: + + in_channels (int, optional): The number of input channels to use for the + inception module. + c1x1 (int, optional): + c3x3reduce (int, optional): + c3x3 (int, optional): + c5x5reduce (int, optional): + c5x5 (int, optional): + pool_proj (int, optional): + activ (type of nn.Module, optional): The nn.Module class type to use for + activation layers. + Default: nn.ReLU + p_layer (type of nn.Module, optional): The nn.Module class type to use for + pooling layers. + Default: nn.MaxPool2d + """ + super().__init__() self.conv_1x1 = nn.Conv2d( in_channels=in_channels, out_channels=c1x1, @@ -254,7 +326,6 @@ def __init__( out_channels=c3x3, kernel_size=(3, 3), stride=(1, 1), - padding=1, groups=1, bias=True, ) @@ -273,7 +344,6 @@ def __init__( out_channels=c5x5, kernel_size=(5, 5), stride=(1, 1), - padding=1, groups=1, bias=True, ) @@ -289,6 +359,14 @@ def __init__( ) def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + + x (torch.Tensor): An input tensor to pass through the Inception Module. + + Returns: + x (torch.Tensor): The output tensor of the Inception Module. + """ c1x1 = self.conv_1x1(x) c3x3 = self.conv_3x3_reduce(x) @@ -311,9 +389,22 @@ def __init__( out_features: int = 1008, activ: Type[nn.Module] = nn.ReLU, ) -> None: - super(AuxBranch, self).__init__() + """ + Args: + + in_channels (int, optional): The number of input channels to use for the + auxiliary branch. + Default: 508 + out_features (int, optional): The number of output features to use for the + auxiliary branch. + Default: 1008 + activ (type of nn.Module, optional): The nn.Module class type to use for + activation layers. + Default: nn.ReLU + """ + super().__init__() self.avg_pool = nn.AdaptiveAvgPool2d((4, 4)) - self.loss_conv = nn.Conv2d( + self.conv = nn.Conv2d( in_channels=in_channels, out_channels=128, kernel_size=(1, 1), @@ -321,21 +412,28 @@ def __init__( groups=1, bias=True, ) - self.loss_conv_relu = activ() - self.loss_fc = nn.Linear(in_features=2048, out_features=1024, bias=True) - self.loss_fc_relu = activ() - self.loss_dropout = nn.Dropout(0.699999988079071) - self.loss_classifier = nn.Linear( - in_features=1024, out_features=out_features, bias=True - ) + self.conv_relu = activ() + self.fc1 = nn.Linear(in_features=2048, out_features=1024, bias=True) + self.fc1_relu = activ() + self.dropout = nn.Dropout(0.699999988079071) + self.fc2 = nn.Linear(in_features=1024, out_features=out_features, bias=True) def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + + x (torch.Tensor): An input tensor to pass through the auxiliary branch + module. + + Returns: + x (torch.Tensor): The output tensor of the auxiliary branch module. + """ x = self.avg_pool(x) - x = self.loss_conv(x) - x = self.loss_conv_relu(x) + x = self.conv(x) + x = self.conv_relu(x) x = torch.flatten(x, 1) - x = self.loss_fc(x) - x = self.loss_fc_relu(x) - x = self.loss_dropout(x) - x = self.loss_classifier(x) + x = self.fc1(x) + x = self.fc1_relu(x) + x = self.dropout(x) + x = self.fc2(x) return x diff --git a/captum/optim/models/_image/inception_v1_places365.py b/captum/optim/models/_image/inception_v1_places365.py new file mode 100644 index 0000000000..8fb2fd8924 --- /dev/null +++ b/captum/optim/models/_image/inception_v1_places365.py @@ -0,0 +1,434 @@ +from typing import Any, Optional, Tuple, Type, Union +from warnings import warn + +import torch +import torch.nn as nn + +from captum.optim.models._common import Conv2dSame, RedirectedReluLayer, SkipLayer + +GS_SAVED_WEIGHTS_URL = ( + "https://pytorch-tutorial-assets.s3.amazonaws.com/" + + "captum/inceptionv1_places365.pth" +) + + +def googlenet_places365( + pretrained: bool = False, + progress: bool = True, + model_path: Optional[str] = None, + **kwargs: Any +) -> "InceptionV1Places365": + r"""GoogLeNet (also known as Inception v1 & Inception 5h) model architecture from + `"Going Deeper with Convolutions" `_. + + The pretrained GoogleNet model was trained using the MIT Places365 Standard + dataset. See here for more information: https://arxiv.org/abs/1610.02055 + + Args: + pretrained (bool, optional): If True, returns a model pre-trained on the MIT + Places365 Standard dataset. + Default: False + progress (bool, optional): If True, displays a progress bar of the download to + stderr + Default: True + model_path (str, optional): Optional path for InceptionV1 model file. + Default: None + replace_relus_with_redirectedrelu (bool, optional): If True, return pretrained + model with Redirected ReLU in place of ReLU layers. + Default: *True* when pretrained is True otherwise *False* + use_linear_modules_only (bool, optional): If True, return pretrained + model with all nonlinear layers replaced with linear equivalents. + Default: False + aux_logits (bool, optional): If True, adds two auxiliary branches that can + improve training. + Default: True + out_features (int, optional): Number of output features in the model used for + training. Default: 365 when pretrained is True. + Default: 365 + transform_input (bool, optional): If True, preprocesses the input according to + the method with which it was trained on Places365. + Default: True + """ + + if pretrained: + if "transform_input" not in kwargs: + kwargs["transform_input"] = True + if "replace_relus_with_redirectedrelu" not in kwargs: + kwargs["replace_relus_with_redirectedrelu"] = True + if "use_linear_modules_only" not in kwargs: + kwargs["use_linear_modules_only"] = False + if "aux_logits" not in kwargs: + kwargs["aux_logits"] = True + if "out_features" not in kwargs: + kwargs["out_features"] = 365 + + model = InceptionV1Places365(**kwargs) + + if model_path is None: + state_dict = torch.hub.load_state_dict_from_url( + GS_SAVED_WEIGHTS_URL, progress=progress, check_hash=False + ) + else: + state_dict = torch.load(model_path, map_location="cpu") + model.load_state_dict(state_dict) + return model + + return InceptionV1Places365(**kwargs) + + +class InceptionV1Places365(nn.Module): + """ + MIT Places365 variant of the InceptionV1 model. + """ + + __constants__ = ["aux_logits", "transform_input"] + + def __init__( + self, + out_features: int = 365, + aux_logits: bool = True, + transform_input: bool = True, + replace_relus_with_redirectedrelu: bool = False, + use_linear_modules_only: bool = False, + ) -> None: + """ + Args: + + out_features (int, optional): Number of output features in the model used + for training. + Default: 365 + aux_logits (bool, optional): If True, adds two auxiliary branches that can + improve training. + Default: True + transform_input (bool, optional): If True, preprocesses the input according + to the method with which it was trained on Places365. + Default: True + replace_relus_with_redirectedrelu (bool, optional): If True, return + pretrained model with Redirected ReLU in place of ReLU layers. + Default: False + use_linear_modules_only (bool, optional): If True, return pretrained model + with all nonlinear layers replaced with linear equivalents. + Default: False + """ + super().__init__() + self.aux_logits = aux_logits + self.transform_input = transform_input + lrn_vals = (5, 9.999999747378752e-05, 0.75, 1.0) + + if use_linear_modules_only: + activ = SkipLayer + pool = nn.AvgPool2d + else: + if replace_relus_with_redirectedrelu: + activ = RedirectedReluLayer + else: + activ = nn.ReLU + pool = nn.MaxPool2d + + self.conv1 = Conv2dSame( + in_channels=3, + out_channels=64, + kernel_size=(7, 7), + stride=(2, 2), + groups=1, + bias=True, + ) + self.conv1_relu = activ() + self.pool1 = pool(kernel_size=3, stride=2, padding=0, ceil_mode=True) + self.local_response_norm1 = nn.LocalResponseNorm(*lrn_vals) + + self.conv2 = nn.Conv2d( + in_channels=64, + out_channels=64, + kernel_size=(1, 1), + stride=(1, 1), + groups=1, + bias=True, + ) + self.conv2_relu = activ() + self.conv3 = Conv2dSame( + in_channels=64, + out_channels=192, + kernel_size=(3, 3), + stride=(1, 1), + groups=1, + bias=True, + ) + self.conv3_relu = activ() + self.local_response_norm2 = nn.LocalResponseNorm(*lrn_vals) + + self.pool2 = pool(kernel_size=3, stride=2, padding=0, ceil_mode=True) + self.mixed3a = InceptionModule(192, 64, 96, 128, 16, 32, 32, activ, pool) + self.mixed3a_relu = activ() + self.mixed3b = InceptionModule(256, 128, 128, 192, 32, 96, 64, activ, pool) + self.mixed3b_relu = activ() + self.pool3 = pool(kernel_size=3, stride=2, padding=0, ceil_mode=True) + self.mixed4a = InceptionModule(480, 192, 96, 208, 16, 48, 64, activ, pool) + self.mixed4a_relu = activ() + + if self.aux_logits: + self.aux1 = AuxBranch(512, out_features, activ) + + self.mixed4b = InceptionModule(512, 160, 112, 224, 24, 64, 64, activ, pool) + self.mixed4b_relu = activ() + self.mixed4c = InceptionModule(512, 128, 128, 256, 24, 64, 64, activ, pool) + self.mixed4c_relu = activ() + self.mixed4d = InceptionModule(512, 112, 144, 288, 32, 64, 64, activ, pool) + self.mixed4d_relu = activ() + + if self.aux_logits: + self.aux2 = AuxBranch(528, out_features, activ) + + self.mixed4e = InceptionModule(528, 256, 160, 320, 32, 128, 128, activ, pool) + self.mixed4e_relu = activ() + self.pool4 = pool(kernel_size=3, stride=2, padding=0, ceil_mode=True) + self.mixed5a = InceptionModule(832, 256, 160, 320, 32, 128, 128, activ, pool) + self.mixed5a_relu = activ() + self.mixed5b = InceptionModule(832, 384, 192, 384, 48, 128, 128, activ, pool) + self.mixed5b_relu = activ() + + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.drop = nn.Dropout(0.4000000059604645) + self.fc = nn.Linear(1024, out_features) + + def _transform_input(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + + x (torch.Tensor): An input tensor to normalize and scale the values of. + + Returns: + x (torch.Tensor): A transformed tensor. + """ + if self.transform_input: + assert x.dim() == 3 or x.dim() == 4 + if x.min() < 0.0 or x.max() > 1.0: + warn("Model input has values outside of the range [0, 1].") + x = x.unsqueeze(0) if x.dim() == 3 else x + x = x * 255 - torch.tensor( + [116.7894, 112.6004, 104.0437], device=x.device + ).view(3, 1, 1) + x = x[:, [2, 1, 0]] # RGB to BGR + return x + + def forward( + self, x: torch.Tensor + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: + """ + Args: + + x (torch.Tensor): An input tensor to normalize and scale the values of. + + Returns: + x (torch.Tensor or tuple of torch.Tensor): A single or multiple output + tensors from the model. + """ + x = self._transform_input(x) + x = self.conv1(x) + x = self.conv1_relu(x) + x = self.pool1(x) + x = self.local_response_norm1(x) + + x = self.conv2(x) + x = self.conv2_relu(x) + x = self.conv3(x) + x = self.conv3_relu(x) + x = self.local_response_norm2(x) + + x = self.pool2(x) + x = self.mixed3a_relu(self.mixed3a(x)) + x = self.mixed3b_relu(self.mixed3b(x)) + x = self.pool3(x) + x = self.mixed4a_relu(self.mixed4a(x)) + + if self.aux_logits: + aux1_output = self.aux1(x) + + x = self.mixed4b_relu(self.mixed4b(x)) + x = self.mixed4c_relu(self.mixed4c(x)) + x = self.mixed4d_relu(self.mixed4d(x)) + + if self.aux_logits: + aux2_output = self.aux2(x) + + x = self.mixed4e_relu(self.mixed4e(x)) + x = self.pool4(x) + x = self.mixed5a_relu(self.mixed5a(x)) + x = self.mixed5b_relu(self.mixed5b(x)) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.drop(x) + x = self.fc(x) + if not self.aux_logits: + return x + else: + return x, aux1_output, aux2_output + + +class InceptionModule(nn.Module): + def __init__( + self, + in_channels: int, + c1x1: int, + c3x3reduce: int, + c3x3: int, + c5x5reduce: int, + c5x5: int, + pool_proj: int, + activ: Type[nn.Module] = nn.ReLU, + p_layer: Type[nn.Module] = nn.MaxPool2d, + ) -> None: + """ + Args: + + in_channels (int, optional): The number of input channels to use for the + inception module. + c1x1 (int, optional): + c3x3reduce (int, optional): + c3x3 (int, optional): + c5x5reduce (int, optional): + c5x5 (int, optional): + pool_proj (int, optional): + activ (type of nn.Module, optional): The nn.Module class type to use for + activation layers. + Default: nn.ReLU + p_layer (type of nn.Module, optional): The nn.Module class type to use for + pooling layers. + Default: nn.MaxPool2d + """ + super().__init__() + self.conv_1x1 = nn.Conv2d( + in_channels=in_channels, + out_channels=c1x1, + kernel_size=(1, 1), + stride=(1, 1), + groups=1, + bias=True, + ) + + self.conv_3x3_reduce = nn.Conv2d( + in_channels=in_channels, + out_channels=c3x3reduce, + kernel_size=(1, 1), + stride=(1, 1), + groups=1, + bias=True, + ) + self.conv_3x3_reduce_relu = activ() + self.conv_3x3 = Conv2dSame( + in_channels=c3x3reduce, + out_channels=c3x3, + kernel_size=(3, 3), + stride=(1, 1), + groups=1, + bias=True, + ) + + self.conv_5x5_reduce = nn.Conv2d( + in_channels=in_channels, + out_channels=c5x5reduce, + kernel_size=(1, 1), + stride=(1, 1), + groups=1, + bias=True, + ) + self.conv_5x5_reduce_relu = activ() + self.conv_5x5 = Conv2dSame( + in_channels=c5x5reduce, + out_channels=c5x5, + kernel_size=(5, 5), + stride=(1, 1), + groups=1, + bias=True, + ) + + self.pool = p_layer(kernel_size=3, stride=1, padding=1, ceil_mode=True) + self.pool_proj = nn.Conv2d( + in_channels=in_channels, + out_channels=pool_proj, + kernel_size=(1, 1), + stride=(1, 1), + groups=1, + bias=True, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + + x (torch.Tensor): An input tensor to pass through the Inception Module. + + Returns: + x (torch.Tensor): The output tensor of the Inception Module. + """ + c1x1 = self.conv_1x1(x) + + c3x3 = self.conv_3x3_reduce(x) + c3x3 = self.conv_3x3_reduce_relu(c3x3) + c3x3 = self.conv_3x3(c3x3) + + c5x5 = self.conv_5x5_reduce(x) + c5x5 = self.conv_5x5_reduce_relu(c5x5) + c5x5 = self.conv_5x5(c5x5) + + px = self.pool(x) + px = self.pool_proj(px) + return torch.cat([c1x1, c3x3, c5x5, px], dim=1) + + +class AuxBranch(nn.Module): + def __init__( + self, + in_channels: int = 512, + out_features: int = 365, + activ: Type[nn.Module] = nn.ReLU, + ) -> None: + """ + Args: + + in_channels (int, optional): The number of input channels to use for the + auxiliary branch. + Default: 508 + out_features (int, optional): The number of output features to use for the + auxiliary branch. + Default: 1008 + activ (type of nn.Module, optional): The nn.Module class type to use for + activation layers. + Default: nn.ReLU + """ + super().__init__() + self.avg_pool = nn.AdaptiveAvgPool2d((4, 4)) + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=128, + kernel_size=(1, 1), + stride=(1, 1), + groups=1, + bias=True, + ) + self.conv_relu = activ() + self.fc1 = nn.Linear(in_features=2048, out_features=1024, bias=True) + self.fc1_relu = activ() + self.dropout = nn.Dropout(0.699999988079071) + self.fc2 = nn.Linear(in_features=1024, out_features=out_features, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + + x (torch.Tensor): An input tensor to pass through the auxiliary branch + module. + + Returns: + x (torch.Tensor): The output tensor of the auxiliary branch module. + """ + x = self.avg_pool(x) + x = self.conv(x) + x = self.conv_relu(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = self.fc1_relu(x) + x = self.dropout(x) + x = self.fc2(x) + return x diff --git a/captum/optim/models/_image/inception_v1_places365_classes.py b/captum/optim/models/_image/inception_v1_places365_classes.py new file mode 100644 index 0000000000..350f3b784d --- /dev/null +++ b/captum/optim/models/_image/inception_v1_places365_classes.py @@ -0,0 +1,375 @@ +from typing import List + +""" +List of classes for the MIT Places365 GoogleNet model trained using the Places365 +Standard dataset. Class list created from the Places365 GitHub repo class list: +https://github.com/CSAILVision/places365/blob/master/categories_places365.txt +""" + +INCEPTIONV1_PLACES365_CLASSES: List[str] = [ + "/a/airfield", + "/a/airplane_cabin", + "/a/airport_terminal", + "/a/alcove", + "/a/alley", + "/a/amphitheater", + "/a/amusement_arcade", + "/a/amusement_park", + "/a/apartment_building/outdoor", + "/a/aquarium", + "/a/aqueduct", + "/a/arcade", + "/a/arch", + "/a/archaelogical_excavation", + "/a/archive", + "/a/arena/hockey", + "/a/arena/performance", + "/a/arena/rodeo", + "/a/army_base", + "/a/art_gallery", + "/a/art_school", + "/a/art_studio", + "/a/artists_loft", + "/a/assembly_line", + "/a/athletic_field/outdoor", + "/a/atrium/public", + "/a/attic", + "/a/auditorium", + "/a/auto_factory", + "/a/auto_showroom", + "/b/badlands", + "/b/bakery/shop", + "/b/balcony/exterior", + "/b/balcony/interior", + "/b/ball_pit", + "/b/ballroom", + "/b/bamboo_forest", + "/b/bank_vault", + "/b/banquet_hall", + "/b/bar", + "/b/barn", + "/b/barndoor", + "/b/baseball_field", + "/b/basement", + "/b/basketball_court/indoor", + "/b/bathroom", + "/b/bazaar/indoor", + "/b/bazaar/outdoor", + "/b/beach", + "/b/beach_house", + "/b/beauty_salon", + "/b/bedchamber", + "/b/bedroom", + "/b/beer_garden", + "/b/beer_hall", + "/b/berth", + "/b/biology_laboratory", + "/b/boardwalk", + "/b/boat_deck", + "/b/boathouse", + "/b/bookstore", + "/b/booth/indoor", + "/b/botanical_garden", + "/b/bow_window/indoor", + "/b/bowling_alley", + "/b/boxing_ring", + "/b/bridge", + "/b/building_facade", + "/b/bullring", + "/b/burial_chamber", + "/b/bus_interior", + "/b/bus_station/indoor", + "/b/butchers_shop", + "/b/butte", + "/c/cabin/outdoor", + "/c/cafeteria", + "/c/campsite", + "/c/campus", + "/c/canal/natural", + "/c/canal/urban", + "/c/candy_store", + "/c/canyon", + "/c/car_interior", + "/c/carrousel", + "/c/castle", + "/c/catacomb", + "/c/cemetery", + "/c/chalet", + "/c/chemistry_lab", + "/c/childs_room", + "/c/church/indoor", + "/c/church/outdoor", + "/c/classroom", + "/c/clean_room", + "/c/cliff", + "/c/closet", + "/c/clothing_store", + "/c/coast", + "/c/cockpit", + "/c/coffee_shop", + "/c/computer_room", + "/c/conference_center", + "/c/conference_room", + "/c/construction_site", + "/c/corn_field", + "/c/corral", + "/c/corridor", + "/c/cottage", + "/c/courthouse", + "/c/courtyard", + "/c/creek", + "/c/crevasse", + "/c/crosswalk", + "/d/dam", + "/d/delicatessen", + "/d/department_store", + "/d/desert/sand", + "/d/desert/vegetation", + "/d/desert_road", + "/d/diner/outdoor", + "/d/dining_hall", + "/d/dining_room", + "/d/discotheque", + "/d/doorway/outdoor", + "/d/dorm_room", + "/d/downtown", + "/d/dressing_room", + "/d/driveway", + "/d/drugstore", + "/e/elevator/door", + "/e/elevator_lobby", + "/e/elevator_shaft", + "/e/embassy", + "/e/engine_room", + "/e/entrance_hall", + "/e/escalator/indoor", + "/e/excavation", + "/f/fabric_store", + "/f/farm", + "/f/fastfood_restaurant", + "/f/field/cultivated", + "/f/field/wild", + "/f/field_road", + "/f/fire_escape", + "/f/fire_station", + "/f/fishpond", + "/f/flea_market/indoor", + "/f/florist_shop/indoor", + "/f/food_court", + "/f/football_field", + "/f/forest/broadleaf", + "/f/forest_path", + "/f/forest_road", + "/f/formal_garden", + "/f/fountain", + "/g/galley", + "/g/garage/indoor", + "/g/garage/outdoor", + "/g/gas_station", + "/g/gazebo/exterior", + "/g/general_store/indoor", + "/g/general_store/outdoor", + "/g/gift_shop", + "/g/glacier", + "/g/golf_course", + "/g/greenhouse/indoor", + "/g/greenhouse/outdoor", + "/g/grotto", + "/g/gymnasium/indoor", + "/h/hangar/indoor", + "/h/hangar/outdoor", + "/h/harbor", + "/h/hardware_store", + "/h/hayfield", + "/h/heliport", + "/h/highway", + "/h/home_office", + "/h/home_theater", + "/h/hospital", + "/h/hospital_room", + "/h/hot_spring", + "/h/hotel/outdoor", + "/h/hotel_room", + "/h/house", + "/h/hunting_lodge/outdoor", + "/i/ice_cream_parlor", + "/i/ice_floe", + "/i/ice_shelf", + "/i/ice_skating_rink/indoor", + "/i/ice_skating_rink/outdoor", + "/i/iceberg", + "/i/igloo", + "/i/industrial_area", + "/i/inn/outdoor", + "/i/islet", + "/j/jacuzzi/indoor", + "/j/jail_cell", + "/j/japanese_garden", + "/j/jewelry_shop", + "/j/junkyard", + "/k/kasbah", + "/k/kennel/outdoor", + "/k/kindergarden_classroom", + "/k/kitchen", + "/l/lagoon", + "/l/lake/natural", + "/l/landfill", + "/l/landing_deck", + "/l/laundromat", + "/l/lawn", + "/l/lecture_room", + "/l/legislative_chamber", + "/l/library/indoor", + "/l/library/outdoor", + "/l/lighthouse", + "/l/living_room", + "/l/loading_dock", + "/l/lobby", + "/l/lock_chamber", + "/l/locker_room", + "/m/mansion", + "/m/manufactured_home", + "/m/market/indoor", + "/m/market/outdoor", + "/m/marsh", + "/m/martial_arts_gym", + "/m/mausoleum", + "/m/medina", + "/m/mezzanine", + "/m/moat/water", + "/m/mosque/outdoor", + "/m/motel", + "/m/mountain", + "/m/mountain_path", + "/m/mountain_snowy", + "/m/movie_theater/indoor", + "/m/museum/indoor", + "/m/museum/outdoor", + "/m/music_studio", + "/n/natural_history_museum", + "/n/nursery", + "/n/nursing_home", + "/o/oast_house", + "/o/ocean", + "/o/office", + "/o/office_building", + "/o/office_cubicles", + "/o/oilrig", + "/o/operating_room", + "/o/orchard", + "/o/orchestra_pit", + "/p/pagoda", + "/p/palace", + "/p/pantry", + "/p/park", + "/p/parking_garage/indoor", + "/p/parking_garage/outdoor", + "/p/parking_lot", + "/p/pasture", + "/p/patio", + "/p/pavilion", + "/p/pet_shop", + "/p/pharmacy", + "/p/phone_booth", + "/p/physics_laboratory", + "/p/picnic_area", + "/p/pier", + "/p/pizzeria", + "/p/playground", + "/p/playroom", + "/p/plaza", + "/p/pond", + "/p/porch", + "/p/promenade", + "/p/pub/indoor", + "/r/racecourse", + "/r/raceway", + "/r/raft", + "/r/railroad_track", + "/r/rainforest", + "/r/reception", + "/r/recreation_room", + "/r/repair_shop", + "/r/residential_neighborhood", + "/r/restaurant", + "/r/restaurant_kitchen", + "/r/restaurant_patio", + "/r/rice_paddy", + "/r/river", + "/r/rock_arch", + "/r/roof_garden", + "/r/rope_bridge", + "/r/ruin", + "/r/runway", + "/s/sandbox", + "/s/sauna", + "/s/schoolhouse", + "/s/science_museum", + "/s/server_room", + "/s/shed", + "/s/shoe_shop", + "/s/shopfront", + "/s/shopping_mall/indoor", + "/s/shower", + "/s/ski_resort", + "/s/ski_slope", + "/s/sky", + "/s/skyscraper", + "/s/slum", + "/s/snowfield", + "/s/soccer_field", + "/s/stable", + "/s/stadium/baseball", + "/s/stadium/football", + "/s/stadium/soccer", + "/s/stage/indoor", + "/s/stage/outdoor", + "/s/staircase", + "/s/storage_room", + "/s/street", + "/s/subway_station/platform", + "/s/supermarket", + "/s/sushi_bar", + "/s/swamp", + "/s/swimming_hole", + "/s/swimming_pool/indoor", + "/s/swimming_pool/outdoor", + "/s/synagogue/outdoor", + "/t/television_room", + "/t/television_studio", + "/t/temple/asia", + "/t/throne_room", + "/t/ticket_booth", + "/t/topiary_garden", + "/t/tower", + "/t/toyshop", + "/t/train_interior", + "/t/train_station/platform", + "/t/tree_farm", + "/t/tree_house", + "/t/trench", + "/t/tundra", + "/u/underwater/ocean_deep", + "/u/utility_room", + "/v/valley", + "/v/vegetable_garden", + "/v/veterinarians_office", + "/v/viaduct", + "/v/village", + "/v/vineyard", + "/v/volcano", + "/v/volleyball_court/outdoor", + "/w/waiting_room", + "/w/water_park", + "/w/water_tower", + "/w/waterfall", + "/w/watering_hole", + "/w/wave", + "/w/wet_bar", + "/w/wheat_field", + "/w/wind_farm", + "/w/windmill", + "/y/yard", + "/y/youth_hostel", + "/z/zen_garden", +] diff --git a/tests/optim/models/test_models.py b/tests/optim/models/test_models.py index 509009141a..635b331094 100644 --- a/tests/optim/models/test_models.py +++ b/tests/optim/models/test_models.py @@ -4,7 +4,7 @@ import torch -from captum.optim.models import googlenet +from captum.optim.models import googlenet, googlenet_places365 from captum.optim.models._common import RedirectedReluLayer, SkipLayer from tests.helpers.basic import BaseTest, assertTensorAlmostEqual @@ -39,8 +39,8 @@ class TestInceptionV1(BaseTest): def test_load_inceptionv1_with_redirected_relu(self) -> None: if torch.__version__ <= "1.2.0": raise unittest.SkipTest( - "Skipping load pretrained inception" - + " due to insufficient Torch version." + "Skipping load pretrained InceptionV1 test due to insufficient Torch" + + " version." ) model = googlenet(pretrained=True, replace_relus_with_redirectedrelu=True) _check_layer_in_model(self, model, RedirectedReluLayer) @@ -48,8 +48,8 @@ def test_load_inceptionv1_with_redirected_relu(self) -> None: def test_load_inceptionv1_no_redirected_relu(self) -> None: if torch.__version__ <= "1.2.0": raise unittest.SkipTest( - "Skipping load pretrained inception RedirectedRelu" - + " due to insufficient Torch version." + "Skipping load pretrained InceptionV1 RedirectedRelu test due to" + + " insufficient Torch version." ) model = googlenet(pretrained=True, replace_relus_with_redirectedrelu=False) _check_layer_not_in_model(self, model, RedirectedReluLayer) @@ -58,8 +58,8 @@ def test_load_inceptionv1_no_redirected_relu(self) -> None: def test_load_inceptionv1_linear(self) -> None: if torch.__version__ <= "1.2.0": raise unittest.SkipTest( - "Skipping load pretrained inception linear" - + " due to insufficient Torch version." + "Skipping load pretrained InceptionV1 linear test due to insufficient" + + " Torch version." ) model = googlenet(pretrained=True, use_linear_modules_only=True) _check_layer_not_in_model(self, model, RedirectedReluLayer) @@ -68,11 +68,11 @@ def test_load_inceptionv1_linear(self) -> None: _check_layer_in_model(self, model, SkipLayer) _check_layer_in_model(self, model, torch.nn.AvgPool2d) - def test_transform_inceptionv1(self) -> None: + def test_inceptionv1_transform(self) -> None: if torch.__version__ <= "1.2.0": raise unittest.SkipTest( - "Skipping inceptionV1 internal transform" - + " due to insufficient Torch version." + "Skipping InceptionV1 internal transform test due to insufficient" + + " Torch version." ) x = torch.randn(1, 3, 224, 224).clamp(0, 1) model = googlenet(pretrained=True) @@ -80,11 +80,24 @@ def test_transform_inceptionv1(self) -> None: expected_output = x * 255 - 117 assertTensorAlmostEqual(self, output, expected_output, 0) - def test_transform_bgr_inceptionv1(self) -> None: + def test_inceptionv1_transform_warning(self) -> None: if torch.__version__ <= "1.2.0": raise unittest.SkipTest( - "Skipping inceptionV1 internal transform" - + " BGR due to insufficient Torch version." + "Skipping InceptionV1 internal transform warning test due to" + + " insufficient Torch version." + ) + x = torch.stack( + [torch.ones(3, 112, 112) * -1, torch.ones(3, 112, 112) * 2], dim=0 + ) + model = googlenet(pretrained=True) + with self.assertWarns(UserWarning): + model._transform_input(x) + + def test_inceptionv1_transform_bgr(self) -> None: + if torch.__version__ <= "1.2.0": + raise unittest.SkipTest( + "Skipping InceptionV1 internal transform BGR test due to insufficient" + + " Torch version." ) x = torch.randn(1, 3, 224, 224).clamp(0, 1) model = googlenet(pretrained=True, bgr_transform=True) @@ -92,45 +105,229 @@ def test_transform_bgr_inceptionv1(self) -> None: expected_output = x[:, [2, 1, 0]] * 255 - 117 assertTensorAlmostEqual(self, output, expected_output, 0) - def test_load_and_forward_basic_inceptionv1(self) -> None: + def test_inceptionv1_forward(self) -> None: if torch.__version__ <= "1.2.0": raise unittest.SkipTest( - "Skipping basic pretrained inceptionV1 forward" - + " due to insufficient Torch version." + "Skipping pretrained InceptionV1 forward test due to insufficient" + + " Torch version." ) - x = torch.randn(1, 3, 224, 224).clamp(0, 1) + x = torch.zeros(1, 3, 224, 224) model = googlenet(pretrained=True) - try: - model(x) - test = True - except Exception: - test = False - self.assertTrue(test) - - def test_load_and_forward_diff_sizes_inceptionv1(self) -> None: + outputs = model(x) + self.assertEqual(list(outputs.shape), [1, 1008]) + + def test_inceptionv1_load_and_forward_diff_sizes(self) -> None: if torch.__version__ <= "1.2.0": raise unittest.SkipTest( - "Skipping pretrained inceptionV1 forward with different sized inputs" + "Skipping pretrained InceptionV1 forward with different sized inputs" + " due to insufficient Torch version." ) - x = torch.randn(1, 3, 512, 512).clamp(0, 1) - x2 = torch.randn(1, 3, 383, 511).clamp(0, 1) + x = torch.zeros(1, 3, 512, 512) + x2 = torch.zeros(1, 3, 383, 511) model = googlenet(pretrained=True) - try: - model(x) - model(x2) - test = True - except Exception: - test = False - self.assertTrue(test) - - def test_forward_aux_inceptionv1(self) -> None: + outputs = model(x) + outputs2 = model(x2) + self.assertEqual(list(outputs.shape), [1, 1008]) + self.assertEqual(list(outputs2.shape), [1, 1008]) + + def test_inceptionv1_forward_aux(self) -> None: if torch.__version__ <= "1.2.0": raise unittest.SkipTest( - "Skipping pretrained inceptionV1 with aux logits forward" + "Skipping pretrained InceptionV1 with aux logits forward due to" + + " insufficient Torch version." + ) + x = torch.zeros(1, 3, 224, 224) + model = googlenet(pretrained=False, aux_logits=True) + outputs = model(x) + self.assertEqual([list(o.shape) for o in outputs], [[1, 1008]] * 3) + + def test_inceptionv1_forward_cuda(self) -> None: + if torch.__version__ <= "1.2.0": + raise unittest.SkipTest( + "Skipping pretrained InceptionV1 forward CUDA test due to insufficient" + + " Torch version." + ) + if not torch.cuda.is_available(): + raise unittest.SkipTest( + "Skipping pretrained InceptionV1 forward CUDA test due to not" + + " supporting CUDA." + ) + x = torch.zeros(1, 3, 224, 224).cuda() + model = googlenet(pretrained=True).cuda() + outputs = model(x) + self.assertTrue(outputs.is_cuda) + self.assertEqual(list(outputs.shape), [1, 1008]) + + def test_inceptionv1_load_and_jit_module(self) -> None: + if torch.__version__ <= "1.8.0": + raise unittest.SkipTest( + "Skipping pretrained InceptionV1 load & JIT test" + + " due to insufficient Torch version." + ) + x = torch.zeros(1, 3, 224, 224) + model = googlenet(pretrained=True) + jit_model = torch.jit.script(model) + outputs = jit_model(x) + self.assertEqual(list(outputs.shape), [1, 1008]) + + def test_inceptionv1_load_and_jit_module_no_redirected_relu(self) -> None: + if torch.__version__ <= "1.8.0": + raise unittest.SkipTest( + "Skipping pretrained InceptionV1 load & JIT with no" + + " redirected relu test due to insufficient Torch version." + ) + x = torch.zeros(1, 3, 224, 224) + model = googlenet(pretrained=True, replace_relus_with_redirectedrelu=False) + jit_model = torch.jit.script(model) + outputs = jit_model(x) + self.assertEqual(list(outputs.shape), [1, 1008]) + + +class TestInceptionV1Places365(BaseTest): + def test_load_inceptionv1_places365_with_redirected_relu(self) -> None: + if torch.__version__ <= "1.6.0": + raise unittest.SkipTest( + "Skipping load pretrained InceptionV1 Places365 due to insufficient" + + " Torch version." + ) + model = googlenet_places365( + pretrained=True, replace_relus_with_redirectedrelu=True + ) + _check_layer_in_model(self, model, RedirectedReluLayer) + + def test_load_inceptionv1_places365_no_redirected_relu(self) -> None: + if torch.__version__ <= "1.6.0": + raise unittest.SkipTest( + "Skipping load pretrained InceptionV1 Places365 RedirectedRelu test" + " due to insufficient Torch version." ) + model = googlenet_places365( + pretrained=True, replace_relus_with_redirectedrelu=False + ) + _check_layer_not_in_model(self, model, RedirectedReluLayer) + _check_layer_in_model(self, model, torch.nn.ReLU) + + def test_load_inceptionv1_places365_linear(self) -> None: + if torch.__version__ <= "1.6.0": + raise unittest.SkipTest( + "Skipping load pretrained InceptionV1 Places365 linear test due to" + + " insufficient Torch version." + ) + model = googlenet_places365(pretrained=True, use_linear_modules_only=True) + _check_layer_not_in_model(self, model, RedirectedReluLayer) + _check_layer_not_in_model(self, model, torch.nn.ReLU) + _check_layer_not_in_model(self, model, torch.nn.MaxPool2d) + _check_layer_in_model(self, model, SkipLayer) + _check_layer_in_model(self, model, torch.nn.AvgPool2d) + + def test_inceptionv1_places365_transform(self) -> None: + if torch.__version__ <= "1.6.0": + raise unittest.SkipTest( + "Skipping InceptionV1 Places365 internal transform test due to" + + " insufficient Torch version." + ) x = torch.randn(1, 3, 224, 224).clamp(0, 1) - model = googlenet(pretrained=False, aux_logits=True) + model = googlenet_places365(pretrained=True) + output = model._transform_input(x) + expected_output = x * 255 - torch.tensor( + [116.7894, 112.6004, 104.0437], device=x.device + ).view(3, 1, 1) + expected_output = expected_output[:, [2, 1, 0]] + assertTensorAlmostEqual(self, output, expected_output, 0) + + def test_inceptionv1_places365_transform_warning(self) -> None: + if torch.__version__ <= "1.6.0": + raise unittest.SkipTest( + "Skipping InceptionV1 Places365 internal transform warning test due" + + " to insufficient Torch version." + ) + x = torch.stack( + [torch.ones(3, 112, 112) * -1, torch.ones(3, 112, 112) * 2], dim=0 + ) + model = googlenet_places365(pretrained=True) + with self.assertWarns(UserWarning): + model._transform_input(x) + + def test_inceptionv1_places365_load_and_forward(self) -> None: + if torch.__version__ <= "1.6.0": + raise unittest.SkipTest( + "Skipping basic pretrained InceptionV1 Places365 forward test due to" + + " insufficient Torch version." + ) + x = torch.zeros(1, 3, 224, 224) + model = googlenet_places365(pretrained=True) outputs = model(x) - self.assertEqual(len(outputs), 3) + self.assertEqual([list(o.shape) for o in outputs], [[1, 365]] * 3) + + def test_inceptionv1_places365_load_and_forward_diff_sizes(self) -> None: + if torch.__version__ <= "1.6.0": + raise unittest.SkipTest( + "Skipping pretrained InceptionV1 Places365 forward with different" + + " sized inputs test due to insufficient Torch version." + ) + x = torch.zeros(1, 3, 512, 512) + x2 = torch.zeros(1, 3, 383, 511) + model = googlenet_places365(pretrained=True) + + outputs = model(x) + outputs2 = model(x2) + + self.assertEqual([list(o.shape) for o in outputs], [[1, 365]] * 3) + self.assertEqual([list(o.shape) for o in outputs2], [[1, 365]] * 3) + + def test_inceptionv1_places365_forward_no_aux(self) -> None: + if torch.__version__ <= "1.6.0": + raise unittest.SkipTest( + "Skipping pretrained InceptionV1 Places365 with aux logits forward" + + " test due to insufficient Torch version." + ) + x = torch.zeros(1, 3, 224, 224) + model = googlenet_places365(pretrained=False, aux_logits=False) + outputs = model(x) + self.assertEqual(list(outputs.shape), [1, 365]) + + def test_inceptionv1_places365_forward_cuda(self) -> None: + if torch.__version__ <= "1.6.0": + raise unittest.SkipTest( + "Skipping pretrained InceptionV1 Places365 forward CUDA test due to" + + " insufficient Torch version." + ) + if not torch.cuda.is_available(): + raise unittest.SkipTest( + "Skipping pretrained InceptionV1 Places365 forward CUDA test due to" + + " not supporting CUDA." + ) + x = torch.zeros(1, 3, 224, 224).cuda() + model = googlenet_places365(pretrained=True).cuda() + outputs = model(x) + + self.assertTrue(outputs[0].is_cuda) + self.assertTrue(outputs[1].is_cuda) + self.assertTrue(outputs[2].is_cuda) + self.assertEqual([list(o.shape) for o in outputs], [[1, 365]] * 3) + + def test_inceptionv1_places365_load_and_jit_module(self) -> None: + if torch.__version__ <= "1.8.0": + raise unittest.SkipTest( + "Skipping pretrained InceptionV1 Places365 load & JIT module test" + + " due to insufficient Torch version." + ) + x = torch.zeros(1, 3, 224, 224) + model = googlenet_places365(pretrained=True) + jit_model = torch.jit.script(model) + outputs = jit_model(x) + self.assertEqual([list(o.shape) for o in outputs], [[1, 365]] * 3) + + def test_inceptionv1_places365_jit_module_no_redirected_relu(self) -> None: + if torch.__version__ <= "1.8.0": + raise unittest.SkipTest( + "Skipping pretrained InceptionV1 Places365 load & JIT module with no" + + " redirected relu test due to insufficient Torch version." + ) + x = torch.zeros(1, 3, 224, 224) + model = googlenet_places365( + pretrained=True, replace_relus_with_redirectedrelu=False + ) + jit_model = torch.jit.script(model) + outputs = jit_model(x) + self.assertEqual([list(o.shape) for o in outputs], [[1, 365]] * 3)