diff --git a/captum/optim/__init__.py b/captum/optim/__init__.py index 2fa5ee043d..647e2ba721 100644 --- a/captum/optim/__init__.py +++ b/captum/optim/__init__.py @@ -8,6 +8,7 @@ from captum.optim._utils import circuits, reducer # noqa: F401 from captum.optim._utils.image import atlas # noqa: F401 from captum.optim._utils.image.common import ( # noqa: F401 + hue_to_rgb, nchannels_to_rgb, save_tensor_as_image, show, @@ -25,6 +26,7 @@ "models", "reducer", "atlas", + "hue_to_rgb", "nchannels_to_rgb", "save_tensor_as_image", "show", diff --git a/captum/optim/_core/loss.py b/captum/optim/_core/loss.py index db9066ceba..f110887406 100644 --- a/captum/optim/_core/loss.py +++ b/captum/optim/_core/loss.py @@ -452,14 +452,14 @@ def __init__( batch_index: Optional[int] = None, ) -> None: BaseLoss.__init__(self, target, batch_index) - self.direction = vec.reshape((1, -1, 1, 1)) + self.vec = vec.reshape((1, -1, 1, 1)) self.cossim_pow = cossim_pow def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: activations = targets_to_values[self.target] - assert activations.size(1) == self.direction.size(1) + assert activations.size(1) == self.vec.size(1) activations = activations[self.batch_index[0] : self.batch_index[1]] - return _dot_cossim(self.direction, activations, cossim_pow=self.cossim_pow) + return _dot_cossim(self.vec, activations, cossim_pow=self.cossim_pow) @loss_wrapper @@ -481,7 +481,7 @@ def __init__( batch_index: Optional[int] = None, ) -> None: BaseLoss.__init__(self, target, batch_index) - self.direction = vec.reshape((1, -1, 1, 1)) + self.vec = vec.reshape((1, -1, 1, 1)) self.x = x self.y = y self.channel_index = channel_index @@ -500,7 +500,7 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: ] if self.channel_index is not None: activations = activations[:, self.channel_index, ...][:, None, ...] - return _dot_cossim(self.direction, activations, cossim_pow=self.cossim_pow) + return _dot_cossim(self.vec, activations, cossim_pow=self.cossim_pow) @loss_wrapper @@ -607,7 +607,8 @@ def __init__( batch_index: Optional[int] = None, ) -> None: BaseLoss.__init__(self, target, batch_index) - self.direction = vec + assert vec.dim() == 4 + self.vec = vec self.cossim_pow = cossim_pow def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: @@ -615,8 +616,8 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: assert activations.dim() == 4 - H_direction, W_direction = self.direction.size(2), self.direction.size(3) - H_activ, W_activ = activations.size(2), activations.size(3) + H_direction, W_direction = self.vec.shape[2:] + H_activ, W_activ = activations.shape[2:] H = (H_activ - H_direction) // 2 W = (W_activ - W_direction) // 2 @@ -627,7 +628,7 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: H : H + H_direction, W : W + W_direction, ] - return _dot_cossim(self.direction, activations, cossim_pow=self.cossim_pow) + return _dot_cossim(self.vec, activations, cossim_pow=self.cossim_pow) @loss_wrapper diff --git a/captum/optim/_param/image/transforms.py b/captum/optim/_param/image/transforms.py index fb0dffe294..4ec8762637 100644 --- a/captum/optim/_param/image/transforms.py +++ b/captum/optim/_param/image/transforms.py @@ -89,9 +89,16 @@ def klt_transform() -> torch.Tensor: **transform** (torch.Tensor): A Karhunen-Loève transform (KLT) measured on the ImageNet dataset. """ + # Handle older versions of PyTorch + torch_norm = ( + torch.linalg.norm + if version.parse(torch.__version__) >= version.parse("1.7.0") + else torch.norm + ) + KLT = [[0.26, 0.09, 0.02], [0.27, 0.00, -0.05], [0.27, -0.09, 0.03]] transform = torch.Tensor(KLT).float() - transform = transform / torch.max(torch.norm(transform, dim=0)) + transform = transform / torch.max(torch_norm(transform, dim=0)) return transform @staticmethod diff --git a/captum/optim/_utils/image/common.py b/captum/optim/_utils/image/common.py index 613e9ea8ea..f4c5d9d3fa 100644 --- a/captum/optim/_utils/image/common.py +++ b/captum/optim/_utils/image/common.py @@ -5,6 +5,7 @@ import numpy as np import torch from captum.optim._utils.reducer import posneg +from packaging import version try: from PIL import Image @@ -64,6 +65,21 @@ def save_tensor_as_image(x: torch.Tensor, filename: str, scale: float = 255.0) - def get_neuron_pos( H: int, W: int, x: Optional[int] = None, y: Optional[int] = None ) -> Tuple[int, int]: + """ + Args: + + H (int) The height + W (int) The width + x (int, optional): Optionally specify and exact x location of the neuron. If + set to None, then the center x location will be used. + Default: None + y (int, optional): Optionally specify and exact y location of the neuron. If + set to None, then the center y location will be used. + Default: None + + Return: + Tuple[_x, _y] (Tuple[int, int]): The x and y dimensions of the neuron. + """ if x is None: _x = W // 2 else: @@ -109,48 +125,75 @@ def _dot_cossim( return dot * torch.clamp(torch.cosine_similarity(x, y, eps=eps), 0.1) ** cossim_pow -@torch.jit.ignore -def nchannels_to_rgb(x: torch.Tensor, warp: bool = True) -> torch.Tensor: - """ - Convert an NCHW image with n channels into a 3 channel RGB image. +# Handle older versions of PyTorch +# Defined outside of function in order to support JIT +_torch_norm = ( + torch.linalg.norm + if version.parse(torch.__version__) >= version.parse("1.7.0") + else torch.norm +) + +def hue_to_rgb( + angle: float, device: torch.device = torch.device("cpu"), warp: bool = True +) -> torch.Tensor: + """ + Create an RGB unit vector based on a hue of the input angle. Args: - x (torch.Tensor): Image tensor to transform into RGB image. - warp (bool, optional): Whether or not to make colors more distinguishable. + angle (float): The hue angle to create an RGB color for. + device (torch.device, optional): The device to create the angle color tensor + on. + Default: torch.device("cpu") + warp (bool, optional): Whether or not to make colors more distinguishable. Default: True Returns: - *tensor* RGB image + color_vec (torch.Tensor): A color vector. """ - def hue_to_rgb(angle: float) -> torch.Tensor: - """ - Create an RGB unit vector based on a hue of the input angle. - """ - - angle = angle - 360 * (angle // 360) - colors = torch.tensor( - [ - [1.0, 0.0, 0.0], - [0.7071, 0.7071, 0.0], - [0.0, 1.0, 0.0], - [0.0, 0.7071, 0.7071], - [0.0, 0.0, 1.0], - [0.7071, 0.0, 0.7071], - ] + angle = angle - 360 * (angle // 360) + colors = torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.7071, 0.7071, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.7071, 0.7071], + [0.0, 0.0, 1.0], + [0.7071, 0.0, 0.7071], + ], + device=device, + ) + + idx = math.floor(angle / 60) + d = (angle - idx * 60) / 60 + + if warp: + # Idea from: https://github.com/tensorflow/lucid/pull/193 + d = ( + math.sin(d * math.pi / 2) + if idx % 2 == 0 + else 1 - math.sin((1 - d) * math.pi / 2) ) - idx = math.floor(angle / 60) - d = (angle - idx * 60) / 60 + vec = (1 - d) * colors[idx] + d * colors[(idx + 1) % 6] + return vec / _torch_norm(vec) - if warp: - def adj(x: float) -> float: - return math.sin(x * math.pi / 2) +def nchannels_to_rgb( + x: torch.Tensor, warp: bool = True, eps: float = 1e-4 +) -> torch.Tensor: + """ + Convert an NCHW image with n channels into a 3 channel RGB image. - d = adj(d) if idx % 2 == 0 else 1 - adj(1 - d) + Args: - vec = (1 - d) * colors[idx] + d * colors[(idx + 1) % 6] - return vec / torch.norm(vec) + x (torch.Tensor): NCHW image tensor to transform into RGB image. + warp (bool, optional): Whether or not to make colors more distinguishable. + Default: True + eps (float, optional): An optional epsilon value. + Default: 1e-4 + Returns: + tensor (torch.Tensor): An NCHW RGB image tensor. + """ assert x.dim() == 4 @@ -158,17 +201,17 @@ def adj(x: float) -> float: x = posneg(x.permute(0, 2, 3, 1), -1).permute(0, 3, 1, 2) rgb = torch.zeros(1, 3, x.size(2), x.size(3), device=x.device) - nc = x.size(1) - for i in range(nc): - rgb = rgb + x[:, i][:, None, :, :] - rgb = rgb * hue_to_rgb(360 * i / nc).to(device=x.device)[None, :, None, None] - - rgb = rgb + torch.ones(x.size(2), x.size(3))[None, None, :, :] * ( - torch.sum(x, 1)[:, None] - torch.max(x, 1)[0][:, None] - ) - return (rgb / (1e-4 + torch.norm(rgb, dim=1, keepdim=True))) * torch.norm( - x, dim=1, keepdim=True + num_channels = x.size(1) + for i in range(num_channels): + rgb_angle = hue_to_rgb(360 * i / num_channels, device=x.device, warp=warp) + rgb = rgb + (x[:, i][:, None, :, :] * rgb_angle[None, :, None, None]) + + rgb = rgb + ( + torch.ones(1, 1, x.size(2), x.size(3), device=x.device) + * (torch.sum(x, 1) - torch.max(x, 1)[0])[:, None] ) + rgb = rgb / (eps + _torch_norm(rgb, dim=1, keepdim=True)) + return rgb * _torch_norm(x, dim=1, keepdim=True) def weights_to_heatmap_2d( diff --git a/captum/optim/_utils/image/dataset.py b/captum/optim/_utils/image/dataset.py index fcc6d03742..c894173990 100644 --- a/captum/optim/_utils/image/dataset.py +++ b/captum/optim/_utils/image/dataset.py @@ -1,38 +1,75 @@ +from typing import cast + import torch +try: + from tqdm.auto import tqdm +except (ImportError, AssertionError): + print( + "The tqdm package is required to use captum.optim's" + + " image dataset functions with progress bar" + ) + -def image_cov(tensor: torch.Tensor) -> torch.Tensor: +def image_cov(x: torch.Tensor) -> torch.Tensor: """ - Calculate a tensor's RGB covariance matrix. + Calculate the average NCHW image tensor color channel covariance matrix for all + tensors in the stack. Args: - tensor (tensor): An NCHW image tensor. + + x (torch.Tensor): One or more NCHW image tensors stacked across the batch + dimension. + Returns: - *tensor*: An RGB covariance matrix for the specified tensor. + *tensor* (torch.Tensor): The average color channel covariance matrix for the + for the input tensor, with a shape of: [n_channels, n_channels]. """ - tensor = tensor.reshape(-1, 3) - tensor = tensor - tensor.mean(0, keepdim=True) - return 1 / (tensor.size(0) - 1) * tensor.T @ tensor + assert x.dim() == 4 + x = x.reshape(x.shape[0], -1, x.shape[1]) + x = x - x.mean(1, keepdim=True) + b_cov_mtx = 1.0 / (x.shape[1] - 1) * x.permute(0, 2, 1) @ x + return torch.sum(b_cov_mtx, dim=0) -def dataset_cov_matrix(loader: torch.utils.data.DataLoader) -> torch.Tensor: +def dataset_cov_matrix( + loader: torch.utils.data.DataLoader, + show_progress: bool = False, + device: torch.device = torch.device("cpu"), +) -> torch.Tensor: """ Calculate the covariance matrix for an image dataset. Args: + loader (torch.utils.data.DataLoader): The reference to a PyTorch dataloader instance. + show_progress (bool, optional): Whether or not to display a tqdm progress bar. + Default: False + device (torch.device, optional): The PyTorch device to use for for calculating + the cov matrix. + Default: torch.device("cpu") + Returns: *tensor*: A covariance matrix for the specified dataset. """ - cov_mtx = torch.zeros(3, 3) + if show_progress: + pbar = tqdm(total=len(loader.dataset), unit=" images") # type: ignore + + cov_mtx = torch.zeros([], device=device).float() for images, _ in loader: - assert images.dim() == 4 - for b in range(images.size(0)): - cov_mtx = cov_mtx + image_cov(images[b].permute(1, 2, 0)) - cov_mtx = cov_mtx / len(loader.dataset) # type: ignore + assert images.dim() > 1 + images = images.to(device) + cov_mtx = cov_mtx + image_cov(images) + if show_progress: + pbar.update(images.size(0)) + + if show_progress: + pbar.close() + + cov_mtx = cov_mtx / cast(int, len(loader.dataset)) return cov_mtx @@ -43,22 +80,31 @@ def cov_matrix_to_klt( Convert a cov matrix to a klt matrix. Args: + cov_mtx (tensor): A 3 by 3 covariance matrix generated from a dataset. normalize (bool): Whether or not to normalize the resulting KLT matrix. + Default: False epsilon (float): + Returns: *tensor*: A KLT matrix for the specified covariance matrix. """ + # Handle older versions of PyTorch + torch_norm = torch.linalg.norm if torch.__version__ >= "1.9.0" else torch.norm + U, S, V = torch.svd(cov_mtx) svd_sqrt = U @ torch.diag(torch.sqrt(S + epsilon)) if normalize: - svd_sqrt / torch.max(torch.norm(svd_sqrt, dim=0)) + svd_sqrt / torch.max(torch_norm(svd_sqrt, dim=0)) return svd_sqrt def dataset_klt_matrix( - loader: torch.utils.data.DataLoader, normalize: bool = False + loader: torch.utils.data.DataLoader, + normalize: bool = False, + show_progress: bool = False, + device: torch.device = torch.device("cpu"), ) -> torch.Tensor: """ Calculate the color correlation matrix, also known as @@ -67,12 +113,20 @@ def dataset_klt_matrix( transforms for models trained on the dataset. Args: + loader (torch.utils.data.DataLoader): The reference to a PyTorch dataloader instance. normalize (bool): Whether or not to normalize the resulting KLT matrix. + Default: False + show_progress (bool, optional): Whether or not to display a tqdm progress bar. + Default: False + device (torch.device, optional): The PyTorch device to use for for calculating + the cov matrix. + Default: torch.device("cpu") + Returns: *tensor*: A KLT matrix for the specified dataset. """ - cov_mtx = dataset_cov_matrix(loader) + cov_mtx = dataset_cov_matrix(loader, show_progress=show_progress, device=device) return cov_matrix_to_klt(cov_mtx, normalize) diff --git a/captum/optim/models/_common.py b/captum/optim/models/_common.py index 3de775157b..85196fcd39 100644 --- a/captum/optim/models/_common.py +++ b/captum/optim/models/_common.py @@ -12,6 +12,10 @@ def get_model_layers(model: nn.Module) -> List[str]: """ Return a list of hookable layers for the target model. + + Args: + + model (nn.Module): A PyTorch model or module instance to collect layers from. """ layers = [] @@ -155,14 +159,59 @@ def __init__( groups: int = 1, bias: bool = True, ) -> None: + """ + See nn.Conv2d for more details on the possible arguments: + https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html + + Args: + + in_channels (int): The expected number of channels in the input tensor. + out_channels (int): The desired number of channels in the output tensor. + kernel_size (int or tuple of int): The desired kernel size to use. + stride (int or tuple of int, optional): The desired stride for the + cross-correlation. + Default: 1 + padding (int or tuple of int, optional): This value is always set to 0. + Default: 0 + dilation (int or tuple of int, optional): The desired spacing between the + kernel points. + Default: 1 + groups (int, optional): Number of blocked connections from input channels + to output channels. Both in_channels and out_channels must be divisable + by groups. + Default: 1 + bias (bool, optional): Whether or not to apply a learnable bias to the + output. + """ super().__init__( in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias ) def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int: + """ + Calculate the required padding for a dimension. + + Args: + + i (int): The specific size of the tensor dimension requiring padding. + k (int): The size of the Conv2d weight dimension. + s (int): The Conv2d stride value for the dimension. + d (int): The Conv2d dilation value for the dimension. + + Returns: + padding_vale (int): The calculated padding value. + """ return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + + x (torch.tensor): The input tensor to apply 2D convolution to. + + Returns + x (torch.Tensor): The input tensor after the 2D convolution was applied. + """ ih, iw = x.size()[-2:] kh, kw = self.weight.size()[-2:] pad_h = self.calc_same_pad(i=ih, k=kh, s=self.stride[0], d=self.dilation[0]) @@ -190,12 +239,25 @@ def collect_activations( ) -> ModuleOutputMapping: """ Collect target activations for a model. + + Args: + + model (nn.Module): A PyTorch model instance. + targets (nn.Module or list of nn.Module): One or more layer targets for the + given model. + model_input (torch.Tensor or tuple of torch.Tensor, optional): Optionally + provide an input tensor to use when collecting the target activations. + Default: torch.zeros(1, 3, 224, 224) + + Returns: + activ_dict (ModuleOutputMapping): A dictionary of collected activations where + the keys are the target layers. """ if not isinstance(targets, list): targets = [targets] catch_activ = ActivationFetcher(model, targets) - activ_out = catch_activ(model_input) - return activ_out + activ_dict = catch_activ(model_input) + return activ_dict class SkipLayer(torch.nn.Module): diff --git a/tests/optim/helpers/image_dataset.py b/tests/optim/helpers/image_dataset.py index a8cef03b87..edced0bd75 100644 --- a/tests/optim/helpers/image_dataset.py +++ b/tests/optim/helpers/image_dataset.py @@ -1,6 +1,5 @@ from typing import List, Tuple -import numpy as np import torch @@ -27,39 +26,3 @@ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]: def __len__(self) -> int: return len(self.tensors) - - -def image_cov_np(array: np.ndarray) -> np.ndarray: - """ - Calculate an array's RGB covariance matrix. - - Args: - array (array): An NCHW image array. - Returns: - *array*: An RGB covariance matrix for the specified array. - """ - - array = array.reshape(-1, 3) - array = array - array.mean(0, keepdims=True) - return 1 / (array.shape[0] - 1) * array.T @ array - - -def cov_matrix_to_klt_np( - cov_mtx: np.ndarray, normalize: bool = False, epsilon: float = 1e-10 -) -> np.ndarray: - """ - Convert a cov matrix to a klt matrix. - - Args: - cov_mtx (array): A 3 by 3 covariance matrix generated from a dataset. - normalize (bool): Whether or not to normalize the resulting KLT matrix. - epsilon (float): - Returns: - *array*: A KLT matrix for the specified covariance matrix. - """ - - U, S, V = np.linalg.svd(cov_mtx) - svd_sqrt = U @ np.diag(np.sqrt(S + epsilon)) - if normalize: - svd_sqrt / np.linalg.norm(svd_sqrt, axis=0).max() - return svd_sqrt diff --git a/tests/optim/utils/image/common.py b/tests/optim/utils/image/common.py deleted file mode 100644 index 617e7c0b4a..0000000000 --- a/tests/optim/utils/image/common.py +++ /dev/null @@ -1,129 +0,0 @@ -#!/usr/bin/env python3 -import unittest - -import captum.optim._utils.image.common as common -import torch -from tests.helpers.basic import BaseTest, assertTensorAlmostEqual - - -class TestGetNeuronPos(unittest.TestCase): - def test_get_neuron_pos_hw(self) -> None: - W, H = 128, 128 - x, y = common.get_neuron_pos(H, W) - - self.assertEqual(x, W // 2) - self.assertEqual(y, H // 2) - - def test_get_neuron_pos_xy(self) -> None: - W, H = 128, 128 - x, y = common.get_neuron_pos(H, W, 5, 5) - - self.assertEqual(x, 5) - self.assertEqual(y, 5) - - def test_get_neuron_pos_x_none(self) -> None: - W, H = 128, 128 - x, y = common.get_neuron_pos(H, W, 5, None) - - self.assertEqual(x, 5) - self.assertEqual(y, H // 2) - - def test_get_neuron_pos_none_y(self) -> None: - W, H = 128, 128 - x, y = common.get_neuron_pos(H, W, None, 5) - - self.assertEqual(x, W // 2) - self.assertEqual(y, 5) - - -class TestNChannelsToRGB(BaseTest): - def test_nchannels_to_rgb_collapse(self) -> None: - test_input = torch.randn(1, 6, 224, 224) - test_output = common.nchannels_to_rgb(test_input) - self.assertEqual(list(test_output.size()), [1, 3, 224, 224]) - - def test_nchannels_to_rgb_increase(self) -> None: - test_input = torch.randn(1, 2, 224, 224) - test_output = common.nchannels_to_rgb(test_input) - self.assertEqual(list(test_output.size()), [1, 3, 224, 224]) - - -class TestWeightsToHeatmap2D(BaseTest): - def test_weights_to_heatmap_2d(self) -> None: - x = torch.ones(5, 4) - x[0:1, 0:4] = x[0:1, 0:4] * 0.2 - x[1:2, 0:4] = x[1:2, 0:4] * 0.8 - x[2:3, 0:4] = x[2:3, 0:4] * 0.0 - x[3:4, 0:4] = x[3:4, 0:4] * -0.2 - x[4:5, 0:4] = x[4:5, 0:4] * -0.8 - - x_out = common.weights_to_heatmap_2d(x) - - x_out_expected = torch.tensor( - [ - [ - [0.9639, 0.9639, 0.9639, 0.9639], - [0.8580, 0.8580, 0.8580, 0.8580], - [0.9686, 0.9686, 0.9686, 0.9686], - [0.8102, 0.8102, 0.8102, 0.8102], - [0.2408, 0.2408, 0.2408, 0.2408], - ], - [ - [0.8400, 0.8400, 0.8400, 0.8400], - [0.2588, 0.2588, 0.2588, 0.2588], - [0.9686, 0.9686, 0.9686, 0.9686], - [0.8902, 0.8902, 0.8902, 0.8902], - [0.5749, 0.5749, 0.5749, 0.5749], - ], - [ - [0.7851, 0.7851, 0.7851, 0.7851], - [0.2792, 0.2792, 0.2792, 0.2792], - [0.9686, 0.9686, 0.9686, 0.9686], - [0.9294, 0.9294, 0.9294, 0.9294], - [0.7624, 0.7624, 0.7624, 0.7624], - ], - ] - ) - assertTensorAlmostEqual(self, x_out, x_out_expected, delta=0.01) - - def test_weights_to_heatmap_2d_cuda(self) -> None: - if not torch.cuda.is_available(): - raise unittest.SkipTest( - "Skipping weights_to_heatmap_2d CUDA test due to not supporting CUDA." - ) - x = torch.ones(5, 4) - x[0:1, 0:4] = x[0:1, 0:4] * 0.2 - x[1:2, 0:4] = x[1:2, 0:4] * 0.8 - x[2:3, 0:4] = x[2:3, 0:4] * 0.0 - x[3:4, 0:4] = x[3:4, 0:4] * -0.2 - x[4:5, 0:4] = x[4:5, 0:4] * -0.8 - - x_out = common.weights_to_heatmap_2d(x.cuda()) - - x_out_expected = torch.tensor( - [ - [ - [0.9639, 0.9639, 0.9639, 0.9639], - [0.8580, 0.8580, 0.8580, 0.8580], - [0.9686, 0.9686, 0.9686, 0.9686], - [0.8102, 0.8102, 0.8102, 0.8102], - [0.2408, 0.2408, 0.2408, 0.2408], - ], - [ - [0.8400, 0.8400, 0.8400, 0.8400], - [0.2588, 0.2588, 0.2588, 0.2588], - [0.9686, 0.9686, 0.9686, 0.9686], - [0.8902, 0.8902, 0.8902, 0.8902], - [0.5749, 0.5749, 0.5749, 0.5749], - ], - [ - [0.7851, 0.7851, 0.7851, 0.7851], - [0.2792, 0.2792, 0.2792, 0.2792], - [0.9686, 0.9686, 0.9686, 0.9686], - [0.9294, 0.9294, 0.9294, 0.9294], - [0.7624, 0.7624, 0.7624, 0.7624], - ], - ] - ) - assertTensorAlmostEqual(self, x_out, x_out_expected, delta=0.01) - self.assertTrue(x_out.is_cuda) diff --git a/tests/optim/utils/image/dataset.py b/tests/optim/utils/image/dataset.py deleted file mode 100644 index 2d54c1be4a..0000000000 --- a/tests/optim/utils/image/dataset.py +++ /dev/null @@ -1,100 +0,0 @@ -#!/usr/bin/env python3 -import captum.optim._utils.image.dataset as dataset_utils -import torch -from tests.helpers.basic import BaseTest, assertTensorAlmostEqual -from tests.optim.helpers import image_dataset as dataset_helpers - - -class TestImageCov(BaseTest): - def test_image_cov(self) -> None: - test_tensor = torch.cat( - [ - torch.ones(1, 4, 4) * 0.1, - torch.ones(1, 4, 4) * 0.2, - torch.ones(1, 4, 4) * 0.3, - ], - 0, - ) - - output_tensor = dataset_utils.image_cov(test_tensor) - expected_output = dataset_helpers.image_cov_np(test_tensor.numpy()) - expected_output = torch.as_tensor(expected_output) - assertTensorAlmostEqual(self, output_tensor, expected_output, 0.01, mode="max") - - -class TestDatasetCovMatrix(BaseTest): - def test_dataset_cov_matrix(self) -> None: - num_tensors = 100 - - def create_tensor() -> torch.Tensor: - return torch.cat( - [ - torch.ones(1, 224, 224) * 0.1, - torch.ones(1, 224, 224) * 0.2, - torch.ones(1, 224, 224) * 0.3, - ], - 0, - ) - - dataset_tensors = [create_tensor() for x in range(num_tensors)] - test_dataset = dataset_helpers.ImageTestDataset(dataset_tensors) - dataset_loader = torch.utils.data.DataLoader( - test_dataset, batch_size=10, num_workers=0, shuffle=False - ) - output_mtx = dataset_utils.dataset_cov_matrix(dataset_loader) - expected_mtx = torch.tensor( - [ - [4.9961e-14, 9.9922e-14, -6.6615e-14], - [9.9922e-14, 1.9984e-13, -1.3323e-13], - [-6.6615e-14, -1.3323e-13, 8.8820e-14], - ] - ) - assertTensorAlmostEqual(self, output_mtx, expected_mtx) - - -class TestCovMatrixToKLT(BaseTest): - def test_cov_matrix_to_klt(self) -> None: - test_input = torch.tensor( - [ - [0.0477, 0.0415, 0.0280], - [0.0415, 0.0425, 0.0333], - [0.0280, 0.0333, 0.0419], - ] - ) - output_mtx = dataset_utils.cov_matrix_to_klt(test_input) - expected_mtx = dataset_helpers.cov_matrix_to_klt_np(test_input.numpy()) - expected_mtx = torch.as_tensor(expected_mtx) - assertTensorAlmostEqual(self, output_mtx, expected_mtx, 0.0005, mode="max") - - -class TestDatasetKLTMatrix(BaseTest): - def test_dataset_klt_matrix(self) -> None: - num_tensors = 100 - - def create_tensor() -> torch.Tensor: - return torch.cat( - [ - torch.ones(1, 224, 224) * 0.1, - torch.ones(1, 224, 224) * 0.2, - torch.ones(1, 224, 224) * 0.3, - ], - 0, - ) - - dataset_tensors = [create_tensor() for x in range(num_tensors)] - test_dataset = dataset_helpers.ImageTestDataset(dataset_tensors) - dataset_loader = torch.utils.data.DataLoader( - test_dataset, batch_size=10, num_workers=0, shuffle=False - ) - - klt_transform = dataset_utils.dataset_klt_matrix(dataset_loader) - - expected_mtx = torch.tensor( - [ - [-3.8412e-06, 9.2125e-06, 6.1284e-07], - [-7.6823e-06, -3.5571e-06, 5.3226e-06], - [5.1216e-06, 1.5737e-06, 8.4436e-06], - ] - ) - - assertTensorAlmostEqual(self, klt_transform, expected_mtx) diff --git a/tests/optim/utils/image/test_common.py b/tests/optim/utils/image/test_common.py new file mode 100644 index 0000000000..14dc5bf1ba --- /dev/null +++ b/tests/optim/utils/image/test_common.py @@ -0,0 +1,341 @@ +#!/usr/bin/env python3 +import unittest + +import captum.optim._utils.image.common as common +import torch +from packaging import version +from tests.helpers.basic import BaseTest, assertTensorAlmostEqual + + +class TestGetNeuronPos(unittest.TestCase): + def test_get_neuron_pos_hw(self) -> None: + W, H = 128, 128 + x, y = common.get_neuron_pos(H, W) + + self.assertEqual(x, W // 2) + self.assertEqual(y, H // 2) + + def test_get_neuron_pos_xy(self) -> None: + W, H = 128, 128 + x, y = common.get_neuron_pos(H, W, 5, 5) + + self.assertEqual(x, 5) + self.assertEqual(y, 5) + + def test_get_neuron_pos_x_none(self) -> None: + W, H = 128, 128 + x, y = common.get_neuron_pos(H, W, 5, None) + + self.assertEqual(x, 5) + self.assertEqual(y, H // 2) + + def test_get_neuron_pos_none_y(self) -> None: + W, H = 128, 128 + x, y = common.get_neuron_pos(H, W, None, 5) + + self.assertEqual(x, W // 2) + self.assertEqual(y, 5) + + +class TestDotCossim(BaseTest): + def test_dot_cossim_cossim_pow_0(self) -> None: + x = torch.arange(0, 1 * 3 * 4 * 4).view(1, 3, 4, 4).float() + y = torch.roll(x.clone(), shifts=(1, 2, 2, 1), dims=(0, 1, 2, 3)) + test_output = common._dot_cossim(x, y, cossim_pow=0.0) + + expected_output = torch.tensor( + [ + [ + [1040.0, 968.0, 1094.0, 1226.0], + [1604.0, 1508.0, 1658.0, 1814.0], + [1112.0, 944.0, 1070.0, 1202.0], + [1676.0, 1484.0, 1634.0, 1790.0], + ] + ] + ) + assertTensorAlmostEqual(self, test_output, expected_output) + + def test_dot_cossim_cossim_pow_4(self) -> None: + x = torch.arange(0, 1 * 3 * 4 * 4).view(1, 3, 4, 4).float() + y = torch.roll(x.clone(), shifts=(1, 2, 2, 1), dims=(0, 1, 2, 3)) + test_output = common._dot_cossim(x, y, cossim_pow=4.0) + + expected_output = torch.tensor( + [ + [ + [101.9391, 89.0743, 124.8861, 168.7577], + [314.2930, 282.3505, 352.6324, 432.1260], + [133.2007, 80.3036, 114.3202, 156.4043], + [365.9309, 266.5905, 335.3027, 413.3354], + ] + ] + ) + assertTensorAlmostEqual(self, test_output, expected_output, delta=0.001) + + +class TestHueToRGB(BaseTest): + def test_hue_to_rgb_n_groups_4_warp_true(self) -> None: + n_groups = 4 + channels = list(range(n_groups)) + test_outputs = [] + for ch in channels: + output = common.hue_to_rgb(360 * ch / n_groups) + test_outputs.append(output) + test_outputs = torch.stack(test_outputs) + expected_outputs = torch.tensor( + [ + [1.0000, 0.0000, 0.0000], + [0.5334, 0.8459, 0.0000], + [0.0000, 0.7071, 0.7071], + [0.5334, 0.0000, 0.8459], + ] + ) + assertTensorAlmostEqual(self, test_outputs, expected_outputs) + + def test_hue_to_rgb_n_groups_4_warp_false(self) -> None: + n_groups = 4 + channels = list(range(n_groups)) + test_outputs = [] + for ch in channels: + output = common.hue_to_rgb(360 * ch / n_groups, warp=False) + test_outputs.append(output) + test_outputs = torch.stack(test_outputs) + expected_outputs = torch.tensor( + [ + [1.0000, 0.0000, 0.0000], + [0.3827, 0.9239, 0.0000], + [0.0000, 0.7071, 0.7071], + [0.3827, 0.0000, 0.9239], + ] + ) + assertTensorAlmostEqual(self, test_outputs, expected_outputs) + + def test_hue_to_rgb_n_groups_3_warp_true(self) -> None: + n_groups = 3 + channels = list(range(n_groups)) + test_outputs = [] + for ch in channels: + output = common.hue_to_rgb(360 * ch / n_groups) + test_outputs.append(output) + test_outputs = torch.stack(test_outputs) + expected_outputs = torch.tensor( + [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]] + ) + assertTensorAlmostEqual(self, test_outputs, expected_outputs, delta=0.0) + + def test_hue_to_rgb_n_groups_2_warp_true(self) -> None: + n_groups = 2 + channels = list(range(n_groups)) + test_outputs = [] + for ch in channels: + output = common.hue_to_rgb(360 * ch / n_groups) + test_outputs.append(output) + test_outputs = torch.stack(test_outputs) + expected_outputs = torch.tensor( + [[1.0000, 0.0000, 0.0000], [0.0000, 0.7071, 0.7071]] + ) + assertTensorAlmostEqual(self, test_outputs, expected_outputs) + + def test_hue_to_rgb_n_groups_2_warp_false(self) -> None: + n_groups = 2 + channels = list(range(n_groups)) + test_outputs = [] + for ch in channels: + output = common.hue_to_rgb(360 * ch / n_groups, warp=False) + test_outputs.append(output) + test_outputs = torch.stack(test_outputs) + expected_outputs = torch.tensor( + [[1.0000, 0.0000, 0.0000], [0.0000, 0.7071, 0.7071]] + ) + assertTensorAlmostEqual(self, test_outputs, expected_outputs) + + +class TestNChannelsToRGB(BaseTest): + def test_nchannels_to_rgb_collapse(self) -> None: + test_input = torch.arange(0, 1 * 4 * 4 * 4).view(1, 4, 4, 4).float() + test_output = common.nchannels_to_rgb(test_input, warp=True) + expected_output = torch.tensor( + [ + [ + [ + [30.3782, 31.5489, 32.7147, 33.8773], + [35.0379, 36.1975, 37.3568, 38.5163], + [39.6765, 40.8378, 42.0003, 43.1642], + [44.3296, 45.4967, 46.6655, 47.8360], + ], + [ + [31.1266, 32.0951, 33.0678, 34.0451], + [35.0270, 36.0137, 37.0051, 38.0011], + [39.0015, 40.0063, 41.0152, 42.0282], + [43.0449, 44.0654, 45.0894, 46.1167], + ], + [ + [41.1375, 41.8876, 42.6646, 43.4656], + [44.2882, 45.1304, 45.9901, 46.8658], + [47.7561, 48.6597, 49.5754, 50.5023], + [51.4394, 52.3859, 53.3411, 54.3044], + ], + ] + ] + ) + assertTensorAlmostEqual(self, test_output, expected_output, delta=0.005) + + def test_nchannels_to_rgb_collapse_warp_false(self) -> None: + test_input = torch.arange(0, 1 * 4 * 4 * 4).view(1, 4, 4, 4).float() + test_output = common.nchannels_to_rgb(test_input, warp=False) + expected_output = torch.tensor( + [ + [ + [ + [27.0349, 28.1947, 29.3453, 30.4887], + [31.6266, 32.7605, 33.8914, 35.0201], + [36.1474, 37.2737, 38.3995, 39.5252], + [40.6511, 41.7772, 42.9039, 44.0312], + ], + [ + [31.8525, 32.8600, 33.8708, 34.8851], + [35.9034, 36.9257, 37.9522, 38.9828], + [40.0175, 41.0561, 42.0987, 43.1451], + [44.1951, 45.2486, 46.3054, 47.3655], + ], + [ + [42.8781, 43.6494, 44.4480, 45.2710], + [46.1162, 46.9813, 47.8644, 48.7640], + [49.6786, 50.6069, 51.5477, 52.5000], + [53.4629, 54.4355, 55.4172, 56.4071], + ], + ] + ] + ) + assertTensorAlmostEqual(self, test_output, expected_output, delta=0.005) + + def test_nchannels_to_rgb_increase(self) -> None: + test_input = torch.arange(0, 1 * 2 * 4 * 4).view(1, 2, 4, 4).float() + test_output = common.nchannels_to_rgb(test_input, warp=True) + expected_output = torch.tensor( + [ + [ + [ + [0.0000, 1.8388, 3.4157, 4.8079], + [6.0713, 7.2442, 8.3524, 9.4137], + [10.4405, 11.4414, 12.4226, 13.3886], + [14.3428, 15.2878, 16.2253, 17.1568], + ], + [ + [11.3136, 11.9711, 12.5764, 13.1697], + [13.7684, 14.3791, 15.0039, 15.6425], + [16.2941, 16.9572, 17.6306, 18.3131], + [19.0037, 19.7013, 20.4051, 21.1145], + ], + [ + [11.3136, 11.9711, 12.5764, 13.1697], + [13.7684, 14.3791, 15.0039, 15.6425], + [16.2941, 16.9572, 17.6306, 18.3131], + [19.0037, 19.7013, 20.4051, 21.1145], + ], + ] + ] + ) + assertTensorAlmostEqual(self, test_output, expected_output, delta=0.005) + + def test_nchannels_to_rgb_cuda(self) -> None: + if not torch.cuda.is_available(): + raise unittest.SkipTest( + "Skipping nchannels_to_rgb CUDA test due to not supporting CUDA." + ) + test_input = torch.randn(1, 6, 224, 224).cuda() + test_output = common.nchannels_to_rgb(test_input) + self.assertTrue(test_output.is_cuda) + self.assertEqual(list(test_output.size()), [1, 3, 224, 224]) + + def test_nchannels_to_rgb_jit_module(self) -> None: + if version.parse(torch.__version__) <= version.parse("1.8.0"): + raise unittest.SkipTest( + "Skipping nchannels_to_rgb JIT module test due to insufficient Torch" + + " version." + ) + test_input = torch.randn(1, 6, 224, 224) + jit_nchannels_to_rgb = torch.jit.script(common.nchannels_to_rgb) + test_output = jit_nchannels_to_rgb(test_input) + self.assertEqual(list(test_output.size()), [1, 3, 224, 224]) + + +class TestWeightsToHeatmap2D(BaseTest): + def test_weights_to_heatmap_2d(self) -> None: + x = torch.ones(5, 4) + x[0:1, 0:4] = x[0:1, 0:4] * 0.2 + x[1:2, 0:4] = x[1:2, 0:4] * 0.8 + x[2:3, 0:4] = x[2:3, 0:4] * 0.0 + x[3:4, 0:4] = x[3:4, 0:4] * -0.2 + x[4:5, 0:4] = x[4:5, 0:4] * -0.8 + + x_out = common.weights_to_heatmap_2d(x) + + x_out_expected = torch.tensor( + [ + [ + [0.9639, 0.9639, 0.9639, 0.9639], + [0.8580, 0.8580, 0.8580, 0.8580], + [0.9686, 0.9686, 0.9686, 0.9686], + [0.8102, 0.8102, 0.8102, 0.8102], + [0.2408, 0.2408, 0.2408, 0.2408], + ], + [ + [0.8400, 0.8400, 0.8400, 0.8400], + [0.2588, 0.2588, 0.2588, 0.2588], + [0.9686, 0.9686, 0.9686, 0.9686], + [0.8902, 0.8902, 0.8902, 0.8902], + [0.5749, 0.5749, 0.5749, 0.5749], + ], + [ + [0.7851, 0.7851, 0.7851, 0.7851], + [0.2792, 0.2792, 0.2792, 0.2792], + [0.9686, 0.9686, 0.9686, 0.9686], + [0.9294, 0.9294, 0.9294, 0.9294], + [0.7624, 0.7624, 0.7624, 0.7624], + ], + ] + ) + assertTensorAlmostEqual(self, x_out, x_out_expected, delta=0.01) + + def test_weights_to_heatmap_2d_cuda(self) -> None: + if not torch.cuda.is_available(): + raise unittest.SkipTest( + "Skipping weights_to_heatmap_2d CUDA test due to not supporting CUDA." + ) + x = torch.ones(5, 4) + x[0:1, 0:4] = x[0:1, 0:4] * 0.2 + x[1:2, 0:4] = x[1:2, 0:4] * 0.8 + x[2:3, 0:4] = x[2:3, 0:4] * 0.0 + x[3:4, 0:4] = x[3:4, 0:4] * -0.2 + x[4:5, 0:4] = x[4:5, 0:4] * -0.8 + + x_out = common.weights_to_heatmap_2d(x.cuda()) + + x_out_expected = torch.tensor( + [ + [ + [0.9639, 0.9639, 0.9639, 0.9639], + [0.8580, 0.8580, 0.8580, 0.8580], + [0.9686, 0.9686, 0.9686, 0.9686], + [0.8102, 0.8102, 0.8102, 0.8102], + [0.2408, 0.2408, 0.2408, 0.2408], + ], + [ + [0.8400, 0.8400, 0.8400, 0.8400], + [0.2588, 0.2588, 0.2588, 0.2588], + [0.9686, 0.9686, 0.9686, 0.9686], + [0.8902, 0.8902, 0.8902, 0.8902], + [0.5749, 0.5749, 0.5749, 0.5749], + ], + [ + [0.7851, 0.7851, 0.7851, 0.7851], + [0.2792, 0.2792, 0.2792, 0.2792], + [0.9686, 0.9686, 0.9686, 0.9686], + [0.9294, 0.9294, 0.9294, 0.9294], + [0.7624, 0.7624, 0.7624, 0.7624], + ], + ] + ) + assertTensorAlmostEqual(self, x_out, x_out_expected, delta=0.01) + self.assertTrue(x_out.is_cuda) diff --git a/tests/optim/utils/image/test_dataset.py b/tests/optim/utils/image/test_dataset.py new file mode 100644 index 0000000000..e793577d15 --- /dev/null +++ b/tests/optim/utils/image/test_dataset.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 +import captum.optim._utils.image.dataset as dataset_utils +import torch +from tests.helpers.basic import BaseTest, assertTensorAlmostEqual +from tests.optim.helpers.image_dataset import ImageTestDataset + + +class TestImageCov(BaseTest): + def test_image_cov_3_channels(self) -> None: + test_input = torch.cat( + [ + torch.ones(1, 1, 4, 4) * 0.1, + torch.ones(1, 1, 4, 4) * 0.2, + torch.ones(1, 1, 4, 4) * 0.3, + ], + 1, + ) + + test_output = dataset_utils.image_cov(test_input) + expected_output = torch.tensor( + [ + [ + [0.0073, 0.0067, 0.0067], + [0.0067, 0.0067, 0.0067], + [0.0067, 0.0067, 0.0073], + ] + ] + ) + self.assertEqual(list(test_output.shape), [3, 3]) + assertTensorAlmostEqual(self, test_output, expected_output[0], delta=0.001) + + def test_image_cov_3_channels_batch_5(self) -> None: + test_input = torch.cat( + [ + torch.ones(5, 1, 4, 4) * 0.1, + torch.ones(5, 1, 4, 4) * 0.2, + torch.ones(5, 1, 4, 4) * 0.3, + ], + 1, + ) + + test_output = dataset_utils.image_cov(test_input) + expected_output = torch.tensor( + [ + [0.0365, 0.0333, 0.0335], + [0.0333, 0.0333, 0.0333], + [0.0335, 0.0333, 0.0365], + ] + ) + self.assertEqual(list(test_output.shape), [3, 3]) + assertTensorAlmostEqual(self, test_output, expected_output, delta=0.001) + + def test_image_cov_2_channels(self) -> None: + test_input = torch.randn(1, 2, 5, 5) + test_output = dataset_utils.image_cov(test_input) + self.assertEqual(list(test_output.shape), [2, 2]) + + def test_image_cov_4_channels(self) -> None: + test_input = torch.randn(1, 4, 5, 5) + test_output = dataset_utils.image_cov(test_input) + self.assertEqual(list(test_output.shape), [4, 4]) + + +class TestDatasetCovMatrix(BaseTest): + def test_dataset_cov_matrix(self) -> None: + num_tensors = 100 + + def create_tensor() -> torch.Tensor: + return torch.cat( + [ + torch.ones(1, 224, 224) * 0.9, + torch.ones(1, 224, 224) * 0.5, + torch.ones(1, 224, 224) * 0.4, + ], + 0, + ) + + dataset_tensors = [create_tensor() for x in range(num_tensors)] + test_dataset = ImageTestDataset(dataset_tensors) + dataset_loader = torch.utils.data.DataLoader( + test_dataset, batch_size=10, num_workers=0, shuffle=False + ) + output_mtx = dataset_utils.dataset_cov_matrix(dataset_loader) + expected_mtx = torch.tensor( + [ + [0.0467, 0.0467, 0.0467], + [0.0467, 0.0467, 0.0467], + [0.0467, 0.0467, 0.0467], + ] + ) + assertTensorAlmostEqual(self, output_mtx, expected_mtx, delta=0.001) + + +class TestCovMatrixToKLT(BaseTest): + def test_cov_matrix_to_klt(self) -> None: + test_input = torch.tensor( + [ + [0.0477, 0.0415, 0.0280], + [0.0415, 0.0425, 0.0333], + [0.0280, 0.0333, 0.0419], + ] + ) + output_mtx = dataset_utils.cov_matrix_to_klt(test_input) + expected_mtx = torch.tensor( + [ + [-0.2036, 0.0750, 0.0249], + [-0.2024, 0.0158, -0.0358], + [-0.1749, -0.1056, 0.0124], + ] + ) + assertTensorAlmostEqual(self, output_mtx, expected_mtx, delta=0.001) + + +class TestDatasetKLTMatrix(BaseTest): + def test_dataset_klt_matrix(self) -> None: + num_tensors = 100 + + def create_tensor() -> torch.Tensor: + return torch.cat( + [ + torch.ones(1, 224, 224) * 0.2, + torch.ones(1, 224, 224) * 0.9, + torch.ones(1, 224, 224) * 0.3, + ], + 0, + ) + + dataset_tensors = [create_tensor() for x in range(num_tensors)] + test_dataset = ImageTestDataset(dataset_tensors) + dataset_loader = torch.utils.data.DataLoader( + test_dataset, batch_size=10, num_workers=0, shuffle=False + ) + + klt_transform = dataset_utils.dataset_klt_matrix(dataset_loader) + + expected_mtx = torch.tensor( + [ + [-0.3091, 0.0023, 0.0004], + [-0.3091, -0.0005, -0.0012], + [-0.3091, -0.0018, 0.0008], + ] + ) + assertTensorAlmostEqual(self, klt_transform, expected_mtx, delta=0.001) + + def test_dataset_klt_matrix_randn(self) -> None: + num_tensors = 100 + + def create_tensor() -> torch.Tensor: + return torch.randn(1, 3, 224, 224).clamp(0, 1) + + dataset_tensors = [create_tensor() for x in range(num_tensors)] + test_dataset = ImageTestDataset(dataset_tensors) + dataset_loader = torch.utils.data.DataLoader( + test_dataset, batch_size=10, num_workers=0, shuffle=False + ) + + klt_transform = dataset_utils.dataset_klt_matrix(dataset_loader) + self.assertEqual(list(klt_transform.shape), [3, 3])