Skip to content

Commit 6892f3c

Browse files
authored
Add model linearization, and expanded weights spatial positions
* Optionally replace non-linear MaxPool2d layers with their linear AvgPool2d equivalents. * Added info for how to visualize expanded weights spatial positions in expanded weights / weight vis tutorial.
1 parent c8d90b4 commit 6892f3c

File tree

4 files changed

+194
-21
lines changed

4 files changed

+194
-21
lines changed

captum/optim/_utils/circuits.py

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
from typing import Optional, Tuple, Union
1+
from typing import Any, Optional, Tuple, Union
22

33
import torch
44
import torch.nn as nn
55

66
from captum.optim._param.image.transform import center_crop_shape
77
from captum.optim._utils.models import collect_activations
8-
from captum.optim._utils.typing import ModelInputType, TransformSize
8+
from captum.optim._utils.typing import ModelInputType, PoolParam, TransformSize
99

1010

1111
def get_expanded_weights(
@@ -56,8 +56,57 @@ def get_expanded_weights(
5656
retain_graph=True,
5757
)[0]
5858
A.append(x.squeeze(0))
59-
exapnded_weights = torch.stack(A, 0)
59+
expanded_weights = torch.stack(A, 0)
6060

6161
if crop_shape is not None:
62-
exapnded_weights = center_crop_shape(exapnded_weights, crop_shape)
63-
return exapnded_weights
62+
expanded_weights = center_crop_shape(expanded_weights, crop_shape)
63+
return expanded_weights
64+
65+
66+
def max2avg_pool2d(model, value: Optional[Any] = float("-inf")) -> None:
67+
"""
68+
Replace all non-linear MaxPool2d layers with their linear AvgPool2d equivalents.
69+
This allows us to ignore non-linear values when calculating expanded weights.
70+
71+
Args:
72+
model (nn.Module): A PyTorch model instance.
73+
value (Any): Used to return any padding that's meant to be ignored by
74+
pooling layers back to zero.
75+
"""
76+
77+
class AvgPool2dInf(torch.nn.Module):
78+
def __init__(
79+
self,
80+
kernel_size: PoolParam = 2,
81+
stride: Optional[PoolParam] = 2,
82+
padding: PoolParam = 0,
83+
ceil_mode: bool = False,
84+
value: Optional[Any] = None,
85+
) -> None:
86+
super().__init__()
87+
self.avgpool = torch.nn.AvgPool2d(
88+
kernel_size=kernel_size,
89+
stride=stride,
90+
padding=padding,
91+
ceil_mode=ceil_mode,
92+
)
93+
self.value = value
94+
95+
def forward(self, x: torch.Tensor) -> torch.Tensor:
96+
x = self.avgpool(x)
97+
if self.value is not None:
98+
x[x == self.value] = 0.0
99+
return x
100+
101+
for name, child in model._modules.items():
102+
if isinstance(child, torch.nn.MaxPool2d):
103+
new_layer = AvgPool2dInf(
104+
kernel_size=child.kernel_size,
105+
stride=child.stride,
106+
padding=child.padding,
107+
ceil_mode=child.ceil_mode,
108+
value=value,
109+
)
110+
setattr(model, name, new_layer)
111+
elif child is not None:
112+
max2avg_pool2d(child)

captum/optim/_utils/typing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,4 @@ def cleanup(self):
3939
TransformVal = Union[int, float, Tensor]
4040
TransformSize = Union[List[int], Tuple[int], int]
4141
ModelInputType = Union[Tuple[Tensor], Tensor]
42+
PoolParam = Union[int, Tuple[int, ...]]

tests/optim/utils/test_circuits.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
import unittest
33

44
import torch
5+
import torch.nn.functional as F
56

67
import captum.optim._utils.circuits as circuits
78
from captum.optim._models.inception_v1 import googlenet
8-
from tests.helpers.basic import BaseTest
9+
from tests.helpers.basic import BaseTest, assertTensorAlmostEqual
910

1011

1112
class TestGetExpandedWeights(BaseTest):
@@ -46,5 +47,24 @@ def test_get_expanded_weights_crop_two_int(self) -> None:
4647
self.assertEqual(list(output_tensor.shape), [480, 256, 5, 5])
4748

4849

50+
class TestMax2AvgPool2d(BaseTest):
51+
def test_max2avg_pool2d(self) -> None:
52+
model = torch.nn.Sequential(
53+
torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
54+
)
55+
56+
circuits.max2avg_pool2d(model)
57+
58+
test_tensor = torch.randn(128, 32, 16, 16)
59+
test_tensor = F.pad(test_tensor, (0, 1, 0, 1), value=float("-inf"))
60+
out_tensor = model(test_tensor)
61+
62+
avg_pool = torch.nn.AvgPool2d(kernel_size=3, stride=2, padding=0)
63+
expected_tensor = avg_pool(test_tensor)
64+
expected_tensor[expected_tensor == float("-inf")] = 0.0
65+
66+
assertTensorAlmostEqual(self, out_tensor, expected_tensor, 0)
67+
68+
4969
if __name__ == "__main__":
5070
unittest.main()

tutorials/optimviz/WeightVisualization_OptimViz.ipynb

Lines changed: 118 additions & 15 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)