Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
6892f3c
Add model linearization, and expanded weights spatial positions
ProGamerGov Dec 27, 2020
336c62f
Spelling fixes and add top channels to tutorial
ProGamerGov Dec 28, 2020
dba1e48
Optionally remove nonlinear ReLU & full expanded weights test
ProGamerGov Dec 29, 2020
83bc5b8
Use torch.norm for PyTorch 1.3.0 only
ProGamerGov Dec 29, 2020
94fe929
Improve weight vis tutorial descriptions
ProGamerGov Dec 30, 2020
690fc79
Move ignore_layer & max2avg_pool2d to utils/models
ProGamerGov Dec 31, 2020
caf0cc7
Improve weight vis tutorial
ProGamerGov Jan 1, 2021
620f1ee
Improve intro & expanded weight description
ProGamerGov Jan 1, 2021
7515fd2
Improvements
ProGamerGov Jan 2, 2021
2e86d98
Merge branch 'optim-wip' into optim-wip-circuits
ProGamerGov Jan 4, 2021
752186a
Replace round with math.ceil in CenterCrop
ProGamerGov Jan 6, 2021
4116ecd
Add optional center crop offset parameter for uneven sides
ProGamerGov Jan 6, 2021
f1f73b0
Improve center crop parameter description
ProGamerGov Jan 8, 2021
ebaacbd
Update weight vis tutorial for new center crop
ProGamerGov Jan 8, 2021
891fd97
Remove suppression of PyTorch UserWarnings
ProGamerGov Jan 9, 2021
18a9d8c
Changes based on feedback
ProGamerGov Jan 9, 2021
49bee39
Update model factory tests
ProGamerGov Jan 9, 2021
70bd89c
Merge branch 'optim-wip' of https://github.com/pytorch/captum into op…
Jan 11, 2021
899ee6f
Update weight vis tutorial with colorspace fix
ProGamerGov Jan 11, 2021
bf5b991
Improve weight vis tutorial
ProGamerGov Jan 15, 2021
9a59014
Minor fixes & improvements to tutorial notebook
ProGamerGov Jan 16, 2021
790066b
Link spatial positions back to weight heatmap
ProGamerGov Jan 16, 2021
441b751
Test version check improvements
ProGamerGov Jan 16, 2021
7212dd8
Changes based on feedback
ProGamerGov Jan 20, 2021
7071396
Fix tests and InceptionV1 model
ProGamerGov Jan 20, 2021
c141bd7
Remove non-working check
ProGamerGov Jan 20, 2021
4237205
Changes based on feedback
ProGamerGov Jan 21, 2021
58a8c3b
Remove redundant skip_layer function
ProGamerGov Jan 21, 2021
4d3c686
Re-add skip_layer function
ProGamerGov Jan 21, 2021
2570eed
Improve replace layers
ProGamerGov Jan 22, 2021
f447533
Remove param transfer from skip_layers
ProGamerGov Jan 22, 2021
d472a91
Changes based on feedback part 1
ProGamerGov Jan 24, 2021
88a88ed
Remove placeholder Any type hints & fix instance creation
ProGamerGov Jan 24, 2021
908bec0
Add type hints for layers, model, and layer instances
ProGamerGov Jan 24, 2021
075ac59
max2avg_pool2d -> replace_max_with_avgconst_pool2d
ProGamerGov Jan 24, 2021
73175c0
Improve _check_layer_in_model test
ProGamerGov Jan 24, 2021
537fe79
Remove unused type hint import
ProGamerGov Jan 24, 2021
0a8b6e2
Revert layer check test and add type hints
ProGamerGov Jan 24, 2021
866faac
Add number of expected layers to tests
ProGamerGov Jan 24, 2021
ba07685
Change _check_layer_in_model based on feedback
ProGamerGov Jan 24, 2021
130513b
Changes to tutorial based on feedback
ProGamerGov Jan 25, 2021
fae868b
Address notebook feedback - part 2
ProGamerGov Jan 25, 2021
b9cbf02
Better wording for new sentence
ProGamerGov Jan 25, 2021
4d85d13
get_expanded_weights -> extract_expanded_weights
ProGamerGov Jan 26, 2021
83921ed
Improve weight vis notebook introduction
ProGamerGov Jan 26, 2021
f45b208
Change NMF link location
ProGamerGov Jan 26, 2021
bc7fa97
Add return to _check_layer_in_model test
ProGamerGov Jan 26, 2021
162e47a
Remove return
ProGamerGov Jan 26, 2021
2fc46d0
GPU Test Fix
ProGamerGov Jan 26, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ commands:
steps:
- run:
name: "Simple PIP install"
command: python -m pip install -e .[dev]
command: |
python -m pip install --upgrade pip
python -m pip install -e .[dev]

py_3_7_setup:
description: "Set python version to 3.7 and install pip and pytest"
Expand All @@ -112,8 +114,6 @@ commands:
command: |
pyenv versions
pyenv global 3.7.0
sudo python -m pip install --upgrade pip
sudo python -m pip install pytest

install_cuda:
description: "Install CUDA for GPU Machine"
Expand Down
127 changes: 82 additions & 45 deletions captum/optim/_models/inception_v1.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
from typing import Tuple, Union, cast
from typing import Optional, Tuple, Union, cast

import torch
import torch.nn as nn
import torch.nn.functional as F

import captum.optim._utils.models as model_utils
from captum.optim._utils.models import (
AvgPool2dConstrained,
LocalResponseNormLayer,
RedirectedReluLayer,
ReluLayer,
SkipLayer,
)

GS_SAVED_WEIGHTS_URL = (
"https://github.com/pytorch/captum/raw/"
Expand All @@ -13,29 +19,42 @@


def googlenet(
pretrained: bool = False, progress: bool = True, model_path: str = None, **kwargs
pretrained: bool = False,
progress: bool = True,
model_path: Optional[str] = None,
**kwargs
):
r"""GoogLeNet (also known as Inception v1 & Inception 5h) model architecture from
`"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
model_path (str): Optional path for InceptionV1 model file
aux_logits (bool): If True, adds two auxiliary branches that can improve
training. Default: *False* when pretrained is True otherwise *True*
out_features (int): Number of output features in the model used for
pretrained (bool, optional): If True, returns a model pre-trained on ImageNet.
progress (bool, optional): If True, displays a progress bar of the download to
stderr
model_path (str, optional): Optional path for InceptionV1 model file.
replace_relus_with_redirectedrelu (bool, optional): If True, return pretrained
model with Redirected ReLU in place of ReLU layers.
use_linear_modules_only (bool, optional): If True, return pretrained
model with all nonlinear layers replaced with linear equivalents.
aux_logits (bool, optional): If True, adds two auxiliary branches that can
improve training. Default: *False* when pretrained is True otherwise *True*
out_features (int, optional): Number of output features in the model used for
training. Default: 1008 when pretrained is True.
transform_input (bool): If True, preprocesses the input according to
transform_input (bool, optional): If True, preprocesses the input according to
the method with which it was trained on ImageNet. Default: *False*
bgr_transform (bool): If True and transform_input is True, perform an
bgr_transform (bool, optional): If True and transform_input is True, perform an
RGB to BGR transform in the internal preprocessing.
Default: *False*
"""

if pretrained:
if "transform_input" not in kwargs:
kwargs["transform_input"] = True
if "bgr_transform" not in kwargs:
kwargs["bgr_transform"] = False
if "replace_relus_with_redirectedrelu" not in kwargs:
kwargs["replace_relus_with_redirectedrelu"] = True
if "use_linear_modules_only" not in kwargs:
kwargs["use_linear_modules_only"] = False
if "aux_logits" not in kwargs:
kwargs["aux_logits"] = False
if "out_features" not in kwargs:
Expand All @@ -50,27 +69,38 @@ def googlenet(
else:
state_dict = torch.load(model_path, map_location="cpu")
model.load_state_dict(state_dict)
model_utils.replace_layers(model)
return model

return InceptionV1(**kwargs)


# Better version of Inception V1/GoogleNet for Inception5h
# Better version of Inception V1 / GoogleNet for Inception5h
class InceptionV1(nn.Module):
def __init__(
self,
out_features: int = 1008,
aux_logits: bool = False,
transform_input: bool = False,
bgr_transform: bool = False,
replace_relus_with_redirectedrelu: bool = False,
use_linear_modules_only: bool = False,
) -> None:
super(InceptionV1, self).__init__()
self.aux_logits = aux_logits
self.transform_input = transform_input
self.bgr_transform = bgr_transform
lrn_vals = (9, 9.99999974738e-05, 0.5, 1.0)

if use_linear_modules_only:
activ = SkipLayer
pool = AvgPool2dConstrained
else:
if replace_relus_with_redirectedrelu:
activ = RedirectedReluLayer
else:
activ = ReluLayer
pool = nn.MaxPool2d

self.conv1 = nn.Conv2d(
in_channels=3,
out_channels=64,
Expand All @@ -79,9 +109,9 @@ def __init__(
groups=1,
bias=True,
)
self.conv1_relu = model_utils.ReluLayer()
self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
self.localresponsenorm1 = model_utils.LocalResponseNormLayer(*lrn_vals)
self.conv1_relu = activ()
self.pool1 = pool(kernel_size=3, stride=2, padding=0)
self.local_response_norm1 = LocalResponseNormLayer(*lrn_vals)

self.conv2 = nn.Conv2d(
in_channels=64,
Expand All @@ -91,7 +121,7 @@ def __init__(
groups=1,
bias=True,
)
self.conv2_relu = model_utils.ReluLayer()
self.conv2_relu = activ()
self.conv3 = nn.Conv2d(
in_channels=64,
out_channels=192,
Expand All @@ -100,29 +130,29 @@ def __init__(
groups=1,
bias=True,
)
self.conv3_relu = model_utils.ReluLayer()
self.localresponsenorm2 = model_utils.LocalResponseNormLayer(*lrn_vals)
self.conv3_relu = activ()
self.local_response_norm2 = LocalResponseNormLayer(*lrn_vals)

self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
self.mixed3a = InceptionModule(192, 64, 96, 128, 16, 32, 32)
self.mixed3b = InceptionModule(256, 128, 128, 192, 32, 96, 64)
self.pool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
self.mixed4a = InceptionModule(480, 192, 96, 204, 16, 48, 64)
self.pool2 = pool(kernel_size=3, stride=2, padding=0)
self.mixed3a = InceptionModule(192, 64, 96, 128, 16, 32, 32, activ, pool)
self.mixed3b = InceptionModule(256, 128, 128, 192, 32, 96, 64, activ, pool)
self.pool3 = pool(kernel_size=3, stride=2, padding=0)
self.mixed4a = InceptionModule(480, 192, 96, 204, 16, 48, 64, activ, pool)

if self.aux_logits:
self.aux1 = AuxBranch(508, out_features)
self.aux1 = AuxBranch(508, out_features, activ)

self.mixed4b = InceptionModule(508, 160, 112, 224, 24, 64, 64)
self.mixed4c = InceptionModule(512, 128, 128, 256, 24, 64, 64)
self.mixed4d = InceptionModule(512, 112, 144, 288, 32, 64, 64)
self.mixed4b = InceptionModule(508, 160, 112, 224, 24, 64, 64, activ, pool)
self.mixed4c = InceptionModule(512, 128, 128, 256, 24, 64, 64, activ, pool)
self.mixed4d = InceptionModule(512, 112, 144, 288, 32, 64, 64, activ, pool)

if self.aux_logits:
self.aux2 = AuxBranch(528, out_features)
self.aux2 = AuxBranch(528, out_features, activ)

self.mixed4e = InceptionModule(528, 256, 160, 320, 32, 128, 128)
self.pool4 = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
self.mixed5a = InceptionModule(832, 256, 160, 320, 48, 128, 128)
self.mixed5b = InceptionModule(832, 384, 192, 384, 48, 128, 128)
self.mixed4e = InceptionModule(528, 256, 160, 320, 32, 128, 128, activ, pool)
self.pool4 = pool(kernel_size=3, stride=2, padding=0)
self.mixed5a = InceptionModule(832, 256, 160, 320, 48, 128, 128, activ, pool)
self.mixed5b = InceptionModule(832, 384, 192, 384, 48, 128, 128, activ, pool)

self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.drop = nn.Dropout(0.4000000059604645)
Expand All @@ -146,14 +176,14 @@ def forward(
x = self.conv1_relu(x)
x = F.pad(x, (0, 1, 0, 1), value=float("-inf"))
x = self.pool1(x)
x = self.localresponsenorm1(x)
x = self.local_response_norm1(x)

x = self.conv2(x)
x = self.conv2_relu(x)
x = F.pad(x, (1, 1, 1, 1))
x = self.conv3(x)
x = self.conv3_relu(x)
x = self.localresponsenorm2(x)
x = self.local_response_norm2(x)

x = F.pad(x, (0, 1, 0, 1), value=float("-inf"))
x = self.pool2(x)
Expand Down Expand Up @@ -199,6 +229,8 @@ def __init__(
c5x5reduce: int,
c5x5: int,
pool_proj: int,
activ=ReluLayer,
p_layer=nn.MaxPool2d,
) -> None:
super(InceptionModule, self).__init__()
self.conv_1x1 = nn.Conv2d(
Expand All @@ -209,7 +241,7 @@ def __init__(
groups=1,
bias=True,
)
self.conv_1x1_relu = model_utils.ReluLayer()
self.conv_1x1_relu = activ()

self.conv_3x3_reduce = nn.Conv2d(
in_channels=in_channels,
Expand All @@ -219,7 +251,7 @@ def __init__(
groups=1,
bias=True,
)
self.conv_3x3_reduce_relu = model_utils.ReluLayer()
self.conv_3x3_reduce_relu = activ()
self.conv_3x3 = nn.Conv2d(
in_channels=c3x3reduce,
out_channels=c3x3,
Expand All @@ -228,7 +260,7 @@ def __init__(
groups=1,
bias=True,
)
self.conv_3x3_relu = model_utils.ReluLayer()
self.conv_3x3_relu = activ()

self.conv_5x5_reduce = nn.Conv2d(
in_channels=in_channels,
Expand All @@ -238,7 +270,7 @@ def __init__(
groups=1,
bias=True,
)
self.conv_5x5_reduce_relu = model_utils.ReluLayer()
self.conv_5x5_reduce_relu = activ()
self.conv_5x5 = nn.Conv2d(
in_channels=c5x5reduce,
out_channels=c5x5,
Expand All @@ -247,9 +279,9 @@ def __init__(
groups=1,
bias=True,
)
self.conv_5x5_relu = model_utils.ReluLayer()
self.conv_5x5_relu = activ()

self.pool = nn.MaxPool2d(kernel_size=3, stride=1, padding=0)
self.pool = p_layer(kernel_size=3, stride=1, padding=0)
self.pool_proj = nn.Conv2d(
in_channels=in_channels,
out_channels=pool_proj,
Expand All @@ -258,7 +290,7 @@ def __init__(
groups=1,
bias=True,
)
self.pool_proj_relu = model_utils.ReluLayer()
self.pool_proj_relu = activ()

def forward(self, x: torch.Tensor) -> torch.Tensor:
c1x1 = self.conv_1x1(x)
Expand All @@ -284,7 +316,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


class AuxBranch(nn.Module):
def __init__(self, in_channels: int = 508, out_features: int = 1008) -> None:
def __init__(
self,
in_channels: int = 508,
out_features: int = 1008,
activ=ReluLayer,
) -> None:
super(AuxBranch, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d((4, 4))
self.loss_conv = nn.Conv2d(
Expand All @@ -295,9 +332,9 @@ def __init__(self, in_channels: int = 508, out_features: int = 1008) -> None:
groups=1,
bias=True,
)
self.loss_conv_relu = model_utils.ReluLayer()
self.loss_conv_relu = activ()
self.loss_fc = nn.Linear(in_features=2048, out_features=1024, bias=True)
self.loss_fc_relu = model_utils.ReluLayer()
self.loss_fc_relu = activ()
self.loss_dropout = nn.Dropout(0.699999988079071)
self.loss_classifier = nn.Linear(
in_features=1024, out_features=out_features, bias=True
Expand Down
31 changes: 26 additions & 5 deletions captum/optim/_param/image/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,22 @@ class CenterCrop(torch.nn.Module):
pixels_from_edges (bool, optional): Whether to treat crop size
values as the number of pixels from the tensor's edge, or an
exact shape in the center.
offset_left (bool, optional): If the cropped away sides are not
equal in size, offset center by +1 to the left and/or top.
Default is set to False. This parameter is only valid when
pixels_from_edges is False.
"""

def __init__(
self, size: IntSeqOrIntType = 0, pixels_from_edges: bool = False
self,
size: IntSeqOrIntType = 0,
pixels_from_edges: bool = False,
offset_left: bool = False,
) -> None:
super(CenterCrop, self).__init__()
self.crop_vals = size
self.pixels_from_edges = pixels_from_edges
self.offset_left = offset_left

def forward(self, input: torch.Tensor) -> torch.Tensor:
"""
Expand All @@ -153,11 +161,16 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
tensor (torch.Tensor): A center cropped tensor.
"""

return center_crop(input, self.crop_vals, self.pixels_from_edges)
return center_crop(
input, self.crop_vals, self.pixels_from_edges, self.offset_left
)


def center_crop(
input: torch.Tensor, crop_vals: IntSeqOrIntType, pixels_from_edges: bool = False
input: torch.Tensor,
crop_vals: IntSeqOrIntType,
pixels_from_edges: bool = False,
offset_left: bool = False,
) -> torch.Tensor:
"""
Center crop a specified amount from a tensor.
Expand All @@ -167,6 +180,10 @@ def center_crop(
pixels_from_edges (bool, optional): Whether to treat crop size
values as the number of pixels from the tensor's edge, or an
exact shape in the center.
offset_left (bool, optional): If the cropped away sides are not
equal in size, offset center by +1 to the left and/or top.
Default is set to False. This parameter is only valid when
pixels_from_edges is False.
Returns:
*tensor*: A center cropped tensor.
"""
Expand All @@ -188,8 +205,12 @@ def center_crop(
sw, sh = w // 2 - (w_crop // 2), h // 2 - (h_crop // 2)
x = input[..., sh : sh + h_crop, sw : sw + w_crop]
else:
h_crop = h - int(round((h - crop_vals[0]) / 2.0))
w_crop = w - int(round((w - crop_vals[1]) / 2.0))
h_crop = h - int(math.ceil((h - crop_vals[0]) / 2.0))
w_crop = w - int(math.ceil((w - crop_vals[1]) / 2.0))
if h % 2 == 0 and crop_vals[0] % 2 != 0 or h % 2 != 0 and crop_vals[0] % 2 == 0:
h_crop = h_crop + 1 if offset_left else h_crop
if w % 2 == 0 and crop_vals[1] % 2 != 0 or w % 2 != 0 and crop_vals[1] % 2 == 0:
w_crop = w_crop + 1 if offset_left else w_crop
x = input[..., h_crop - crop_vals[0] : h_crop, w_crop - crop_vals[1] : w_crop]
return x

Expand Down
Loading