Skip to content

Optim-wip: Miscellaneous Changes & Fixes #827

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 15, 2022
2 changes: 2 additions & 0 deletions captum/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from captum.optim._utils import circuits, reducer # noqa: F401
from captum.optim._utils.image import atlas # noqa: F401
from captum.optim._utils.image.common import ( # noqa: F401
hue_to_rgb,
nchannels_to_rgb,
save_tensor_as_image,
show,
Expand All @@ -25,6 +26,7 @@
"models",
"reducer",
"atlas",
"hue_to_rgb",
"nchannels_to_rgb",
"save_tensor_as_image",
"show",
Expand Down
19 changes: 10 additions & 9 deletions captum/optim/_core/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,14 +452,14 @@ def __init__(
batch_index: Optional[int] = None,
) -> None:
BaseLoss.__init__(self, target, batch_index)
self.direction = vec.reshape((1, -1, 1, 1))
self.vec = vec.reshape((1, -1, 1, 1))
self.cossim_pow = cossim_pow

def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
activations = targets_to_values[self.target]
assert activations.size(1) == self.direction.size(1)
assert activations.size(1) == self.vec.size(1)
activations = activations[self.batch_index[0] : self.batch_index[1]]
return _dot_cossim(self.direction, activations, cossim_pow=self.cossim_pow)
return _dot_cossim(self.vec, activations, cossim_pow=self.cossim_pow)


@loss_wrapper
Expand All @@ -481,7 +481,7 @@ def __init__(
batch_index: Optional[int] = None,
) -> None:
BaseLoss.__init__(self, target, batch_index)
self.direction = vec.reshape((1, -1, 1, 1))
self.vec = vec.reshape((1, -1, 1, 1))
self.x = x
self.y = y
self.channel_index = channel_index
Expand All @@ -500,7 +500,7 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
]
if self.channel_index is not None:
activations = activations[:, self.channel_index, ...][:, None, ...]
return _dot_cossim(self.direction, activations, cossim_pow=self.cossim_pow)
return _dot_cossim(self.vec, activations, cossim_pow=self.cossim_pow)


@loss_wrapper
Expand Down Expand Up @@ -607,16 +607,17 @@ def __init__(
batch_index: Optional[int] = None,
) -> None:
BaseLoss.__init__(self, target, batch_index)
self.direction = vec
assert vec.dim() == 4
self.vec = vec
self.cossim_pow = cossim_pow

def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
activations = targets_to_values[self.target]

assert activations.dim() == 4

H_direction, W_direction = self.direction.size(2), self.direction.size(3)
H_activ, W_activ = activations.size(2), activations.size(3)
H_direction, W_direction = self.vec.shape[2:]
H_activ, W_activ = activations.shape[2:]

H = (H_activ - H_direction) // 2
W = (W_activ - W_direction) // 2
Expand All @@ -627,7 +628,7 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
H : H + H_direction,
W : W + W_direction,
]
return _dot_cossim(self.direction, activations, cossim_pow=self.cossim_pow)
return _dot_cossim(self.vec, activations, cossim_pow=self.cossim_pow)


@loss_wrapper
Expand Down
9 changes: 8 additions & 1 deletion captum/optim/_param/image/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,16 @@ def klt_transform() -> torch.Tensor:
**transform** (torch.Tensor): A Karhunen-Loève transform (KLT) measured on
the ImageNet dataset.
"""
# Handle older versions of PyTorch
torch_norm = (
torch.linalg.norm
if version.parse(torch.__version__) >= version.parse("1.7.0")
else torch.norm
)

KLT = [[0.26, 0.09, 0.02], [0.27, 0.00, -0.05], [0.27, -0.09, 0.03]]
transform = torch.Tensor(KLT).float()
transform = transform / torch.max(torch.norm(transform, dim=0))
transform = transform / torch.max(torch_norm(transform, dim=0))
return transform

@staticmethod
Expand Down
123 changes: 83 additions & 40 deletions captum/optim/_utils/image/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import torch
from captum.optim._utils.reducer import posneg
from packaging import version

try:
from PIL import Image
Expand Down Expand Up @@ -64,6 +65,21 @@ def save_tensor_as_image(x: torch.Tensor, filename: str, scale: float = 255.0) -
def get_neuron_pos(
H: int, W: int, x: Optional[int] = None, y: Optional[int] = None
) -> Tuple[int, int]:
"""
Args:

H (int) The height
W (int) The width
x (int, optional): Optionally specify and exact x location of the neuron. If
set to None, then the center x location will be used.
Default: None
y (int, optional): Optionally specify and exact y location of the neuron. If
set to None, then the center y location will be used.
Default: None

Return:
Tuple[_x, _y] (Tuple[int, int]): The x and y dimensions of the neuron.
"""
if x is None:
_x = W // 2
else:
Expand Down Expand Up @@ -109,66 +125,93 @@ def _dot_cossim(
return dot * torch.clamp(torch.cosine_similarity(x, y, eps=eps), 0.1) ** cossim_pow


@torch.jit.ignore
def nchannels_to_rgb(x: torch.Tensor, warp: bool = True) -> torch.Tensor:
"""
Convert an NCHW image with n channels into a 3 channel RGB image.
# Handle older versions of PyTorch
# Defined outside of function in order to support JIT
_torch_norm = (
torch.linalg.norm
if version.parse(torch.__version__) >= version.parse("1.7.0")
else torch.norm
)


def hue_to_rgb(
angle: float, device: torch.device = torch.device("cpu"), warp: bool = True
) -> torch.Tensor:
"""
Create an RGB unit vector based on a hue of the input angle.
Args:
x (torch.Tensor): Image tensor to transform into RGB image.
warp (bool, optional): Whether or not to make colors more distinguishable.
angle (float): The hue angle to create an RGB color for.
device (torch.device, optional): The device to create the angle color tensor
on.
Default: torch.device("cpu")
warp (bool, optional): Whether or not to make colors more distinguishable.
Default: True
Returns:
*tensor* RGB image
color_vec (torch.Tensor): A color vector.
"""

def hue_to_rgb(angle: float) -> torch.Tensor:
"""
Create an RGB unit vector based on a hue of the input angle.
"""

angle = angle - 360 * (angle // 360)
colors = torch.tensor(
[
[1.0, 0.0, 0.0],
[0.7071, 0.7071, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.7071, 0.7071],
[0.0, 0.0, 1.0],
[0.7071, 0.0, 0.7071],
]
angle = angle - 360 * (angle // 360)
colors = torch.tensor(
[
[1.0, 0.0, 0.0],
[0.7071, 0.7071, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.7071, 0.7071],
[0.0, 0.0, 1.0],
[0.7071, 0.0, 0.7071],
],
device=device,
)

idx = math.floor(angle / 60)
d = (angle - idx * 60) / 60

if warp:
# Idea from: https://github.com/tensorflow/lucid/pull/193
d = (
math.sin(d * math.pi / 2)
if idx % 2 == 0
else 1 - math.sin((1 - d) * math.pi / 2)
)

idx = math.floor(angle / 60)
d = (angle - idx * 60) / 60
vec = (1 - d) * colors[idx] + d * colors[(idx + 1) % 6]
return vec / _torch_norm(vec)

if warp:

def adj(x: float) -> float:
return math.sin(x * math.pi / 2)
def nchannels_to_rgb(
x: torch.Tensor, warp: bool = True, eps: float = 1e-4
) -> torch.Tensor:
"""
Convert an NCHW image with n channels into a 3 channel RGB image.

d = adj(d) if idx % 2 == 0 else 1 - adj(1 - d)
Args:

vec = (1 - d) * colors[idx] + d * colors[(idx + 1) % 6]
return vec / torch.norm(vec)
x (torch.Tensor): NCHW image tensor to transform into RGB image.
warp (bool, optional): Whether or not to make colors more distinguishable.
Default: True
eps (float, optional): An optional epsilon value.
Default: 1e-4
Returns:
tensor (torch.Tensor): An NCHW RGB image tensor.
"""

assert x.dim() == 4

if (x < 0).any():
x = posneg(x.permute(0, 2, 3, 1), -1).permute(0, 3, 1, 2)

rgb = torch.zeros(1, 3, x.size(2), x.size(3), device=x.device)
nc = x.size(1)
for i in range(nc):
rgb = rgb + x[:, i][:, None, :, :]
rgb = rgb * hue_to_rgb(360 * i / nc).to(device=x.device)[None, :, None, None]

rgb = rgb + torch.ones(x.size(2), x.size(3))[None, None, :, :] * (
torch.sum(x, 1)[:, None] - torch.max(x, 1)[0][:, None]
)
return (rgb / (1e-4 + torch.norm(rgb, dim=1, keepdim=True))) * torch.norm(
x, dim=1, keepdim=True
num_channels = x.size(1)
for i in range(num_channels):
rgb_angle = hue_to_rgb(360 * i / num_channels, device=x.device, warp=warp)
rgb = rgb + (x[:, i][:, None, :, :] * rgb_angle[None, :, None, None])

rgb = rgb + (
torch.ones(1, 1, x.size(2), x.size(3), device=x.device)
* (torch.sum(x, 1) - torch.max(x, 1)[0])[:, None]
)
rgb = rgb / (eps + _torch_norm(rgb, dim=1, keepdim=True))
return rgb * _torch_norm(x, dim=1, keepdim=True)


def weights_to_heatmap_2d(
Expand Down
Loading