Skip to content

Commit 2e86d98

Browse files
authored
Merge branch 'optim-wip' into optim-wip-circuits
2 parents 7515fd2 + 885ea4b commit 2e86d98

File tree

12 files changed

+472
-210
lines changed

12 files changed

+472
-210
lines changed

captum/optim/_core/output_hook.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import torch.nn as nn
66

7-
from captum.optim._utils.typing import ModelInputType, ModuleOutputMapping
7+
from captum.optim._utils.typing import ModuleOutputMapping, TupleOfTensorsOrTensorType
88

99

1010
class AbortForwardException(Exception):
@@ -101,7 +101,7 @@ def __init__(self, model, targets: Union[nn.Module, List[nn.Module]]) -> None:
101101
self.model = model
102102
self.layers = ModuleOutputsHook(targets)
103103

104-
def __call__(self, input_t: ModelInputType) -> ModuleOutputMapping:
104+
def __call__(self, input_t: TupleOfTensorsOrTensorType) -> ModuleOutputMapping:
105105
try:
106106
with suppress(AbortForwardException):
107107
self.model(input_t)

captum/optim/_param/image/images.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
print("The Pillow/PIL library is required to use Captum's Optim library")
1414

1515
from captum.optim._param.image.transform import SymmetricPadding, ToRGB
16-
from captum.optim._utils.typing import InitSize, SquashFunc
1716

1817

1918
class ImageTensor(torch.Tensor):
@@ -183,7 +182,7 @@ class FFTImage(ImageParameterization):
183182

184183
def __init__(
185184
self,
186-
size: InitSize = None,
185+
size: Tuple[int, int] = None,
187186
channels: int = 3,
188187
batch: int = 1,
189188
init: Optional[torch.Tensor] = None,
@@ -271,7 +270,7 @@ def forward(self) -> torch.Tensor:
271270
class PixelImage(ImageParameterization):
272271
def __init__(
273272
self,
274-
size: InitSize = None,
273+
size: Tuple[int, int] = None,
275274
channels: int = 3,
276275
batch: int = 1,
277276
init: Optional[torch.Tensor] = None,
@@ -292,7 +291,7 @@ def forward(self) -> torch.Tensor:
292291
class LaplacianImage(ImageParameterization):
293292
def __init__(
294293
self,
295-
size: InitSize = None,
294+
size: Tuple[int, int] = None,
296295
channels: int = 3,
297296
batch: int = 1,
298297
init: Optional[torch.Tensor] = None,
@@ -318,7 +317,7 @@ def __init__(
318317

319318
def setup_input(
320319
self,
321-
size: InitSize,
320+
size: Tuple[int, int],
322321
channels: int,
323322
power: float = 0.1,
324323
init: Optional[torch.Tensor] = None,
@@ -470,49 +469,70 @@ class NaturalImage(ImageParameterization):
470469
r"""Outputs an optimizable input image.
471470
472471
By convention, single images are CHW and float32s in [0,1].
473-
The underlying parameterization is decorrelated via a ToRGB transform.
472+
The underlying parameterization can be decorrelated via a ToRGB transform.
474473
When used with the (default) FFT parameterization, this results in a fully
475474
uncorrelated image parameterization. :-)
476475
477476
If a model requires a normalization step, such as normalizing imagenet RGB values,
478477
or rescaling to [0,255], it can perform those steps with the provided transforms or
479478
inside its computation.
480-
For example, our GoogleNet factory function has a `transform_input=True` argument.
479+
480+
Arguments:
481+
size (Tuple[int, int]): The height and width to use for the nn.Parameter image
482+
tensor.
483+
channels (int): The number of channels to use when creating the
484+
nn.Parameter tensor.
485+
batch (int): The number of channels to use when creating the
486+
nn.Parameter tensor, or stacking init images.
487+
parameterization (ImageParameterization, optional): An image parameterization
488+
class.
489+
squash_func (Callable[[torch.Tensor], torch.Tensor]], optional): The squash
490+
function to use after color recorrelation. A funtion or lambda function.
491+
decorrelation_module (nn.Module, optional): A ToRGB instance.
492+
decorrelate_init (bool, optional): Whether or not to apply color decorrelation
493+
to the init tensor input.
481494
"""
482495

483496
def __init__(
484497
self,
485-
size: InitSize = None,
498+
size: Tuple[int, int] = None,
486499
channels: int = 3,
487500
batch: int = 1,
488-
parameterization: ImageParameterization = FFTImage,
489501
init: Optional[torch.Tensor] = None,
502+
parameterization: ImageParameterization = FFTImage,
503+
squash_func: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
504+
decorrelation_module: Optional[nn.Module] = ToRGB(transform="klt"),
490505
decorrelate_init: bool = True,
491-
squash_func: Optional[SquashFunc] = None,
492506
) -> None:
493507
super().__init__()
494-
self.decorrelate = ToRGB(transform_name="klt")
508+
self.decorrelate = decorrelation_module
495509
if init is not None:
496510
assert init.dim() == 3 or init.dim() == 4
497511
if decorrelate_init:
512+
assert self.decorrelate is not None
498513
init = (
499514
init.refine_names("B", "C", "H", "W")
500515
if init.dim() == 4
501516
else init.refine_names("C", "H", "W")
502517
)
503518
init = self.decorrelate(init, inverse=True).rename(None)
504519
if squash_func is None:
505-
squash_func: SquashFunc = lambda x: x.clamp(0, 1)
520+
squash_func: Callable[[torch.Tensor], torch.Tensor] = lambda x: x.clamp(
521+
0, 1
522+
)
506523
else:
507524
if squash_func is None:
508-
squash_func: SquashFunc = lambda x: torch.sigmoid(x)
525+
squash_func: Callable[
526+
[torch.Tensor], torch.Tensor
527+
] = lambda x: torch.sigmoid(x)
509528
self.squash_func = squash_func
510529
self.parameterization = parameterization(
511530
size=size, channels=channels, batch=batch, init=init
512531
)
513532

514533
def forward(self) -> torch.Tensor:
515534
image = self.parameterization()
516-
image = self.decorrelate(image)
535+
if self.decorrelate is not None:
536+
image = self.decorrelate(image)
517537
image = image.rename(None) # TODO: the world is not yet ready
518538
return CudaImageTensor(self.squash_func(image))

captum/optim/_param/image/transform.py

Lines changed: 76 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import math
22
import numbers
3-
from typing import List, Optional, Sequence, Tuple, Union
3+
from typing import List, Optional, Sequence, Tuple, Union, cast
44

55
import numpy as np
66
import torch
77
import torch.nn as nn
88
import torch.nn.functional as F
99

1010
from captum.optim._utils.image.common import nchannels_to_rgb
11-
from captum.optim._utils.typing import TransformSize, TransformVal, TransformValList
11+
from captum.optim._utils.typing import IntSeqOrIntType, NumSeqOrTensorType
1212

1313
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
1414

@@ -46,14 +46,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
4646

4747
class ToRGB(nn.Module):
4848
"""Transforms arbitrary channels to RGB. We use this to ensure our
49-
image parameteriaztion itself can be decorrelated. So this goes between
50-
the image parameterization and the normalization/sigmoid step.
51-
We offer two transforms: Karhunen-Loève (KLT) and I1I2I3.
49+
image parametrization itself can be decorrelated. So this goes between
50+
the image parametrization and the normalization/sigmoid step.
51+
We offer two precalculated transforms: Karhunen-Loève (KLT) and I1I2I3.
5252
KLT corresponds to the empirically measured channel correlations on imagenet.
53-
I1I2I3 corresponds to an aproximation for natural images from Ohta et al.[0]
53+
I1I2I3 corresponds to an approximation for natural images from Ohta et al.[0]
5454
[0] Y. Ohta, T. Kanade, and T. Sakai, "Color information for region segmentation,"
5555
Computer Graphics and Image Processing, vol. 13, no. 3, pp. 222–241, 1980
5656
https://www.sciencedirect.com/science/article/pii/0146664X80900477
57+
58+
Arguments:
59+
transform (str or tensor): Either a string for one of the precalculated
60+
transform matrices, or a 3x3 matrix for the 3 RGB channels of input
61+
tensors.
5762
"""
5863

5964
@staticmethod
@@ -73,15 +78,21 @@ def i1i2i3_transform() -> torch.Tensor:
7378
]
7479
return torch.Tensor(i1i2i3_matrix)
7580

76-
def __init__(self, transform_name: str = "klt") -> None:
81+
def __init__(self, transform: Union[str, torch.Tensor] = "klt") -> None:
7782
super().__init__()
78-
79-
if transform_name == "klt":
83+
assert isinstance(transform, str) or torch.is_tensor(transform)
84+
if torch.is_tensor(transform):
85+
transform = cast(torch.Tensor, transform)
86+
assert list(transform.shape) == [3, 3]
87+
self.register_buffer("transform", transform)
88+
elif transform == "klt":
8089
self.register_buffer("transform", ToRGB.klt_transform())
81-
elif transform_name == "i1i2i3":
90+
elif transform == "i1i2i3":
8291
self.register_buffer("transform", ToRGB.i1i2i3_transform())
8392
else:
84-
raise ValueError("transform_name has to be either 'klt' or 'i1i2i3'")
93+
raise ValueError(
94+
"transform has to be either 'klt', 'i1i2i3'," + " or a matrix tensor."
95+
)
8596

8697
def forward(self, x: torch.Tensor, inverse: bool = False) -> torch.Tensor:
8798
assert x.dim() == 3 or x.dim() == 4
@@ -118,60 +129,74 @@ def forward(self, x: torch.Tensor, inverse: bool = False) -> torch.Tensor:
118129

119130
class CenterCrop(torch.nn.Module):
120131
"""
121-
Center crop the specified amount of pixels from the edges.
132+
Center crop a specified amount from a tensor.
122133
Arguments:
123-
size (int, sequence) or (int): Number of pixels to center crop away.
134+
size (int, sequence, int): Number of pixels to center crop away.
135+
pixels_from_edges (bool, optional): Whether to treat crop size
136+
values as the number of pixels from the tensor's edge, or an
137+
exact shape in the center.
124138
"""
125139

126-
def __init__(self, size: TransformSize = 0) -> None:
140+
def __init__(
141+
self, size: IntSeqOrIntType = 0, pixels_from_edges: bool = False
142+
) -> None:
127143
super(CenterCrop, self).__init__()
128-
if type(size) is list or type(size) is tuple:
129-
assert len(size) == 2, (
130-
"CenterCrop requires a single crop value or a tuple of (height,width)"
131-
+ "in pixels for cropping."
132-
)
133-
self.crop_val = size
134-
else:
135-
self.crop_val = [size] * 2
144+
self.crop_vals = size
145+
self.pixels_from_edges = pixels_from_edges
136146

137147
def forward(self, input: torch.Tensor) -> torch.Tensor:
138-
assert (
139-
input.dim() == 3 or input.dim() == 4
140-
), "Input to CenterCrop must be 3D or 4D"
141-
if input.dim() == 4:
142-
h, w = input.size(2), input.size(3)
143-
elif input.dim() == 3:
144-
h, w = input.size(1), input.size(2)
145-
h_crop = h - self.crop_val[0]
146-
w_crop = w - self.crop_val[1]
147-
sw, sh = w // 2 - (w_crop // 2), h // 2 - (h_crop // 2)
148-
return input[..., sh : sh + h_crop, sw : sw + w_crop]
148+
"""
149+
Center crop an input.
150+
Arguments:
151+
input (torch.Tensor): Input to center crop.
152+
Returns:
153+
tensor (torch.Tensor): A center cropped tensor.
154+
"""
155+
156+
return center_crop(input, self.crop_vals, self.pixels_from_edges)
149157

150158

151-
def center_crop_shape(input: torch.Tensor, output_size: List[int]) -> torch.Tensor:
159+
def center_crop(
160+
input: torch.Tensor, crop_vals: IntSeqOrIntType, pixels_from_edges: bool = False
161+
) -> torch.Tensor:
152162
"""
153-
Crop NCHW & CHW outputs by specifying the desired output shape.
163+
Center crop a specified amount from a tensor.
164+
Arguments:
165+
input (tensor): A CHW or NCHW image tensor to center crop.
166+
size (int, sequence, int): Number of pixels to center crop away.
167+
pixels_from_edges (bool, optional): Whether to treat crop size
168+
values as the number of pixels from the tensor's edge, or an
169+
exact shape in the center.
170+
Returns:
171+
*tensor*: A center cropped tensor.
154172
"""
155173

156-
assert input.dim() == 4 or input.dim() == 3
157-
output_size = [output_size] if not hasattr(output_size, "__iter__") else output_size
158-
assert len(output_size) == 1 or len(output_size) == 2
159-
output_size = output_size * 2 if len(output_size) == 1 else output_size
174+
assert input.dim() == 3 or input.dim() == 4
175+
crop_vals = [crop_vals] if not hasattr(crop_vals, "__iter__") else crop_vals
176+
crop_vals = cast(Union[List[int], Tuple[int], Tuple[int, int]], crop_vals)
177+
assert len(crop_vals) == 1 or len(crop_vals) == 2
178+
crop_vals = crop_vals * 2 if len(crop_vals) == 1 else crop_vals
160179

161180
if input.dim() == 4:
162181
h, w = input.size(2), input.size(3)
163182
if input.dim() == 3:
164183
h, w = input.size(1), input.size(2)
165184

166-
h_crop = h - int(round((h - output_size[0]) / 2.0))
167-
w_crop = w - int(round((w - output_size[1]) / 2.0))
168-
169-
return input[
170-
..., h_crop - output_size[0] : h_crop, w_crop - output_size[1] : w_crop
171-
]
185+
if pixels_from_edges:
186+
h_crop = h - crop_vals[0]
187+
w_crop = w - crop_vals[1]
188+
sw, sh = w // 2 - (w_crop // 2), h // 2 - (h_crop // 2)
189+
x = input[..., sh : sh + h_crop, sw : sw + w_crop]
190+
else:
191+
h_crop = h - int(round((h - crop_vals[0]) / 2.0))
192+
w_crop = w - int(round((w - crop_vals[1]) / 2.0))
193+
x = input[..., h_crop - crop_vals[0] : h_crop, w_crop - crop_vals[1] : w_crop]
194+
return x
172195

173196

174-
def rand_select(transform_values: TransformValList) -> TransformVal:
197+
def rand_select(
198+
transform_values: NumSeqOrTensorType,
199+
) -> Union[int, float, torch.Tensor]:
175200
"""
176201
Randomly return a value from the provided tuple or list
177202
"""
@@ -186,19 +211,21 @@ class RandomScale(nn.Module):
186211
scale (float, sequence): Tuple of rescaling values to randomly select from.
187212
"""
188213

189-
def __init__(self, scale: TransformValList) -> None:
214+
def __init__(self, scale: NumSeqOrTensorType) -> None:
190215
super(RandomScale, self).__init__()
191216
self.scale = scale
192217

193218
def get_scale_mat(
194-
self, m: TransformVal, device: torch.device, dtype: torch.dtype
219+
self, m: IntSeqOrIntType, device: torch.device, dtype: torch.dtype
195220
) -> torch.Tensor:
196221
scale_mat = torch.tensor(
197222
[[m, 0.0, 0.0], [0.0, m, 0.0]], device=device, dtype=dtype
198223
)
199224
return scale_mat
200225

201-
def scale_tensor(self, x: torch.Tensor, scale: TransformVal) -> torch.Tensor:
226+
def scale_tensor(
227+
self, x: torch.Tensor, scale: Union[int, float, torch.Tensor]
228+
) -> torch.Tensor:
202229
scale_matrix = self.get_scale_mat(scale, x.device, x.dtype)[None, ...].repeat(
203230
x.shape[0], 1, 1
204231
)

captum/optim/_utils/circuits.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
1-
from typing import Optional, Tuple, Union
1+
from typing import Callable, Optional, Tuple, Union
22

33
import torch
44
import torch.nn as nn
55

6-
from captum.optim._param.image.transform import center_crop_shape
6+
from captum.optim._param.image.transform import center_crop
77
from captum.optim._utils.models import collect_activations
8-
from captum.optim._utils.typing import ModelInputType, TransformSize
8+
from captum.optim._utils.typing import IntSeqOrIntType, TupleOfTensorsOrTensorType
99

1010

1111
def get_expanded_weights(
1212
model,
1313
target1: nn.Module,
1414
target2: nn.Module,
15-
crop_shape: Optional[Union[Tuple[int, int], TransformSize]] = None,
16-
model_input: ModelInputType = torch.zeros(1, 3, 224, 224),
15+
crop_shape: Optional[Union[Tuple[int, int], IntSeqOrIntType]] = None,
16+
model_input: TupleOfTensorsOrTensorType = torch.zeros(1, 3, 224, 224),
17+
crop_func: Optional[Callable] = center_crop,
1718
) -> torch.Tensor:
1819
"""
1920
Extract meaningful weight interactions from between neurons which aren’t
@@ -32,6 +33,8 @@ def get_expanded_weights(
3233
size to enter crop away padding.
3334
model_input (tensor or tuple of tensors, optional): The input to use
3435
with the specified model.
36+
crop_func (Callable, optional): Specify a function to crop away the padding
37+
from the output weights.
3538
Returns:
3639
*tensor*: A tensor containing the expanded weights in the form of:
3740
(target2 output channels, target1 output channels, y, x)
@@ -58,6 +61,6 @@ def get_expanded_weights(
5861
A.append(x.squeeze(0))
5962
expanded_weights = torch.stack(A, 0)
6063

61-
if crop_shape is not None:
62-
expanded_weights = center_crop_shape(expanded_weights, crop_shape)
64+
if crop_shape is not None and crop_func is not None:
65+
expanded_weights = crop_func(expanded_weights, crop_shape)
6366
return expanded_weights

0 commit comments

Comments
 (0)