diff --git a/captum/optim/__init__.py b/captum/optim/__init__.py old mode 100755 new mode 100644 diff --git a/captum/optim/_core/optimization.py b/captum/optim/_core/optimization.py index 27c5bf3162..c251dfc8ec 100644 --- a/captum/optim/_core/optimization.py +++ b/captum/optim/_core/optimization.py @@ -46,6 +46,7 @@ def __init__( ) -> None: r""" Args: + model (nn.Module): The reference to PyTorch model instance. input_param (nn.Module, optional): A module that generates an input, consumed by the model. @@ -71,6 +72,7 @@ def __init__( def loss(self) -> torch.Tensor: r"""Compute loss value for current iteration. + Returns: *tensor* representing **loss**: - **loss** (*tensor*): @@ -115,18 +117,26 @@ def optimize( lr: float = 0.025, ) -> torch.Tensor: r"""Optimize input based on loss function and objectives. + Args: + stop_criteria (StopCriteria, optional): A function that is called every iteration and returns a bool that determines whether to stop the optimization. See captum.optim.typing.StopCriteria for details. optimizer (Optimizer, optional): An torch.optim.Optimizer used to optimize the input based on the loss function. + loss_summarize_fn (Callable, optional): The function to use for summarizing + tensor outputs from loss functions. + Default: default_loss_summarize + lr: (float, optional): If no optimizer is given, then lr is used as the + learning rate for the Adam optimizer. + Default: 0.025 + Returns: - *list* of *np.arrays* representing the **history**: - - **history** (*list*): - A list of loss values per iteration. - Length of the list corresponds to the number of iterations + history (torch.Tensor): A stack of loss values per iteration. The size + of the dimension on which loss values are stacked corresponds to + the number of iterations. """ stop_criteria = stop_criteria or n_steps(512) optimizer = optimizer or optim.Adam(self.parameters(), lr=lr) @@ -150,10 +160,12 @@ def optimize( def n_steps(n: int, show_progress: bool = True) -> StopCriteria: """StopCriteria generator that uses number of steps as a stop criteria. + Args: n (int): Number of steps to run optimization. show_progress (bool, optional): Whether or not to show progress bar. Default: True + Returns: *StopCriteria* callable """ diff --git a/captum/optim/_core/output_hook.py b/captum/optim/_core/output_hook.py old mode 100755 new mode 100644 index 4bbf5c0fa3..6cfbc4ff2e --- a/captum/optim/_core/output_hook.py +++ b/captum/optim/_core/output_hook.py @@ -8,12 +8,13 @@ from captum.optim._utils.typing import ModuleOutputMapping, TupleOfTensorsOrTensorType -class ModuleReuseException(Exception): - pass - - class ModuleOutputsHook: def __init__(self, target_modules: Iterable[nn.Module]) -> None: + """ + Args: + + target_modules (Iterable of nn.Module): A list of nn.Module targets. + """ self.outputs: ModuleOutputMapping = dict.fromkeys(target_modules, None) self.hooks = [ module.register_forward_hook(self._forward_hook()) @@ -21,6 +22,9 @@ def __init__(self, target_modules: Iterable[nn.Module]) -> None: ] def _reset_outputs(self) -> None: + """ + Delete captured activations. + """ self.outputs = dict.fromkeys(self.outputs.keys(), None) @property @@ -28,6 +32,13 @@ def is_ready(self) -> bool: return all(value is not None for value in self.outputs.values()) def _forward_hook(self) -> Callable: + """ + Return the forward_hook function. + + Returns: + forward_hook (Callable): The forward_hook function. + """ + def forward_hook( module: nn.Module, input: Tuple[torch.Tensor], output: torch.Tensor ) -> None: @@ -49,6 +60,12 @@ def forward_hook( return forward_hook def consume_outputs(self) -> ModuleOutputMapping: + """ + Collect target activations and return them. + + Returns: + outputs (ModuleOutputMapping): The captured outputs. + """ if not self.is_ready: warn( "Consume captured outputs, but not all requested target outputs " @@ -63,11 +80,16 @@ def targets(self) -> Iterable[nn.Module]: return self.outputs.keys() def remove_hooks(self) -> None: + """ + Remove hooks. + """ for hook in self.hooks: hook.remove() def __del__(self) -> None: - # print(f"DEL HOOKS!: {list(self.outputs.keys())}") + """ + Ensure that using 'del' properly deletes hooks. + """ self.remove_hooks() @@ -77,16 +99,34 @@ class ActivationFetcher: """ def __init__(self, model: nn.Module, targets: Iterable[nn.Module]) -> None: + """ + Args: + + model (nn.Module): The reference to PyTorch model instance. + targets (nn.Module or list of nn.Module): The target layers to + collect activations from. + """ super(ActivationFetcher, self).__init__() self.model = model self.layers = ModuleOutputsHook(targets) def __call__(self, input_t: TupleOfTensorsOrTensorType) -> ModuleOutputMapping: + """ + Args: + + input_t (tensor or tuple of tensors, optional): The input to use + with the specified model. + + Returns: + activations_dict: An dict containing the collected activations. The keys + for the returned dictionary are the target layers. + """ + try: with warnings.catch_warnings(): warnings.simplefilter("ignore") self.model(input_t) - activations = self.layers.consume_outputs() + activations_dict = self.layers.consume_outputs() finally: self.layers.remove_hooks() - return activations + return activations_dict diff --git a/captum/optim/_param/image/images.py b/captum/optim/_param/image/images.py old mode 100755 new mode 100644 index b0852a512c..cf4b01da0d --- a/captum/optim/_param/image/images.py +++ b/captum/optim/_param/image/images.py @@ -27,6 +27,15 @@ def __new__( *args, **kwargs, ) -> torch.Tensor: + """ + Args: + + x (list or np.ndarray or torch.Tensor): A list, NumPy array, or PyTorch + tensor to create an `ImageTensor` from. + + Returns: + x (ImageTensor): An `ImageTensor` instance. + """ if isinstance(x, torch.Tensor) and x.is_cuda: x.show = MethodType(cls.show, x) x.export = MethodType(cls.export, x) @@ -36,6 +45,20 @@ def __new__( @classmethod def open(cls, path: str, scale: float = 255.0, mode: str = "RGB") -> "ImageTensor": + """ + Load an image file from a URL or local filepath directly into an `ImageTensor`. + + Args: + + path (str): A URL or filepath to an image. + scale (float, optional): The image scale to use. + Default: 255.0 + mode (str, optional): The image loading mode to use. + Default: "RGB" + + Returns: + x (ImageTensor): An `ImageTensor` instance. + """ if path.startswith("https://") or path.startswith("http://"): response = requests.get(path, stream=True) img = Image.open(response.raw) @@ -73,9 +96,31 @@ def __torch_function__( def show( self, figsize: Optional[Tuple[int, int]] = None, scale: float = 255.0 ) -> None: + """ + Display an `ImageTensor`. + + Args: + + figsize (Tuple[int, int], optional): height & width to use + for displaying the `ImageTensor` figure. + scale (float, optional): Value to multiply the `ImageTensor` by so that + it's value range is [0-255] for display. + Default: 255.0 + """ show(self, figsize=figsize, scale=scale) def export(self, filename: str, scale: float = 255.0) -> None: + """ + Save an `ImageTensor` as an image file. + + Args: + + filename (str): The filename to use when saving the `ImageTensor` as an + image file. + scale (float, optional): Value to multiply the `ImageTensor` by so that + it's value range is [0-255] for saving. + Default: 255.0 + """ save_tensor_as_image(self, filename=filename, scale=scale) @@ -89,7 +134,9 @@ class ImageParameterization(InputParameterization): class FFTImage(ImageParameterization): - """Parameterize an image using inverse real 2D FFT""" + """ + Parameterize an image using inverse real 2D FFT + """ def __init__( self, @@ -98,6 +145,20 @@ def __init__( batch: int = 1, init: Optional[torch.Tensor] = None, ) -> None: + """ + Args: + + size (Tuple[int, int]): The height & width dimensions to use for the + parameterized output image tensor. + channels (int, optional): The number of channels to use for each image. + Default: 3 + batch (int, optional): The number of images to stack along the batch + dimension. + Default: 1 + init (torch.tensor, optional): Optionally specify a tensor to + use instead of creating one. + Default: None + """ super().__init__() if init is None: assert len(size) == 2 @@ -137,13 +198,33 @@ def __init__( self.fourier_coeffs = nn.Parameter(fourier_coeffs) def rfft2d_freqs(self, height: int, width: int) -> torch.Tensor: - """Computes 2D spectrum frequencies.""" + """ + Computes 2D spectrum frequencies. + + Args: + + height (int): The h dimension of the 2d frequency scale. + width (int): The w dimension of the 2d frequency scale. + + Returns: + **tensor** (tensor): A 2d frequency scale tensor. + """ + fy = self.torch_fftfreq(height)[:, None] fx = self.torch_fftfreq(width)[: width // 2 + 1] return torch.sqrt((fx * fx) + (fy * fy)) def get_fft_funcs(self) -> Tuple[Callable, Callable, Callable]: - """Support older versions of PyTorch""" + """ + Support older versions of PyTorch. This function ensures that the same FFT + operations are carried regardless of whether your PyTorch version has the + torch.fft update. + + Returns: + fft functions (tuple of Callable): A list of FFT functions + to use for irfft, rfft, and fftfreq operations. + """ + if TORCH_VERSION >= "1.7.0": import torch.fft @@ -180,12 +261,21 @@ def torch_fftfreq(v: int, d: float = 1.0) -> torch.Tensor: return torch_rfft, torch_irfft, torch_fftfreq def forward(self) -> torch.Tensor: + """ + Returns: + **output** (torch.tensor): A spatially recorrelated tensor. + """ + scaled_spectrum = self.fourier_coeffs * self.spectrum_scale output = self.torch_irfft(scaled_spectrum) return output.refine_names("B", "C", "H", "W") class PixelImage(ImageParameterization): + """ + Parameterize a simple pixel image tensor that requires no additional transforms. + """ + def __init__( self, size: Tuple[int, int] = None, @@ -193,6 +283,20 @@ def __init__( batch: int = 1, init: Optional[torch.Tensor] = None, ) -> None: + """ + Args: + + size (Tuple[int, int]): The height & width dimensions to use for the + parameterized output image tensor. + channels (int, optional): The number of channels to use for each image. + Default: 3 + batch (int, optional): The number of images to stack along the batch + dimension. + Default: 1 + init (torch.tensor, optional): Optionally specify a tensor to + use instead of creating one. + Default: None + """ super().__init__() if init is None: assert size is not None and channels is not None and batch is not None @@ -212,6 +316,7 @@ def forward(self) -> torch.Tensor: class LaplacianImage(ImageParameterization): """ TODO: Fix divison by 6 in setup_input when init is not None. + Parameterize an image tensor with a laplacian pyramid. """ def __init__( @@ -221,11 +326,25 @@ def __init__( batch: int = 1, init: Optional[torch.Tensor] = None, ) -> None: + """ + Args: + + size (Tuple[int, int]): The height & width dimensions to use for the + parameterized output image tensor. + channels (int, optional): The number of channels to use for each image. + Default: 3 + batch (int, optional): The number of images to stack along the batch + dimension. + Default: 1 + init (torch.tensor, optional): Optionally specify a tensor to + use instead of creating one. + Default: None + """ super().__init__() power = 0.1 if init is None: - tensor_params, self.scaler = self.setup_input(size, channels, power, init) + tensor_params, self.scaler = self._setup_input(size, channels, power, init) self.tensor_params = torch.nn.ModuleList( [deepcopy(tensor_params) for b in range(batch)] @@ -234,13 +353,13 @@ def __init__( init = init.unsqueeze(0) if init.dim() == 3 else init P = [] for b in range(init.size(0)): - tensor_params, self.scaler = self.setup_input( + tensor_params, self.scaler = self._setup_input( size, channels, power, init[b].unsqueeze(0) ) P.append(tensor_params) self.tensor_params = torch.nn.ModuleList(P) - def setup_input( + def _setup_input( self, size: Tuple[int, int], channels: int, @@ -264,16 +383,26 @@ def setup_input( tensor_params = torch.nn.ParameterList(tensor_params) return tensor_params, scaler - def create_tensor(self, params_list: torch.nn.ParameterList) -> torch.Tensor: - A = [] + def _create_tensor(self, params_list: torch.nn.ParameterList) -> torch.Tensor: + """ + Resize tensor parameters to the target size. + + Args: + + params_list (torch.nn.ParameterList): List of tensors to resize. + + Returns: + **tensor** (torch.Tensor): The sum of all tensor parameters. + """ + A: List[torch.Tensor] = [] for xi, upsamplei in zip(params_list, self.scaler): A.append(upsamplei(xi)) return torch.sum(torch.cat(A), 0) + 0.5 def forward(self) -> torch.Tensor: - A = [] + A: List[torch.Tensor] = [] for params_list in self.tensor_params: - tensor = self.create_tensor(params_list) + tensor = self._create_tensor(params_list) A.append(tensor) return torch.stack(A).refine_names("B", "C", "H", "W") @@ -297,6 +426,17 @@ def __init__( parameterization: ImageParameterization = None, offset: Union[int, Tuple[int], Tuple[Tuple[int]], None] = None, ) -> None: + """ + Args: + + shapes (list of int or list of list of ints): The shapes of the shared + tensors to use for creating the nn.Parameter tensors. + parameterization (ImageParameterization): An image parameterization + instance. + offset (int or list of int or list of list of ints , optional): The offsets + to use for the shared tensors. + Default: None + """ super().__init__() assert shapes is not None A = [] @@ -308,9 +448,21 @@ def __init__( A.append(torch.nn.Parameter(torch.randn([batch, channels, height, width]))) self.shared_init = torch.nn.ParameterList(A) self.parameterization = parameterization - self.offset = self.get_offset(offset, len(A)) if offset is not None else None + self.offset = self._get_offset(offset, len(A)) if offset is not None else None + + def _get_offset(self, offset: Union[int, Tuple[int]], n: int) -> List[List[int]]: + """ + Given offset values, return a list of offsets for _apply_offset to use. + + Args: + + offset (int or list of int or list of list of ints , optional): The offsets + to use for the shared tensors. + n (int): The number of tensors needing offset values. - def get_offset(self, offset: Union[int, Tuple[int]], n: int) -> List[List[int]]: + Returns: + **offset** (list of list of int): A list of offset values. + """ if type(offset) is tuple or type(offset) is list: if type(offset[0]) is tuple or type(offset[0]) is list: assert len(offset) == n and all(len(t) == 4 for t in offset) @@ -323,8 +475,19 @@ def get_offset(self, offset: Union[int, Tuple[int]], n: int) -> List[List[int]]: assert all([all([type(o) is int for o in v]) for v in offset]) return offset - def apply_offset(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]: - A = [] + def _apply_offset(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]: + """ + Apply list of offsets to list of tensors. + + Args: + + x_list (list of torch.Tensor): list of tensors to offset. + + Returns: + **A** (list of torch.Tensor): list of offset tensors. + """ + + A: List[torch.Tensor] = [] for x, offset in zip(x_list, self.offset): assert x.dim() == 4 size = list(x.size()) @@ -345,13 +508,23 @@ def apply_offset(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]: A.append(x) return A - def interpolate_tensor( + def _interpolate_tensor( self, x: torch.Tensor, batch: int, channels: int, height: int, width: int ) -> torch.Tensor: """ - Linear interpolation for 4D, 5D, and 6D tensors. - If the batch dimension needs to be resized, - we move it's location temporarily for F.interpolate. + Linear interpolation for 4D, 5D, and 6D tensors. If the batch dimension needs + to be resized, we move it's location temporarily for F.interpolate. + + Args: + + x (torch.Tensor): The tensor to resize. + batch (int): The batch size to resize the tensor to. + channels (int): The channel size to resize the tensor to. + height (int): The height to resize the tensor to. + width (int): The width to resize the tensor to. + + Returns: + **tensor** (torch.Tensor): A resized tensor. """ if x.size(1) == channels: @@ -376,7 +549,7 @@ def interpolate_tensor( def forward(self) -> torch.Tensor: image = self.parameterization() x = [ - self.interpolate_tensor( + self._interpolate_tensor( shared_tensor, image.size(0), image.size(1), @@ -386,7 +559,7 @@ def forward(self) -> torch.Tensor: for shared_tensor in self.shared_init ] if self.offset is not None: - x = self.apply_offset(x) + x = self._apply_offset(x) return (image + sum(x)).refine_names("B", "C", "H", "W") @@ -401,21 +574,6 @@ class NaturalImage(ImageParameterization): If a model requires a normalization step, such as normalizing imagenet RGB values, or rescaling to [0,255], it can perform those steps with the provided transforms or inside its computation. - - Arguments: - size (Tuple[int, int]): The height and width to use for the nn.Parameter image - tensor. - channels (int): The number of channels to use when creating the - nn.Parameter tensor. - batch (int): The number of channels to use when creating the - nn.Parameter tensor, or stacking init images. - parameterization (ImageParameterization, optional): An image parameterization - class. - squash_func (Callable[[torch.Tensor], torch.Tensor]], optional): The squash - function to use after color recorrelation. A funtion or lambda function. - decorrelation_module (nn.Module, optional): A ToRGB instance. - decorrelate_init (bool, optional): Whether or not to apply color decorrelation - to the init tensor input. """ def __init__( @@ -429,6 +587,30 @@ def __init__( decorrelation_module: Optional[nn.Module] = ToRGB(transform="klt"), decorrelate_init: bool = True, ) -> None: + """ + Args: + + size (Tuple[int, int], optional): The height and width to use for the + nn.Parameter image tensor. + Default: (224, 224) + channels (int, optional): The number of channels to use when creating the + nn.Parameter tensor. + Default: 3 + batch (int, optional): The number of channels to use when creating the + nn.Parameter tensor, or stacking init images. + Default: 1 + parameterization (ImageParameterization, optional): An image + parameterization class. + Default: FFTImage + squash_func (Callable[[torch.Tensor], torch.Tensor]], optional): The squash + function to use after color recorrelation. A funtion or lambda function. + Default: None + decorrelation_module (nn.Module, optional): A ToRGB instance. + Default: ToRGB + decorrelate_init (bool, optional): Whether or not to apply color + decorrelation to the init tensor input. + Default: True + """ super().__init__() self.decorrelate = decorrelation_module if init is not None: diff --git a/captum/optim/_param/image/transforms.py b/captum/optim/_param/image/transforms.py index ba3c146e40..93df78243e 100644 --- a/captum/optim/_param/image/transforms.py +++ b/captum/optim/_param/image/transforms.py @@ -13,15 +13,31 @@ class BlendAlpha(nn.Module): r"""Blends a 4 channel input parameterization into an RGB image. - You can specify a fixed background, or a random one will be used by default. """ def __init__(self, background: Optional[torch.Tensor] = None) -> None: + """ + Args: + + background (tensor, optional): An NCHW image tensor to be used as the + Alpha channel's background. + Default: None + """ super().__init__() self.background = background def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Blend the Alpha channel into the RGB channels. + + Args: + + x (torch.Tensor): RGBA image tensor to blend into an RGB image tensor. + + Returns: + **blended** (torch.Tensor): RGB image tensor. + """ assert x.dim() == 4 assert x.size(1) == 4 rgb, alpha = x[:, :3, ...], x[:, 3:4, ...] @@ -36,6 +52,16 @@ class IgnoreAlpha(nn.Module): r"""Ignores a 4th channel""" def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Ignore the alpha channel. + + Args: + + x (torch.Tensor): RGBA image tensor. + + Returns: + **rgb** (torch.Tensor): RGB image tensor without the alpha channel. + """ assert x.dim() == 4 assert x.size(1) == 4 rgb = x[:, :3, ...] @@ -52,16 +78,17 @@ class ToRGB(nn.Module): [0] Y. Ohta, T. Kanade, and T. Sakai, "Color information for region segmentation," Computer Graphics and Image Processing, vol. 13, no. 3, pp. 222–241, 1980 https://www.sciencedirect.com/science/article/pii/0146664X80900477 - - Arguments: - transform (str or tensor): Either a string for one of the precalculated - transform matrices, or a 3x3 matrix for the 3 RGB channels of input - tensors. """ @staticmethod def klt_transform() -> torch.Tensor: - """Karhunen-Loève transform (KLT) measured on ImageNet""" + """ + Karhunen-Loève transform (KLT) measured on ImageNet + + Returns: + **transform** (torch.Tensor): A Karhunen-Loève transform (KLT) measured on + the ImageNet dataset. + """ KLT = [[0.26, 0.09, 0.02], [0.27, 0.00, -0.05], [0.27, -0.09, 0.03]] transform = torch.Tensor(KLT).float() transform = transform / torch.max(torch.norm(transform, dim=0)) @@ -69,6 +96,11 @@ def klt_transform() -> torch.Tensor: @staticmethod def i1i2i3_transform() -> torch.Tensor: + """ + Returns: + **transform** (torch.Tensor): An approximation of natural colors transform + (i1i2i3). + """ i1i2i3_matrix = [ [1 / 3, 1 / 3, 1 / 3], [1 / 2, 0, -1 / 2], @@ -77,6 +109,13 @@ def i1i2i3_transform() -> torch.Tensor: return torch.Tensor(i1i2i3_matrix) def __init__(self, transform: Union[str, torch.Tensor] = "klt") -> None: + """ + Args: + + transform (str or tensor): Either a string for one of the precalculated + transform matrices, or a 3x3 matrix for the 3 RGB channels of input + tensors. + """ super().__init__() assert isinstance(transform, str) or torch.is_tensor(transform) if torch.is_tensor(transform): @@ -93,6 +132,18 @@ def __init__(self, transform: Union[str, torch.Tensor] = "klt") -> None: ) def forward(self, x: torch.Tensor, inverse: bool = False) -> torch.Tensor: + """ + Args: + + x (torch.tensor): A CHW or NCHW RGB or RGBA image tensor. + inverse (bool, optional): Whether to recorrelate or decorrelate colors. + Default: False. + + Returns: + chw (torch.tensor): A tensor with it's colors recorrelated or + decorrelated. + """ + assert x.dim() == 3 or x.dim() == 4 # alpha channel is taken off... @@ -128,15 +179,6 @@ def forward(self, x: torch.Tensor, inverse: bool = False) -> torch.Tensor: class CenterCrop(torch.nn.Module): """ Center crop a specified amount from a tensor. - Arguments: - size (int, sequence, int): Number of pixels to center crop away. - pixels_from_edges (bool, optional): Whether to treat crop size - values as the number of pixels from the tensor's edge, or an - exact shape in the center. - offset_left (bool, optional): If the cropped away sides are not - equal in size, offset center by +1 to the left and/or top. - Default is set to False. This parameter is only valid when - pixels_from_edges is False. """ def __init__( @@ -145,6 +187,18 @@ def __init__( pixels_from_edges: bool = False, offset_left: bool = False, ) -> None: + """ + Args: + + size (int, sequence, int): Number of pixels to center crop away. + pixels_from_edges (bool, optional): Whether to treat crop size + values as the number of pixels from the tensor's edge, or an + exact shape in the center. + offset_left (bool, optional): If the cropped away sides are not + equal in size, offset center by +1 to the left and/or top. + This parameter is only valid when `pixels_from_edges` is False. + Default: False + """ super().__init__() self.crop_vals = size self.pixels_from_edges = pixels_from_edges @@ -153,10 +207,12 @@ def __init__( def forward(self, input: torch.Tensor) -> torch.Tensor: """ Center crop an input. - Arguments: + + Args: input (torch.Tensor): Input to center crop. + Returns: - tensor (torch.Tensor): A center cropped tensor. + **tensor** (torch.Tensor): A center cropped *tensor*. """ return center_crop( @@ -172,18 +228,22 @@ def center_crop( ) -> torch.Tensor: """ Center crop a specified amount from a tensor. - Arguments: + + Args: + input (tensor): A CHW or NCHW image tensor to center crop. size (int, sequence, int): Number of pixels to center crop away. pixels_from_edges (bool, optional): Whether to treat crop size values as the number of pixels from the tensor's edge, or an exact shape in the center. + Default: False offset_left (bool, optional): If the cropped away sides are not equal in size, offset center by +1 to the left and/or top. - Default is set to False. This parameter is only valid when - pixels_from_edges is False. + This parameter is only valid when `pixels_from_edges` is False. + Default: False + Returns: - *tensor*: A center cropped tensor. + **tensor**: A center cropped *tensor*. """ assert input.dim() == 3 or input.dim() == 4 @@ -218,6 +278,13 @@ def _rand_select( ) -> Union[int, float, torch.Tensor]: """ Randomly return a single value from the provided tuple, list, or tensor. + + Args: + + transform_values (sequence): A sequence of values to randomly select from. + + Returns: + **value**: A single value from the specified sequence. """ n = torch.randint(low=0, high=len(transform_values), size=[1]).item() return transform_values[n] @@ -226,11 +293,14 @@ def _rand_select( class RandomScale(nn.Module): """ Apply random rescaling on a NCHW tensor. - Arguments: - scale (float, sequence): Tuple of rescaling values to randomly select from. """ def __init__(self, scale: NumSeqOrTensorType) -> None: + """ + Args: + + scale (float, sequence): Tuple of rescaling values to randomly select from. + """ super().__init__() self.scale = scale @@ -258,6 +328,16 @@ def scale_tensor( return x def forward(self, input: torch.Tensor) -> torch.Tensor: + """ + Randomly scale / zoom in or out of a tensor. + + Args: + + input (torch.Tensor): Input to randomly scale. + + Returns: + **tensor** (torch.Tensor): Scaled *tensor*. + """ scale = _rand_select(self.scale) return self.scale_tensor(input, scale=scale) @@ -265,11 +345,14 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: class RandomSpatialJitter(torch.nn.Module): """ Apply random spatial translations on a NCHW tensor. - Arguments: - translate (int): """ def __init__(self, translate: int) -> None: + """ + Args: + + translate (int): The max horizontal and vertical translation to use. + """ super().__init__() self.pad_range = 2 * translate self.pad = nn.ReflectionPad2d(translate) @@ -287,6 +370,16 @@ def translate_tensor(self, x: torch.Tensor, insets: torch.Tensor) -> torch.Tenso return cropped def forward(self, input: torch.Tensor) -> torch.Tensor: + """ + Randomly translate an input tensor's height and width dimensions. + + Args: + + input (torch.Tensor): Input to randomly translate. + + Returns: + **tensor** (torch.Tensor): A randomly translated *tensor*. + """ insets = torch.randint(high=self.pad_range, size=(2,)) return self.translate_tensor(input, insets) @@ -298,10 +391,25 @@ class ScaleInputRange(nn.Module): """ def __init__(self, multiplier: float = 1.0) -> None: + """ + Args: + + multiplier (float, optional): A float value used to scale the input. + """ super().__init__() self.multiplier = multiplier def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Scale an input tensor's values. + + Args: + + x (torch.Tensor): Input to scale values of. + + Returns: + **tensor** (torch.Tensor): tensor with it's values scaled. + """ return x * self.multiplier @@ -311,6 +419,16 @@ class RGBToBGR(nn.Module): """ def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Perform RGB to BGR conversion on an input + + Args: + + x (torch.Tensor): RGB image tensor to convert to BGR. + + Returns: + **BGR tensor** (torch.Tensor): A BGR tensor. + """ assert x.dim() == 4 assert x.size(1) == 3 return x[:, [2, 1, 0]] @@ -354,13 +472,6 @@ class GaussianSmoothing(nn.Module): Apply gaussian smoothing on a 1d, 2d or 3d tensor. Filtering is performed seperately for each channel in the input using a depthwise convolution. - Arguments: - channels (int, sequence): Number of channels of the input tensors. Output will - have this number of channels as well. - kernel_size (int, sequence): Size of the gaussian kernel. - sigma (float, sequence): Standard deviation of the gaussian kernel. - dim (int, optional): The number of dimensions of the data. - Default value is 2 (spatial). """ def __init__( @@ -370,6 +481,16 @@ def __init__( sigma: Union[float, Sequence[float]], dim: int = 2, ) -> None: + """ + Args: + + channels (int, sequence): Number of channels of the input tensors. Output + will have this number of channels as well. + kernel_size (int, sequence): Size of the gaussian kernel. + sigma (float, sequence): Standard deviation of the gaussian kernel. + dim (int, optional): The number of dimensions of the data. + Default value is 2 (spatial). + """ super().__init__() if isinstance(kernel_size, numbers.Number): kernel_size = [kernel_size] * dim @@ -414,10 +535,13 @@ def __init__( def forward(self, input: torch.Tensor) -> torch.Tensor: """ Apply gaussian filter to input. - Arguments: + + Args: + input (torch.Tensor): Input to apply gaussian filter on. + Returns: - filtered (torch.Tensor): Filtered output. + **filtered** (torch.Tensor): Filtered output. """ return self.conv(input, weight=self.weight, groups=self.groups) @@ -431,6 +555,16 @@ class SymmetricPadding(torch.autograd.Function): def forward( ctx: torch.autograd.Function, x: torch.Tensor, padding: List[List[int]] ) -> torch.Tensor: + """ + Apply NumPy symmetric padding to an input tensor while preserving the gradient. + + Args: + + x (torch.Tensor): Input to apply symmetric padding on. + + Returns: + **tensor** (torch.Tensor): Padded tensor. + """ ctx.padding = padding x_device = x.device x = x.cpu() @@ -444,6 +578,16 @@ def forward( def backward( ctx: torch.autograd.Function, grad_output: torch.Tensor ) -> Tuple[torch.Tensor, None]: + """ + Crop away symmetric padding. + + Args: + + grad_output (torch.Tensor): Input to remove symmetric padding from. + + Returns: + **grad_input** (torch.Tensor): Unpadded tensor. + """ grad_input = grad_output.clone() B, C, H, W = grad_input.size() b1, b2 = ctx.padding[0] @@ -460,26 +604,44 @@ class NChannelsToRGB(nn.Module): """ def __init__(self, warp: bool = False) -> None: + """ + Args: + + warp (bool, optional): Whether or not to make the resulting RGB colors more + distict from each other. Default is set to False. + """ super().__init__() self.warp = warp def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Reduce any number of channels down to 3. + + Args: + + x (torch.Tensor): Input to reduce channel dimensions on. + + Returns: + **3 channel RGB tensor** (torch.Tensor): RGB image tensor. + """ assert x.dim() == 4 return nchannels_to_rgb(x, self.warp) class RandomCrop(nn.Module): """ - Randomly crop out a specific size from an NCHW image tensor. - ​ - Args: - crop_size (int, sequence, int): The desired cropped output size. + Randomly crop out a specific size from an NCHW image tensor. """ def __init__( self, crop_size: IntSeqOrIntType, ) -> None: + """ + Args: + + crop_size (int, sequence, int): The desired cropped output size. + """ super().__init__() crop_size = [crop_size] * 2 if not hasattr(crop_size, "__iter__") else crop_size crop_size = list(crop_size) * 2 if len(crop_size) == 1 else crop_size diff --git a/captum/optim/_utils/circuits.py b/captum/optim/_utils/circuits.py index dfe97b204e..3dc2f3e524 100644 --- a/captum/optim/_utils/circuits.py +++ b/captum/optim/_utils/circuits.py @@ -30,8 +30,8 @@ def extract_expanded_weights( specified for target2. target2 (nn.Module): The end target layer. Must be above the layer specified for target1. - crop_shape (int or tuple of ints, optional): Specify the output weight - size to enter crop away padding. + crop_shape (int or tuple of ints, optional): Specify the exact output size + to crop out. model_input (tensor or tuple of tensors, optional): The input to use with the specified model. crop_func (Callable, optional): Specify a function to crop away the padding diff --git a/captum/optim/_utils/image/dataset.py b/captum/optim/_utils/image/dataset.py index 69a2be3453..fcc6d03742 100644 --- a/captum/optim/_utils/image/dataset.py +++ b/captum/optim/_utils/image/dataset.py @@ -3,7 +3,12 @@ def image_cov(tensor: torch.Tensor) -> torch.Tensor: """ - Calculate a tensor's RGB covariance matrix + Calculate a tensor's RGB covariance matrix. + + Args: + tensor (tensor): An NCHW image tensor. + Returns: + *tensor*: An RGB covariance matrix for the specified tensor. """ tensor = tensor.reshape(-1, 3) @@ -14,6 +19,12 @@ def image_cov(tensor: torch.Tensor) -> torch.Tensor: def dataset_cov_matrix(loader: torch.utils.data.DataLoader) -> torch.Tensor: """ Calculate the covariance matrix for an image dataset. + + Args: + loader (torch.utils.data.DataLoader): The reference to a PyTorch + dataloader instance. + Returns: + *tensor*: A covariance matrix for the specified dataset. """ cov_mtx = torch.zeros(3, 3) @@ -30,6 +41,13 @@ def cov_matrix_to_klt( ) -> torch.Tensor: """ Convert a cov matrix to a klt matrix. + + Args: + cov_mtx (tensor): A 3 by 3 covariance matrix generated from a dataset. + normalize (bool): Whether or not to normalize the resulting KLT matrix. + epsilon (float): + Returns: + *tensor*: A KLT matrix for the specified covariance matrix. """ U, S, V = torch.svd(cov_mtx) @@ -47,6 +65,13 @@ def dataset_klt_matrix( a Karhunen-Loève transform (KLT) matrix, for a dataset. The color correlation matrix can then used in color decorrelation transforms for models trained on the dataset. + + Args: + loader (torch.utils.data.DataLoader): The reference to a PyTorch + dataloader instance. + normalize (bool): Whether or not to normalize the resulting KLT matrix. + Returns: + *tensor*: A KLT matrix for the specified dataset. """ cov_mtx = dataset_cov_matrix(loader) diff --git a/captum/optim/_utils/reducer.py b/captum/optim/_utils/reducer.py index 33b0fb13dd..2696d003d6 100644 --- a/captum/optim/_utils/reducer.py +++ b/captum/optim/_utils/reducer.py @@ -16,12 +16,20 @@ class ChannelReducer: """ - Dimensionality reduction for the channel dimension of an input. - The default reduction_alg is NMF from sklearn, which requires users - to put input on CPU before passing to fit_transform. - + Dimensionality reduction for the channel dimension of an input tensor. Olah, et al., "The Building Blocks of Interpretability", Distill, 2018. - See: https://distill.pub/2018/building-blocks/ + + See here for more information: https://distill.pub/2018/building-blocks/ + + Args: + n_components (int, optional): The number of channels to reduce the target + dimension to. + reduction_alg (str or callable, optional): The desired dimensionality + reduction algorithm to use. The default reduction_alg is set to NMF from + sklearn, which requires users to put inputs on CPU before passing them to + fit_transform. + **kwargs (optional): Arbitrary keyword arguments used by the specified + reduction_alg. """ def __init__( @@ -63,9 +71,13 @@ def fit_transform( ) -> torch.Tensor: """ Perform dimensionality reduction on an input tensor. - - If swap_2nd_and_last_dims is true, input channels are expected to be in the - second dimension unless the input tensor has a shape of CHW. + Args: + tensor (tensor): A tensor to perform dimensionality reduction on. + swap_2nd_and_last_dims (bool, optional): If true, input channels are + expected to be in the second dimension unless the input tensor has a + shape of CHW. Default is set to True. + Returns: + *tensor*: A tensor with one of it's dimensions reduced. """ if x.dim() == 3 and swap_2nd_and_last_dims: @@ -115,8 +127,15 @@ def __dir__(self) -> List: def posneg(x: torch.Tensor, dim: int = 0) -> torch.Tensor: """ - Hack that makes a matrix positive by concatination in order to simulate - one-sided NMF with regular NMF + Hack that makes a matrix positive by concatination in order to simulate one-sided + NMF with regular NMF. + + Args: + x (tensor): A tensor to make positive. + dim (int, optional): The dimension to concatinate the two tensor halves at. + Returns: + tensor (torch.tensor): A positive tensor for one-sided dimensionality + reduction. """ return torch.cat([F.relu(x), F.relu(-x)], dim=dim) diff --git a/captum/optim/models/_common.py b/captum/optim/models/_common.py index e9fba1ba27..91757fd743 100644 --- a/captum/optim/models/_common.py +++ b/captum/optim/models/_common.py @@ -1,256 +1,308 @@ -import math -from inspect import signature -from typing import Dict, List, Tuple, Type, Union, cast - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from captum.optim._core.output_hook import ActivationFetcher -from captum.optim._utils.typing import ModuleOutputMapping, TupleOfTensorsOrTensorType - - -def get_model_layers(model: nn.Module) -> List[str]: - """ - Return a list of hookable layers for the target model. - """ - layers = [] - - def get_layers(net: nn.Module, prefix: List = []) -> None: - if hasattr(net, "_modules"): - for name, layer in net._modules.items(): - if layer is None: - continue - separator = "" if str(name).isdigit() else "." - name = "[" + str(name) + "]" if str(name).isdigit() else name - layers.append(separator.join(prefix + [name])) - get_layers(layer, prefix=prefix + [name]) - - get_layers(model) - return layers - - -class RedirectedReLU(torch.autograd.Function): - """ - A workaround when there is no gradient flow from an initial random input. - ReLU layers will block the gradient flow during backpropagation when their - input is less than 0. This means that it can be impossible to visualize a - target without allowing negative values to pass through ReLU layers during - backpropagation. - See: - https://github.com/tensorflow/lucid/blob/master/lucid/misc/redirected_relu_grad.py - """ - - @staticmethod - def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: - self.save_for_backward(input_tensor) - return input_tensor.clamp(min=0) - - @staticmethod - def backward(self, grad_output: torch.Tensor) -> torch.Tensor: - (input_tensor,) = self.saved_tensors - relu_grad = grad_output.clone() - relu_grad[input_tensor < 0] = 0 - if torch.equal(relu_grad, torch.zeros_like(relu_grad)): - # Let "wrong" gradients flow if gradient is completely 0 - return grad_output.clone() - return relu_grad - - -class RedirectedReluLayer(nn.Module): - """ - Class for applying RedirectedReLU - """ - - def forward(self, input: torch.Tensor) -> torch.Tensor: - return RedirectedReLU.apply(input) - - -def replace_layers( - model: nn.Module, - layer1: Type[nn.Module], - layer2: Type[nn.Module], - transfer_vars: bool = False, - **kwargs -) -> None: - """ - Replace all target layers with new layers inside the specified model, - possibly with the same initialization variables. - - Args: - model: (nn.Module): A PyTorch model instance. - layer1: (Type[nn.Module]): The layer class that you want to transfer - initialization variables from. - layer2: (Type[nn.Module]): The layer class to create with the variables - from layer1. - transfer_vars (bool, optional): Wether or not to try and copy - initialization variables from layer1 instances to the replacement - layer2 instances. - kwargs: (Any, optional): Any additional variables to use when creating - the new layer. - """ - - for name, child in model._modules.items(): - if isinstance(child, layer1): - if transfer_vars: - new_layer = _transfer_layer_vars(child, layer2, **kwargs) - else: - new_layer = layer2(**kwargs) - setattr(model, name, new_layer) - elif child is not None: - replace_layers(child, layer1, layer2, transfer_vars, **kwargs) - - -def _transfer_layer_vars( - layer1: nn.Module, layer2: Type[nn.Module], **kwargs -) -> nn.Module: - """ - Given a layer instance, create a new layer instance of another class - with the same initialization variables as the original layer. - Args: - layer1: (nn.Module): A layer instance that you want to transfer - initialization variables from. - layer2: (nn.Module): The layer class to create with the variables - from of layer1. - kwargs: (Any, optional): Any additional variables to use when creating - the new layer. - Returns: - layer2 instance (nn.Module): An instance of layer2 with the initialization - variables that it shares with layer1, and any specified additional - initialization variables. - """ - - l2_vars = list(signature(layer2.__init__).parameters.values()) - l2_vars = [ - str(l2_vars[i]).split()[0] - for i in range(len(l2_vars)) - if str(l2_vars[i]) != "self" - ] - l2_vars = [p.split(":")[0] if ":" in p and "=" not in p else p for p in l2_vars] - l2_vars = [p.split("=")[0] if "=" in p and ":" not in p else p for p in l2_vars] - layer2_vars: Dict = {k: [] for k in dict.fromkeys(l2_vars).keys()} - - layer1_vars = {k: v for k, v in vars(layer1).items() if not k.startswith("_")} - shared_vars = {k: v for k, v in layer1_vars.items() if k in layer2_vars} - new_vars = dict(item for d in (shared_vars, kwargs) for item in d.items()) - return layer2(**new_vars) - - -class Conv2dSame(nn.Conv2d): - """ - Tensorflow like 'SAME' convolution wrapper for 2D convolutions. - TODO: Replace with torch.nn.Conv2d when support for padding='same' - is in stable version - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: Union[int, Tuple[int, int]], - stride: Union[int, Tuple[int, int]] = 1, - padding: Union[int, Tuple[int, int]] = 0, - dilation: Union[int, Tuple[int, int]] = 1, - groups: int = 1, - bias: bool = True, - ) -> None: - super().__init__( - in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias - ) - - def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int: - return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - ih, iw = x.size()[-2:] - kh, kw = self.weight.size()[-2:] - pad_h = self.calc_same_pad(i=ih, k=kh, s=self.stride[0], d=self.dilation[0]) - pad_w = self.calc_same_pad(i=iw, k=kw, s=self.stride[1], d=self.dilation[1]) - - if pad_h > 0 or pad_w > 0: - x = F.pad( - x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] - ) - return F.conv2d( - x, - self.weight, - self.bias, - self.stride, - self.padding, - self.dilation, - self.groups, - ) - - -def collect_activations( - model: nn.Module, - targets: Union[nn.Module, List[nn.Module]], - model_input: TupleOfTensorsOrTensorType = torch.zeros(1, 3, 224, 224), -) -> ModuleOutputMapping: - """ - Collect target activations for a model. - """ - if not hasattr(targets, "__iter__"): - targets = [targets] - catch_activ = ActivationFetcher(model, targets) - activ_out = catch_activ(model_input) - return activ_out - - -class SkipLayer(torch.nn.Module): - """ - This layer is made to take the place of any layer that needs to be skipped over - during the forward pass. Use cases include removing nonlinear activation layers - like ReLU for circuits research. - - This layer works almost exactly the same way that nn.Indentiy does, except it also - ignores any additional arguments passed to the forward function. Any layer replaced - by SkipLayer must have the same input and output shapes. - - See nn.Identity for more details: - https://pytorch.org/docs/stable/generated/torch.nn.Identity.html - - Args: - args (Any): Any argument. Arguments will be safely ignored. - kwargs (Any) Any keyword argument. Arguments will be safely ignored. - """ - - def __init__(self, *args, **kwargs) -> None: - super().__init__() - - def forward( - self, x: Union[torch.Tensor, Tuple[torch.Tensor]], *args, **kwargs - ) -> Union[torch.Tensor, Tuple[torch.Tensor]]: - """ - Args: - x (torch.Tensor or tuple of torch.Tensor): The input tensor or tensors. - args (Any): Any argument. Arguments will be safely ignored. - kwargs (Any) Any keyword argument. Arguments will be safely ignored. - Returns: - x (torch.Tensor or tuple of torch.Tensor): The unmodified input tensor or - tensors. - """ - return x - - -def skip_layers( - model: nn.Module, layers: Union[List[Type[nn.Module]], Type[nn.Module]] -) -> None: - """ - This function is a wrapper function for - replace_layers and replaces the target layer - with layers that do nothing. - This is useful for removing the nonlinear ReLU - layers when creating expanded weights. - Args: - model (nn.Module): A PyTorch model instance. - layers (nn.Module or list of nn.Module): The layer - class type to replace in the model. - """ - if not hasattr(layers, "__iter__"): - layers = cast(Type[nn.Module], layers) - replace_layers(model, layers, SkipLayer) - else: - layers = cast(List[Type[nn.Module]], layers) - for target_layer in layers: - replace_layers(model, target_layer, SkipLayer) +import math +from inspect import signature +from typing import Dict, List, Optional, Tuple, Type, Union, cast + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from captum.optim._core.output_hook import ActivationFetcher +from captum.optim._utils.typing import ModuleOutputMapping, TupleOfTensorsOrTensorType + + +def get_model_layers(model: nn.Module) -> List[str]: + """ + Return a list of hookable layers for the target model. + """ + layers = [] + + def get_layers(net: nn.Module, prefix: List = []) -> None: + if hasattr(net, "_modules"): + for name, layer in net._modules.items(): + if layer is None: + continue + separator = "" if str(name).isdigit() else "." + name = "[" + str(name) + "]" if str(name).isdigit() else name + layers.append(separator.join(prefix + [name])) + get_layers(layer, prefix=prefix + [name]) + + get_layers(model) + return layers + + +class RedirectedReLU(torch.autograd.Function): + """ + A workaround when there is no gradient flow from an initial random input. + ReLU layers will block the gradient flow during backpropagation when their + input is less than 0. This means that it can be impossible to visualize a + target without allowing negative values to pass through ReLU layers during + backpropagation. + See: + https://github.com/tensorflow/lucid/blob/master/lucid/misc/redirected_relu_grad.py + """ + + @staticmethod + def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: + self.save_for_backward(input_tensor) + return input_tensor.clamp(min=0) + + @staticmethod + def backward(self, grad_output: torch.Tensor) -> torch.Tensor: + (input_tensor,) = self.saved_tensors + relu_grad = grad_output.clone() + relu_grad[input_tensor < 0] = 0 + if torch.equal(relu_grad, torch.zeros_like(relu_grad)): + # Let "wrong" gradients flow if gradient is completely 0 + return grad_output.clone() + return relu_grad + + +class RedirectedReluLayer(nn.Module): + """ + Class for applying RedirectedReLU + """ + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return RedirectedReLU.apply(input) + + +def replace_layers( + model: nn.Module, + layer1: Type[nn.Module], + layer2: Type[nn.Module], + transfer_vars: bool = False, + **kwargs +) -> None: + """ + Replace all target layers with new layers inside the specified model, + possibly with the same initialization variables. + + Args: + model: (nn.Module): A PyTorch model instance. + layer1: (Type[nn.Module]): The layer class that you want to transfer + initialization variables from. + layer2: (Type[nn.Module]): The layer class to create with the variables + from layer1. + transfer_vars (bool, optional): Wether or not to try and copy + initialization variables from layer1 instances to the replacement + layer2 instances. + kwargs: (Any, optional): Any additional variables to use when creating + the new layer. + """ + + for name, child in model._modules.items(): + if isinstance(child, layer1): + if transfer_vars: + new_layer = _transfer_layer_vars(child, layer2, **kwargs) + else: + new_layer = layer2(**kwargs) + setattr(model, name, new_layer) + elif child is not None: + replace_layers(child, layer1, layer2, transfer_vars, **kwargs) + + +def _transfer_layer_vars( + layer1: nn.Module, layer2: Type[nn.Module], **kwargs +) -> nn.Module: + """ + Given a layer instance, create a new layer instance of another class + with the same initialization variables as the original layer. + Args: + layer1: (nn.Module): A layer instance that you want to transfer + initialization variables from. + layer2: (nn.Module): The layer class to create with the variables + from of layer1. + kwargs: (Any, optional): Any additional variables to use when creating + the new layer. + Returns: + layer2 instance (nn.Module): An instance of layer2 with the initialization + variables that it shares with layer1, and any specified additional + initialization variables. + """ + + l2_vars = list(signature(layer2.__init__).parameters.values()) + l2_vars = [ + str(l2_vars[i]).split()[0] + for i in range(len(l2_vars)) + if str(l2_vars[i]) != "self" + ] + l2_vars = [p.split(":")[0] if ":" in p and "=" not in p else p for p in l2_vars] + l2_vars = [p.split("=")[0] if "=" in p and ":" not in p else p for p in l2_vars] + layer2_vars: Dict = {k: [] for k in dict.fromkeys(l2_vars).keys()} + + layer1_vars = {k: v for k, v in vars(layer1).items() if not k.startswith("_")} + shared_vars = {k: v for k, v in layer1_vars.items() if k in layer2_vars} + new_vars = dict(item for d in (shared_vars, kwargs) for item in d.items()) + return layer2(**new_vars) + + +class Conv2dSame(nn.Conv2d): + """ + Tensorflow like 'SAME' convolution wrapper for 2D convolutions. + TODO: Replace with torch.nn.Conv2d when support for padding='same' + is in stable version + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]] = 1, + padding: Union[int, Tuple[int, int]] = 0, + dilation: Union[int, Tuple[int, int]] = 1, + groups: int = 1, + bias: bool = True, + ) -> None: + super().__init__( + in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias + ) + + def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int: + return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + ih, iw = x.size()[-2:] + kh, kw = self.weight.size()[-2:] + pad_h = self.calc_same_pad(i=ih, k=kh, s=self.stride[0], d=self.dilation[0]) + pad_w = self.calc_same_pad(i=iw, k=kw, s=self.stride[1], d=self.dilation[1]) + + if pad_h > 0 or pad_w > 0: + x = F.pad( + x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] + ) + return F.conv2d( + x, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + + +def collect_activations( + model: nn.Module, + targets: Union[nn.Module, List[nn.Module]], + model_input: TupleOfTensorsOrTensorType = torch.zeros(1, 3, 224, 224), +) -> ModuleOutputMapping: + """ + Collect target activations for a model. + """ + if not hasattr(targets, "__iter__"): + targets = [targets] + catch_activ = ActivationFetcher(model, targets) + activ_out = catch_activ(model_input) + return activ_out + + +class SkipLayer(torch.nn.Module): + """ + This layer is made to take the place of any layer that needs to be skipped over + during the forward pass. Use cases include removing nonlinear activation layers + like ReLU for circuits research. + + This layer works almost exactly the same way that nn.Indentiy does, except it also + ignores any additional arguments passed to the forward function. Any layer replaced + by SkipLayer must have the same input and output shapes. + + See nn.Identity for more details: + https://pytorch.org/docs/stable/generated/torch.nn.Identity.html + + Args: + args (Any): Any argument. Arguments will be safely ignored. + kwargs (Any) Any keyword argument. Arguments will be safely ignored. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__() + + def forward( + self, x: Union[torch.Tensor, Tuple[torch.Tensor]], *args, **kwargs + ) -> Union[torch.Tensor, Tuple[torch.Tensor]]: + """ + Args: + x (torch.Tensor or tuple of torch.Tensor): The input tensor or tensors. + args (Any): Any argument. Arguments will be safely ignored. + kwargs (Any) Any keyword argument. Arguments will be safely ignored. + Returns: + x (torch.Tensor or tuple of torch.Tensor): The unmodified input tensor or + tensors. + """ + return x + + +def skip_layers( + model: nn.Module, layers: Union[List[Type[nn.Module]], Type[nn.Module]] +) -> None: + """ + This function is a wrapper function for + replace_layers and replaces the target layer + with layers that do nothing. + This is useful for removing the nonlinear ReLU + layers when creating expanded weights. + Args: + model (nn.Module): A PyTorch model instance. + layers (nn.Module or list of nn.Module): The layer + class type to replace in the model. + """ + if not hasattr(layers, "__iter__"): + layers = cast(Type[nn.Module], layers) + replace_layers(model, layers, SkipLayer) + else: + layers = cast(List[Type[nn.Module]], layers) + for target_layer in layers: + replace_layers(model, target_layer, SkipLayer) + + +class MaxPool2dRelaxed(torch.nn.Module): + """ + A relaxed pooling layer, that's useful for calculating attributions of spatial + positions. This layer reduces Noise in the gradient through the use of a + continuous relaxation of the gradient. + + This layer peforms a MaxPool2d operation on the input, while using an equivalent + AvgPool2d layer to compute the gradient. + """ + + def __init__( + self, + kernel_size: Union[int, Tuple[int, ...]], + stride: Optional[Union[int, Tuple[int, ...]]] = None, + padding: Union[int, Tuple[int, ...]] = 0, + ceil_mode: bool = False, + ) -> None: + """ + Args: + + kernel_size (int or tuple of int): The size of the window to perform max & + average pooling with. + stride (int or tuple of int, optional): The stride window size to use. + Default: None + padding (int or tuple of int): The amount of zero padding to add to both + sides in the nn.MaxPool2d & nn.AvgPool2d modules. + Default: 0 + ceil_mode (bool, optional): Whether to use ceil or floor for creating the + output shape. + Default: False + """ + super().__init__() + self.maxpool = torch.nn.MaxPool2d( + kernel_size=kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode + ) + self.avgpool = torch.nn.AvgPool2d( + kernel_size=kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + + x (torch.Tensor): An input tensor to run the pooling operations on. + + Returns: + x (torch.Tensor): A max pooled x tensor with gradient of an equivalent avg + pooled tensor. + """ + return self.maxpool(x.detach()) + self.avgpool(x) - self.avgpool(x.detach()) diff --git a/tests/optim/helpers/image_dataset.py b/tests/optim/helpers/image_dataset.py index 9b5e73ad48..a8cef03b87 100644 --- a/tests/optim/helpers/image_dataset.py +++ b/tests/optim/helpers/image_dataset.py @@ -5,6 +5,14 @@ class ImageTestDataset(torch.utils.data.Dataset): + """ + Create a simple tensor dataset for testing image dataset classes + and functions. + + Args: + tensors (list): A list of tensors to use in the dataset. + """ + def __init__(self, tensors: List[torch.Tensor]) -> None: assert all(t.size(0) == 1 for t in tensors if t.dim() == 4) @@ -23,7 +31,12 @@ def __len__(self) -> int: def image_cov_np(array: np.ndarray) -> np.ndarray: """ - Calculate an array's RGB covariance matrix + Calculate an array's RGB covariance matrix. + + Args: + array (array): An NCHW image array. + Returns: + *array*: An RGB covariance matrix for the specified array. """ array = array.reshape(-1, 3) @@ -36,6 +49,13 @@ def cov_matrix_to_klt_np( ) -> np.ndarray: """ Convert a cov matrix to a klt matrix. + + Args: + cov_mtx (array): A 3 by 3 covariance matrix generated from a dataset. + normalize (bool): Whether or not to normalize the resulting KLT matrix. + epsilon (float): + Returns: + *array*: A KLT matrix for the specified covariance matrix. """ U, S, V = np.linalg.svd(cov_mtx) diff --git a/tests/optim/helpers/numpy_common.py b/tests/optim/helpers/numpy_common.py index 6013600eb7..b432829694 100644 --- a/tests/optim/helpers/numpy_common.py +++ b/tests/optim/helpers/numpy_common.py @@ -12,6 +12,13 @@ def weights_to_heatmap_2d( By default red represents excitatory values, blue represents inhibitory values, and white represents no excitation or inhibition. + + Args: + weight (array): A 2d array to create the heatmap from. + colors (List of strings): A list of strings containing color + hex values to use for coloring the heatmap. + Returns: + *array*: A weight heatmap. """ assert array.ndim == 2 diff --git a/tests/optim/helpers/numpy_transforms.py b/tests/optim/helpers/numpy_transforms.py index 386478f236..eec0afebac 100644 --- a/tests/optim/helpers/numpy_transforms.py +++ b/tests/optim/helpers/numpy_transforms.py @@ -8,7 +8,11 @@ class BlendAlpha: """ - NumPy version of the BlendAlpha transform + NumPy version of the BlendAlpha transform. + + Args: + background (array, optional): An NCHW image array to be used as the + Alpha channel's background. """ def __init__(self, background: Optional[np.ndarray] = None) -> None: @@ -16,6 +20,14 @@ def __init__(self, background: Optional[np.ndarray] = None) -> None: self.background = background def blend_alpha(self, x: np.ndarray) -> np.ndarray: + """ + Blend the Alpha channel into the RGB channels. + + Args: + x (array): RGBA image array to blend into an RGB image array. + Returns: + blended (array): RGB image array. + """ assert x.shape[1] == 4 assert x.ndim == 4 rgb, alpha = x[:, :3, ...], x[:, 3:4, ...] @@ -30,7 +42,11 @@ def blend_alpha(self, x: np.ndarray) -> np.ndarray: class RandomSpatialJitter: """ - NumPy version of the RandomSpatialJitter transform + NumPy version of the RandomSpatialJitter transform. + + Args: + translate (int): The amount to translate the H and W dimensions + of an CHW or NCHW array. """ def __init__(self, translate: int) -> None: @@ -59,6 +75,7 @@ def jitter(self, x: np.ndarray) -> np.ndarray: class CenterCrop: """ Center crop a specified amount from a tensor. + Arguments: size (int or sequence of int): Number of pixels to center crop away. pixels_from_edges (bool, optional): Whether to treat crop size values @@ -181,6 +198,15 @@ def __init__(self, transform: Union[str, np.ndarray] = "klt") -> None: ) def to_rgb(self, x: np.ndarray, inverse: bool = False) -> np.ndarray: + """ + Args: + x (array): A CHW or NCHW RGB or RGBA image array. + inverse (bool, optional): Whether to recorrelate or decorrelate colors. + Default is set to False. + Returns: + *array*: An array with it's colors recorrelated or decorrelated. + """ + assert x.ndim == 3 or x.ndim == 4 # alpha channel is taken off... diff --git a/tests/optim/models/test_models_common.py b/tests/optim/models/test_models_common.py index f6418b8d6c..176b10fff2 100644 --- a/tests/optim/models/test_models_common.py +++ b/tests/optim/models/test_models_common.py @@ -290,3 +290,48 @@ def test_skip_layers(self) -> None: model_utils.skip_layers(model, torch.nn.ReLU) output_tensor = model(x) assertTensorAlmostEqual(self, x, output_tensor, 0) + + +class TestMaxPool2dRelaxed(BaseTest): + def test_maxpool2d_relaxed_forward_data(self) -> None: + maxpool_relaxed = model_utils.MaxPool2dRelaxed( + kernel_size=3, stride=2, padding=0, ceil_mode=True + ) + maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) + + test_input = torch.arange(0, 1 * 3 * 8 * 8).view(1, 3, 8, 8).float() + + test_output_relaxed = maxpool_relaxed(test_input.clone()) + test_output_max = maxpool(test_input.clone()) + + assertTensorAlmostEqual(self, test_output_relaxed, test_output_max) + + def test_maxpool2d_relaxed_gradient(self) -> None: + maxpool_relaxed = model_utils.MaxPool2dRelaxed( + kernel_size=3, stride=2, padding=0, ceil_mode=True + ) + test_input = torch.nn.Parameter( + torch.arange(0, 1 * 1 * 4 * 4).view(1, 1, 4, 4).float() + ) + + test_output = maxpool_relaxed(test_input) + + output_grad = torch.autograd.grad( + outputs=[test_output], + inputs=[test_input], + grad_outputs=[test_output], + )[0] + + expected_output = torch.tensor( + [ + [ + [ + [1.1111, 1.1111, 2.9444, 1.8333], + [1.1111, 1.1111, 2.9444, 1.8333], + [3.4444, 3.4444, 9.0278, 5.5833], + [2.3333, 2.3333, 6.0833, 3.7500], + ] + ] + ], + ) + assertTensorAlmostEqual(self, output_grad, expected_output, 0.0005) diff --git a/tests/optim/param/test_images.py b/tests/optim/param/test_images.py index 525d6277aa..7c420aa579 100644 --- a/tests/optim/param/test_images.py +++ b/tests/optim/param/test_images.py @@ -337,7 +337,7 @@ def test_sharedimage_get_offset_single_number(self) -> None: shapes=shared_shapes, parameterization=test_param ) - offset = image_param.get_offset(4, 3) + offset = image_param._get_offset(4, 3) self.assertEqual(len(offset), 3) self.assertEqual(offset, [[4, 4, 4, 4]] * 3) @@ -354,7 +354,7 @@ def test_sharedimage_get_offset_exact(self) -> None: ) offset_vals = ((1, 2, 3, 4), (4, 3, 2, 1), (1, 2, 3, 4)) - offset = image_param.get_offset(offset_vals, 3) + offset = image_param._get_offset(offset_vals, 3) self.assertEqual(len(offset), 3) self.assertEqual(offset, [[int(o) for o in v] for v in offset_vals]) @@ -371,7 +371,7 @@ def test_sharedimage_get_offset_single_set_four_numbers(self) -> None: ) offset_vals = (1, 2, 3, 4) - offset = image_param.get_offset(offset_vals, 3) + offset = image_param._get_offset(offset_vals, 3) self.assertEqual(len(offset), 3) self.assertEqual(offset, [list(offset_vals)] * 3) @@ -388,7 +388,7 @@ def test_sharedimage_get_offset_single_set_three_numbers(self) -> None: ) offset_vals = (2, 3, 4) - offset = image_param.get_offset(offset_vals, 3) + offset = image_param._get_offset(offset_vals, 3) self.assertEqual(len(offset), 3) self.assertEqual(offset, [[0] + list(offset_vals)] * 3) @@ -405,7 +405,7 @@ def test_sharedimage_get_offset_single_set_two_numbers(self) -> None: ) offset_vals = (3, 4) - offset = image_param.get_offset(offset_vals, 3) + offset = image_param._get_offset(offset_vals, 3) self.assertEqual(len(offset), 3) self.assertEqual(offset, [[0, 0] + list(offset_vals)] * 3) @@ -448,7 +448,7 @@ def test_apply_offset(self): ) test_x_list = [torch.ones(*size) for x in range(size[0])] - output_A = image_param.apply_offset(test_x_list) + output_A = image_param._apply_offset(test_x_list) x_list = [torch.ones(*size) for x in range(size[0])] self.assertEqual(image_param.offset, [list(offset_vals)]) @@ -475,7 +475,7 @@ def test_interpolate_tensor(self) -> None: batch = 1 test_tensor = torch.ones(6, 4, 128, 128) - output_tensor = image_param.interpolate_tensor( + output_tensor = image_param._interpolate_tensor( test_tensor, batch, channels, size[0], size[1] ) diff --git a/tutorials/optimviz/atlas/ActivationAtlasSampleCollection_OptimViz.ipynb b/tutorials/optimviz/atlas/ActivationAtlasSampleCollection_OptimViz.ipynb new file mode 100644 index 0000000000..1a47abb9e3 --- /dev/null +++ b/tutorials/optimviz/atlas/ActivationAtlasSampleCollection_OptimViz.ipynb @@ -0,0 +1,592 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "ActivationAtlasSampleCollection_OptimViz.ipynb", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "KP2PKna21WLK" + }, + "source": [ + "# Collecting Samples for Activation Atlases with captum.optim\n", + "\n", + "This notebook demonstrates how to collect the activation and corresponding attribution samples required for [Activation Atlases](https://distill.pub/2019/activation-atlas/) for the InceptionV1 model imported from Caffe." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "v6T6jxWb4cil" + }, + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "from typing import List, Optional, Tuple, cast\n", + "\n", + "import os\n", + "import torch\n", + "import torchvision\n", + "\n", + "from tqdm.auto import tqdm\n", + "\n", + "from captum.optim.models import googlenet\n", + "\n", + "import captum.optim as opt\n", + "\n", + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dtE-t6ZG0-sJ" + }, + "source": [ + "### Dataset Download & Setup \n", + "\n", + "To begin, we'll need to download and setup the image dataset that our model was trained on. You can download ImageNet's ILSVRC2012 dataset from the [ImageNet website](http://www.image-net.org/challenges/LSVRC/2012/) or via BitTorrent from [Academic Torrents](https://academictorrents.com/details/a306397ccf9c2ead27155983c254227c0fd938e2)." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "lDt-6WMp0qh3" + }, + "source": [ + "collect_attributions = True # Set to False for no attributions\n", + "\n", + "# Setup basic transforms\n", + "# The model has the normalization step in its internal transform_input\n", + "# function, so we don't need to normalize our inputs here.\n", + "transform_list = [\n", + " torchvision.transforms.Resize((224, 224)),\n", + " torchvision.transforms.ToTensor(),\n", + "]\n", + "transform_list = torchvision.transforms.Compose(transform_list)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "i85yBIhL7owj" + }, + "source": [ + "To make it easier to load the ImageNet dataset, we can use [Torchvision](https://pytorch.org/vision/stable/datasets.html#imagenet)'s `torchvision.datasets.ImageNet` instead of the default `ImageFolder`." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "3oRqxlMq7gJ4" + }, + "source": [ + "# Load the dataset\n", + "image_dataset = torchvision.datasets.ImageNet(\n", + " root=\"path/to/dataset\", split=\"train\", transform=transform_list\n", + ")" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "573290Fr8KN7" + }, + "source": [ + "Now we wrap our dataset in a `torch.utils.data.DataLoader` instance, and set the desired batch size." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "DUCfwsvR7iGC" + }, + "source": [ + "# Set desired batch size & load dataset with torch.utils.DataLoader\n", + "image_loader = torch.utils.data.DataLoader(\n", + " image_dataset,\n", + " batch_size=32,\n", + " shuffle=True,\n", + ")" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4qfpBAPu18jv" + }, + "source": [ + "We load our model, then set the desired model target layers and corresponding file names." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "qMViqsJ82Mcp" + }, + "source": [ + "# Model to collect samples from, what layers of the model to collect samples from,\n", + "# and the desired names to use for the target layers.\n", + "sample_model = (\n", + " googlenet(\n", + " pretrained=True, replace_relus_with_redirectedrelu=False, bgr_transform=True\n", + " )\n", + " .eval()\n", + " .to(device)\n", + ")\n", + "sample_targets = [sample_model.mixed4c_relu]\n", + "sample_target_names = [\"mixed4c_relu_samples\"]" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Jl719nyZEGSt" + }, + "source": [ + "By default the activation samples will not have the right class attributions, so we remedy this by loading a second instance of our model. We then replace all `nn.MaxPool2d` layers in the second model instance with Captum's `MaxPool2dRelaxed` layer. The relaxed max pooling layer lets us estimate the sample class attributions by determining the rate at which increasing the neuron affects the output classes." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "A-VJyHRm1tqC" + }, + "source": [ + "# Optionally collect attributions from a copy of the first model that's\n", + "# been setup with relaxed pooling layers.\n", + "if collect_attributions:\n", + " sample_model_attr = (\n", + " googlenet(\n", + " pretrained=True, replace_relus_with_redirectedrelu=False, bgr_transform=True\n", + " )\n", + " .eval()\n", + " .to(device)\n", + " )\n", + " opt.models.replace_layers(\n", + " sample_model_attr,\n", + " torch.nn.MaxPool2d,\n", + " opt.models.MaxPool2dRelaxed,\n", + " transfer_vars=True,\n", + " )\n", + " sample_attr_targets = [sample_model_attr.mixed4c_relu]\n", + " sample_logit_target = sample_model_attr.fc\n", + "else:\n", + " sample_model_attr = None\n", + " sample_attr_targets = None\n", + " sample_logit_target = None" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "32zDGSR5-qDW" + }, + "source": [ + "With our dataset loaded and models ready to go, we can now start collecting our samples. To perform the sample collection, we define a function called `capture_activation_samples` to randomly sample an x and y position for every image for all specified target layers." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "2YLBCYP0J4Gq" + }, + "source": [ + "def attribute_spatial_position(\n", + " target_activ: torch.Tensor,\n", + " logit_activ: torch.Tensor,\n", + " position_mask: torch.Tensor,\n", + ") -> torch.Tensor:\n", + " \"\"\"\n", + " Args:\n", + "\n", + " logit_activ: Captured activations from the FC / logit layer.\n", + " target_activ: Captured activations from the target layer.\n", + " position_mask (torch.Tensor, optional): If using a batch size greater than\n", + " one, a mask is used to zero out all the non-target positions.\n", + "\n", + " Returns:\n", + " logit_attr (torch.Tensor): A sorted list of class attributions for the target\n", + " spatial positions.\n", + " \"\"\"\n", + "\n", + " assert target_activ.dim() == 2 or target_activ.dim() == 4\n", + " assert logit_activ.dim() == 2\n", + "\n", + " zeros = torch.nn.Parameter(torch.zeros_like(logit_activ))\n", + " target_zeros = target_activ * position_mask\n", + "\n", + " grad_one = torch.autograd.grad(\n", + " outputs=[logit_activ],\n", + " inputs=[target_activ],\n", + " grad_outputs=[zeros],\n", + " create_graph=True,\n", + " )\n", + " logit_attr = torch.autograd.grad(\n", + " outputs=grad_one,\n", + " inputs=[zeros],\n", + " grad_outputs=[target_zeros],\n", + " create_graph=True,\n", + " )[0]\n", + " return logit_attr\n", + "\n", + "\n", + "def capture_activation_samples(\n", + " loader: torch.utils.data.DataLoader,\n", + " model: torch.nn.Module,\n", + " targets: List[torch.nn.Module],\n", + " target_names: Optional[List[str]] = None,\n", + " sample_dir: str = \"\",\n", + " num_images: Optional[int] = None,\n", + " samples_per_image: int = 1,\n", + " input_device: torch.device = torch.device(\"cpu\"),\n", + " collect_attributions: bool = False,\n", + " attr_model: Optional[torch.nn.Module] = None,\n", + " attr_targets: Optional[List[torch.nn.Module]] = None,\n", + " logit_target: Optional[torch.nn.Module] = None,\n", + " show_progress: bool = False,\n", + "):\n", + " \"\"\"\n", + " Capture randomly sampled activations for an image dataset from one or multiple\n", + " target layers.\n", + "\n", + " Args:\n", + "\n", + " loader (torch.utils.data.DataLoader): A torch.utils.data.DataLoader\n", + " instance for an image dataset.\n", + " model (nn.Module): A PyTorch model instance.\n", + " targets (list of nn.Module): A list of layers to collect activation samples\n", + " from.\n", + " target_names (list of str, optional): A list of names to use when saving sample\n", + " tensors as files. Names will automatically be chosen if set to None.\n", + " Default: None\n", + " sample_dir (str): Path to where activation samples should be saved.\n", + " Default: \"\"\n", + " num_images (int, optional): How many images to collect samples from.\n", + " Default is to collect samples for every image in the dataset. Set to None\n", + " to collect samples from every image in the dataset.\n", + " Default: None\n", + " samples_per_image (int): How many samples to collect per image.\n", + " Default: 1\n", + " input_device (torch.device, optional): The device to use for model\n", + " inputs.\n", + " Default: torch.device(\"cpu\")\n", + " collect_attributions (bool, optional): Whether or not to collect attributions\n", + " for samples.\n", + " Default: False\n", + " attr_model (nn.Module, optional): A PyTorch model instance to use for\n", + " calculating sample attributions.\n", + " Default: None\n", + " attr_targets (list of nn.Module, optional): A list of attribution model layers\n", + " to collect attributions from. This should be the exact same as the targets\n", + " parameter, except for the attribution model.\n", + " Default: None\n", + " logit_target (nn.Module, optional): The final layer in the attribution model\n", + " that determines the classes. This parameter is only enabled if\n", + " collect_attributions is set to True.\n", + " Default: None\n", + " show_progress (bool, optional): Whether or not to show progress.\n", + " Default: False\n", + " \"\"\"\n", + "\n", + " if target_names is None:\n", + " target_names = [\"target\" + str(i) + \"_\" for i in range(len(targets))]\n", + "\n", + " assert len(target_names) == len(targets)\n", + " assert os.path.isdir(sample_dir)\n", + "\n", + " def random_sample(\n", + " activations: torch.Tensor,\n", + " ) -> Tuple[List[torch.Tensor], List[List[List[int]]]]:\n", + " \"\"\"\n", + " Randomly sample H & W dimensions of activations with 4 dimensions.\n", + " \"\"\"\n", + " assert activations.dim() == 4 or activations.dim() == 2\n", + "\n", + " activation_samples: List = []\n", + " position_list: List = []\n", + "\n", + " with torch.no_grad():\n", + " for i in range(samples_per_image):\n", + " sample_position_list: List = []\n", + " for b in range(activations.size(0)):\n", + " if activations.dim() == 4:\n", + " h, w = activations.shape[2:]\n", + " y = torch.randint(low=1, high=h - 1, size=[1])\n", + " x = torch.randint(low=1, high=w - 1, size=[1])\n", + " activ = activations[b, :, y, x]\n", + " sample_position_list.append((b, y, x))\n", + " elif activations.dim() == 2:\n", + " activ = activations[b].unsqueeze(1)\n", + " sample_position_list.append(b)\n", + " activation_samples.append(activ)\n", + " position_list.append(sample_position_list)\n", + " return activation_samples, position_list\n", + "\n", + " def attribute_samples(\n", + " activations: torch.Tensor,\n", + " logit_activ: torch.Tensor,\n", + " position_list: List[List[List[int]]],\n", + " ) -> List[torch.Tensor]:\n", + " \"\"\"\n", + " Collect attributions for target sample positions.\n", + " \"\"\"\n", + " assert activations.dim() == 4 or activations.dim() == 2\n", + "\n", + " sample_attributions: List = []\n", + " with torch.set_grad_enabled(True):\n", + " zeros_mask = torch.zeros_like(activations)\n", + " for sample_pos_list in position_list:\n", + " for c in sample_pos_list:\n", + " if activations.dim() == 4:\n", + " zeros_mask[c[0], :, c[1], c[2]] = 1\n", + " elif activations.dim() == 2:\n", + " zeros_mask[c] = 1\n", + " attr = attribute_spatial_position(\n", + " activations, logit_activ, position_mask=zeros_mask\n", + " ).detach()\n", + " sample_attributions.append(attr)\n", + " return sample_attributions\n", + "\n", + " if collect_attributions:\n", + " logit_target == list(model.children())[len(list(model.children())) - 1 :][\n", + " 0\n", + " ] if logit_target is None else logit_target\n", + " attr_targets = cast(List[torch.nn.Module], attr_targets)\n", + " attr_targets += [cast(torch.nn.Module, logit_target)]\n", + "\n", + " if show_progress:\n", + " total = (\n", + " len(loader.dataset) if num_images is None else num_images # type: ignore\n", + " )\n", + " pbar = tqdm(total=total, unit=\" images\")\n", + "\n", + " image_count, batch_count = 0, 0\n", + " with torch.no_grad():\n", + " for inputs, _ in loader:\n", + " inputs = inputs.to(input_device)\n", + " image_count += inputs.size(0)\n", + " batch_count += 1\n", + "\n", + " target_activ_dict = opt.models.collect_activations(model, targets, inputs)\n", + " if collect_attributions:\n", + " with torch.set_grad_enabled(True):\n", + " target_activ_attr_dict = opt.models.collect_activations(\n", + " attr_model, attr_targets, inputs\n", + " )\n", + " logit_activ = target_activ_attr_dict[logit_target]\n", + " del target_activ_attr_dict[logit_target]\n", + "\n", + " sample_coords = []\n", + " for t, n in zip(target_activ_dict, target_names):\n", + " sample_tensors, p_list = random_sample(target_activ_dict[t])\n", + " torch.save(\n", + " sample_tensors,\n", + " os.path.join(\n", + " sample_dir, n + \"_activations_\" + str(batch_count) + \".pt\"\n", + " ),\n", + " )\n", + " sample_coords.append(p_list)\n", + "\n", + " if collect_attributions:\n", + " for t, n, s_coords in zip(\n", + " target_activ_attr_dict, target_names, sample_coords\n", + " ):\n", + " sample_attrs = attribute_samples(\n", + " target_activ_attr_dict[t], logit_activ, s_coords\n", + " )\n", + " torch.save(\n", + " sample_attrs,\n", + " os.path.join(\n", + " sample_dir,\n", + " n + \"_attributions_\" + str(batch_count) + \".pt\",\n", + " ),\n", + " )\n", + "\n", + " if show_progress:\n", + " pbar.update(inputs.size(0))\n", + "\n", + " if num_images is not None:\n", + " if image_count > num_images:\n", + " break\n", + "\n", + " if show_progress:\n", + " pbar.close()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IWsmPssJJ09E" + }, + "source": [ + "" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "uODdkyjY1lap" + }, + "source": [ + "# Directory to save sample files to\n", + "sample_dir = \"inceptionv1_samples\"\n", + "try:\n", + " os.mkdir(sample_dir)\n", + "except:\n", + " pass\n", + "\n", + "# Collect samples & optionally attributions as well\n", + "capture_activation_samples(\n", + " loader=image_loader,\n", + " model=sample_model,\n", + " targets=sample_targets,\n", + " target_names=sample_target_names,\n", + " attr_model=sample_model_attr,\n", + " attr_targets=sample_attr_targets,\n", + " input_device=device,\n", + " sample_dir=sample_dir,\n", + " show_progress=True,\n", + " collect_attributions=collect_attributions,\n", + " logit_target=sample_logit_target,\n", + ")" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "eMrBUaPi97fF" + }, + "source": [ + "Now that we've collected our samples, we need to combine them into a single tensor. Below we use the `consolidate_samples` function to load each list of tensor samples, and then concatinate them into a single tensor." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "LaFglPVYKbXj" + }, + "source": [ + "def consolidate_samples(\n", + " sample_dir: str,\n", + " sample_basename: str = \"\",\n", + " dim: int = 1,\n", + " num_files: Optional[int] = None,\n", + " show_progress: bool = False,\n", + ") -> torch.Tensor:\n", + " \"\"\"\n", + " Combine samples collected from capture_activation_samples into a single tensor\n", + " with a shape of [n_channels, n_samples].\n", + "\n", + " Args:\n", + "\n", + " sample_dir (str): The directory where activation samples where saved.\n", + " sample_basename (str, optional): If samples from different layers are present\n", + " in sample_dir, then you can use samples from only a specific layer by\n", + " specifying the basename that samples of the same layer share.\n", + " Default: \"\"\n", + " dim (int, optional): The dimension to concatinate the samples together on.\n", + " Default: 1\n", + " show_progress (bool, optional): Whether or not to show progress.\n", + " Default: False\n", + "\n", + " Returns:\n", + " sample_tensor (torch.Tensor): A tensor containing all the specified sample\n", + " tensors with a shape of [n_channels, n_samples].\n", + " \"\"\"\n", + "\n", + " assert os.path.isdir(sample_dir)\n", + "\n", + " tensor_samples = [\n", + " os.path.join(sample_dir, name)\n", + " for name in os.listdir(sample_dir)\n", + " if sample_basename.lower() in name.lower()\n", + " and os.path.isfile(os.path.join(sample_dir, name))\n", + " ]\n", + " assert len(tensor_samples) > 0\n", + "\n", + " if show_progress:\n", + " total = len(tensor_samples) if num_files is None else num_files # type: ignore\n", + " pbar = tqdm(total=total, unit=\" sample batches collected\")\n", + "\n", + " samples: List[torch.Tensor] = []\n", + " for file in tensor_samples:\n", + " sample_batch = torch.load(file)\n", + " for s in sample_batch:\n", + " samples += [s.cpu()]\n", + " if show_progress:\n", + " pbar.update(1)\n", + "\n", + " if show_progress:\n", + " pbar.close()\n", + " return torch.cat(samples, dim)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "BKUPszVR1Ew-" + }, + "source": [ + "# Combine our newly collected samples into single tensors.\n", + "# We load the sample tensors from sample_dir and then\n", + "# concatenate them.\n", + "\n", + "for name in sample_target_names:\n", + " print(\"Combining \" + name + \" samples:\")\n", + " activation_samples = consolidate_samples(\n", + " sample_dir=sample_dir,\n", + " sample_basename=name + \"_activations\",\n", + " dim=1,\n", + " show_progress=True,\n", + " )\n", + " if collect_attributions:\n", + " sample_attributions = consolidate_samples(\n", + " sample_dir=sample_dir,\n", + " sample_basename=name + \"_attributions\",\n", + " dim=0,\n", + " show_progress=True,\n", + " )\n", + "\n", + " # Save the results\n", + " torch.save(activation_samples, name + \"activation_samples.pt\")\n", + " if collect_attributions:\n", + " torch.save(sample_attributions, name + \"attribution_samples.pt\")" + ], + "execution_count": null, + "outputs": [] + } + ] +}