Skip to content

Optim-wip: Add the pre-trained InceptionV1 Places365 model #935

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions captum/optim/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@
)
from ._image.inception5h_classes import INCEPTION5H_CLASSES # noqa: F401
from ._image.inception_v1 import InceptionV1, googlenet # noqa: F401
from ._image.inception_v1_places365 import ( # noqa: F401
InceptionV1Places365,
googlenet_places365,
)
from ._image.inception_v1_places365_classes import ( # noqa: F401
INCEPTIONV1_PLACES365_CLASSES,
)

__all__ = [
"RedirectedReluLayer",
Expand All @@ -19,4 +26,7 @@
"InceptionV1",
"googlenet",
"INCEPTION5H_CLASSES",
"InceptionV1Places365",
"googlenet_places365",
"INCEPTIONV1_PLACES365_CLASSES",
]
1 change: 1 addition & 0 deletions captum/optim/models/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class RedirectedReluLayer(nn.Module):
Class for applying RedirectedReLU
"""

@torch.jit.ignore
def forward(self, input: torch.Tensor) -> torch.Tensor:
return RedirectedReLU.apply(input)

Expand Down
154 changes: 126 additions & 28 deletions captum/optim/models/_image/inception_v1.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Optional, Tuple, Type, Union, cast
from typing import Optional, Tuple, Type, Union
from warnings import warn

import torch
import torch.nn as nn
Expand All @@ -19,24 +20,37 @@ def googlenet(
) -> "InceptionV1":
r"""GoogLeNet (also known as Inception v1 & Inception 5h) model architecture from
`"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.

Args:

pretrained (bool, optional): If True, returns a model pre-trained on ImageNet.
Default: False
progress (bool, optional): If True, displays a progress bar of the download to
stderr
Default: True
model_path (str, optional): Optional path for InceptionV1 model file.
Default: None
replace_relus_with_redirectedrelu (bool, optional): If True, return pretrained
model with Redirected ReLU in place of ReLU layers.
Default: *True* when pretrained is True otherwise *False*
use_linear_modules_only (bool, optional): If True, return pretrained
model with all nonlinear layers replaced with linear equivalents.
Default: False
aux_logits (bool, optional): If True, adds two auxiliary branches that can
improve training. Default: *False* when pretrained is True otherwise *True*
improve training.
Default: False
out_features (int, optional): Number of output features in the model used for
training. Default: 1008 when pretrained is True.
training.
Default: 1008
transform_input (bool, optional): If True, preprocesses the input according to
the method with which it was trained on ImageNet. Default: *False*
the method with which it was trained on ImageNet.
Default: False
bgr_transform (bool, optional): If True and transform_input is True, perform an
RGB to BGR transform in the internal preprocessing.
Default: *False*
Default: False

Returns:
**InceptionV1** (InceptionV1): An Inception5h model.
"""

if pretrained:
Expand Down Expand Up @@ -69,6 +83,8 @@ def googlenet(

# Better version of Inception V1 / GoogleNet for Inception5h
class InceptionV1(nn.Module):
__constants__ = ["aux_logits", "transform_input", "bgr_transform"]

def __init__(
self,
out_features: int = 1008,
Expand All @@ -78,7 +94,29 @@ def __init__(
replace_relus_with_redirectedrelu: bool = False,
use_linear_modules_only: bool = False,
) -> None:
super(InceptionV1, self).__init__()
"""
Args:

replace_relus_with_redirectedrelu (bool, optional): If True, return
pretrained model with Redirected ReLU in place of ReLU layers.
Default: False
use_linear_modules_only (bool, optional): If True, return pretrained
model with all nonlinear layers replaced with linear equivalents.
Default: False
aux_logits (bool, optional): If True, adds two auxiliary branches that can
improve training.
Default: False
out_features (int, optional): Number of output features in the model used
for training.
Default: 1008
transform_input (bool, optional): If True, preprocesses the input according
to the method with which it was trained on ImageNet.
Default: False
bgr_transform (bool, optional): If True and transform_input is True,
perform an RGB to BGR transform in the internal preprocessing.
Default: False
"""
super().__init__()
self.aux_logits = aux_logits
self.transform_input = transform_input
self.bgr_transform = bgr_transform
Expand All @@ -99,7 +137,6 @@ def __init__(
out_channels=64,
kernel_size=(7, 7),
stride=(2, 2),
padding=3,
groups=1,
bias=True,
)
Expand All @@ -121,7 +158,6 @@ def __init__(
out_channels=192,
kernel_size=(3, 3),
stride=(1, 1),
padding=1,
groups=1,
bias=True,
)
Expand Down Expand Up @@ -163,9 +199,18 @@ def __init__(
self.fc = nn.Linear(1024, out_features)

def _transform_input(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:

x (torch.Tensor): An input tensor to normalize and scale the values of.

Returns:
x (torch.Tensor): A transformed tensor.
"""
if self.transform_input:
assert x.dim() == 3 or x.dim() == 4
assert x.min() >= 0.0 and x.max() <= 1.0
if x.min() < 0.0 or x.max() > 1.0:
warn("Model input has values outside of the range [0, 1].")
x = x.unsqueeze(0) if x.dim() == 3 else x
x = x * 255 - 117
x = x[:, [2, 1, 0]] if self.bgr_transform else x
Expand All @@ -174,6 +219,15 @@ def _transform_input(self, x: torch.Tensor) -> torch.Tensor:
def forward(
self, x: torch.Tensor
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""
Args:

x (torch.Tensor): An input tensor to normalize and scale the values of.

Returns:
x (torch.Tensor or tuple of torch.Tensor): A single or multiple output
tensors from the model.
"""
x = self._transform_input(x)
x = self.conv1(x)
x = self.conv1_relu(x)
Expand Down Expand Up @@ -212,7 +266,7 @@ def forward(
x = self.drop(x)
x = self.fc(x)
if not self.aux_logits:
return cast(torch.Tensor, x)
return x
else:
return x, aux1_output, aux2_output

Expand All @@ -230,7 +284,25 @@ def __init__(
activ: Type[nn.Module] = nn.ReLU,
p_layer: Type[nn.Module] = nn.MaxPool2d,
) -> None:
super(InceptionModule, self).__init__()
"""
Args:

in_channels (int, optional): The number of input channels to use for the
inception module.
c1x1 (int, optional):
c3x3reduce (int, optional):
c3x3 (int, optional):
c5x5reduce (int, optional):
c5x5 (int, optional):
pool_proj (int, optional):
activ (type of nn.Module, optional): The nn.Module class type to use for
activation layers.
Default: nn.ReLU
p_layer (type of nn.Module, optional): The nn.Module class type to use for
pooling layers.
Default: nn.MaxPool2d
"""
super().__init__()
self.conv_1x1 = nn.Conv2d(
in_channels=in_channels,
out_channels=c1x1,
Expand All @@ -254,7 +326,6 @@ def __init__(
out_channels=c3x3,
kernel_size=(3, 3),
stride=(1, 1),
padding=1,
groups=1,
bias=True,
)
Expand All @@ -273,7 +344,6 @@ def __init__(
out_channels=c5x5,
kernel_size=(5, 5),
stride=(1, 1),
padding=1,
groups=1,
bias=True,
)
Expand All @@ -289,6 +359,14 @@ def __init__(
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:

x (torch.Tensor): An input tensor to pass through the Inception Module.

Returns:
x (torch.Tensor): The output tensor of the Inception Module.
"""
c1x1 = self.conv_1x1(x)

c3x3 = self.conv_3x3_reduce(x)
Expand All @@ -311,31 +389,51 @@ def __init__(
out_features: int = 1008,
activ: Type[nn.Module] = nn.ReLU,
) -> None:
super(AuxBranch, self).__init__()
"""
Args:

in_channels (int, optional): The number of input channels to use for the
auxiliary branch.
Default: 508
out_features (int, optional): The number of output features to use for the
auxiliary branch.
Default: 1008
activ (type of nn.Module, optional): The nn.Module class type to use for
activation layers.
Default: nn.ReLU
"""
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d((4, 4))
self.loss_conv = nn.Conv2d(
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=128,
kernel_size=(1, 1),
stride=(1, 1),
groups=1,
bias=True,
)
self.loss_conv_relu = activ()
self.loss_fc = nn.Linear(in_features=2048, out_features=1024, bias=True)
self.loss_fc_relu = activ()
self.loss_dropout = nn.Dropout(0.699999988079071)
self.loss_classifier = nn.Linear(
in_features=1024, out_features=out_features, bias=True
)
self.conv_relu = activ()
self.fc1 = nn.Linear(in_features=2048, out_features=1024, bias=True)
self.fc1_relu = activ()
self.dropout = nn.Dropout(0.699999988079071)
self.fc2 = nn.Linear(in_features=1024, out_features=out_features, bias=True)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:

x (torch.Tensor): An input tensor to pass through the auxiliary branch
module.

Returns:
x (torch.Tensor): The output tensor of the auxiliary branch module.
"""
x = self.avg_pool(x)
x = self.loss_conv(x)
x = self.loss_conv_relu(x)
x = self.conv(x)
x = self.conv_relu(x)
x = torch.flatten(x, 1)
x = self.loss_fc(x)
x = self.loss_fc_relu(x)
x = self.loss_dropout(x)
x = self.loss_classifier(x)
x = self.fc1(x)
x = self.fc1_relu(x)
x = self.dropout(x)
x = self.fc2(x)
return x
Loading