diff --git a/captum/optim/_core/loss.py b/captum/optim/_core/loss.py index 66bb4c40c2..ffd7c8e43d 100644 --- a/captum/optim/_core/loss.py +++ b/captum/optim/_core/loss.py @@ -1,7 +1,6 @@ -import functools import operator from abc import ABC, abstractmethod, abstractproperty -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -9,12 +8,6 @@ from captum.optim._utils.typing import ModuleOutputMapping -def _make_arg_str(arg: Any) -> str: - arg = str(arg) - too_big = len(arg) > 15 or "\n" in arg - return arg[:15] + "..." if too_big else arg - - class Loss(ABC): """ Abstract Class to describe loss. @@ -23,7 +16,8 @@ class Loss(ABC): """ def __init__(self) -> None: - super(Loss, self).__init__() + super().__init__() + self.__name__ = self.__class__.__name__ @abstractproperty def target(self) -> Union[nn.Module, List[nn.Module]]: @@ -105,10 +99,35 @@ def module_op( ) -> "CompositeLoss": """ This is a general function for applying math operations to Losses + + Args: + + self (Loss): A Loss objective instance. + other (int, float, Loss, or None): The Loss objective instance or number to + use on the self Loss objective as part of a math operation. If math_op + is a unary operation, then other should be set to None. + math_op (Callable): A math operator to use on the Loss instance. + + Returns: + loss (CompositeLoss): A CompositeLoss instance with the math operations + created by the specified arguments. """ if other is None and math_op == operator.neg: def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: + """ + Pass collected activations through loss objective, and then apply a unary + math op. + + Args: + + module (ModuleOutputMapping): A dict of captured activations with + nn.Modules as keys. + + Returns: + loss (torch.Tensor): The target activations after being run + through the loss objective, and the unary math_op. + """ return math_op(self(module)) name = self.__name__ @@ -116,6 +135,19 @@ def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: elif isinstance(other, (int, float)): def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: + """ + Pass collected activations through the loss objective and then apply the + math operations with numbers. + + Args: + + module (ModuleOutputMapping): A dict of captured activations with + nn.Modules as keys. + + Returns: + loss (torch.Tensor): The target activations after being run + through the loss objective, and then the math_op with a number. + """ return math_op(self(module), other) name = self.__name__ @@ -123,6 +155,19 @@ def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: elif isinstance(other, Loss): # We take the mean of the output tensor to resolve shape mismatches def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: + """ + Pass collected activations through the loss objectives and then combine the + outputs with a math operation. + + Args: + + module (ModuleOutputMapping): A dict of captured activations with + nn.Modules as keys. + + Returns: + loss (torch.Tensor): The target activations after being run + through the loss objectives, and then merged with the math_op. + """ return math_op(torch.mean(self(module)), torch.mean(other(module))) name = f"Compose({', '.join([self.__name__, other.__name__])})" @@ -138,95 +183,227 @@ def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: class BaseLoss(Loss): + """ + The base class used for all Loss objectives. + """ + def __init__( self, target: Union[nn.Module, List[nn.Module]] = [], - batch_index: Optional[int] = None, + batch_index: Optional[Union[int, List[int]]] = None, ) -> None: - super(BaseLoss, self).__init__() + """ + Args: + + target (nn.Module or list of nn.Module): A target nn.Module or list of + nn.Module. + batch_index (int or list of int, optional): The index or index range of + activations to optimize if optimizing a batch of activations. If set to + ``None``, defaults to all activations in the batch. Index ranges should + be in the format of: [start, end]. + Default: ``None`` + """ + super().__init__() self._target = target if batch_index is None: self._batch_index = (None, None) + elif isinstance(batch_index, (list, tuple)): + self._batch_index = tuple(batch_index) else: self._batch_index = (batch_index, batch_index + 1) + assert all([isinstance(b, (int, type(None))) for b in self._batch_index]) + assert len(self._batch_index) == 2 @property def target(self) -> Union[nn.Module, List[nn.Module]]: + """ + Returns: + target (nn.Module or list of nn.Module): A target nn.Module or list of + nn.Module. + """ return self._target @property def batch_index(self) -> Tuple: + """ + Returns: + batch_index (tuple of int): A tuple of batch indices with a format + of: (start, end). + """ return self._batch_index class CompositeLoss(BaseLoss): + """ + When math operations are performed using one or more loss objectives, this class + is used to store and run those operations. Below we show examples of common + CompositeLoss use cases. + + + Using CompositeLoss with a unary op or with a binary op involving a Loss instance + and a float or integer: + + .. code-block:: python + + def compose_single_loss(loss: opt.loss.Loss) -> opt.loss.CompositeLoss: + def loss_fn( + module: Dict[nn.Module, Optional[torch.Tensor]] + ) -> torch.Tensor: + return loss(module) + + # Name of new composable loss instance + name = loss.__name__ + # All targets being used in the composable loss instance + target = loss.target + return opt.loss.CompositeLoss(loss_fn, name=name, target=target) + + Using CompositeLoss with a binary op using two Loss instances: + + .. code-block:: python + + def compose_binary_loss( + loss1: opt.loss.Loss, loss2: opt.loss.Loss + ) -> opt.loss.CompositeLoss: + def loss_fn( + module: Dict[nn.Module, Optional[torch.Tensor]] + ) -> torch.Tensor: + # Operation using 2 loss instances + return loss1(module) + loss2(module) + + # Name of new composable loss instance + name = "Compose(" + ", ".join([loss1.__name__, loss2.__name__]) + ")" + + # All targets being used in the composable loss instance + target1 = loss1.target if type(loss1.target) is list else [loss1.target] + target2 = loss2.target if type(loss2.target) is list else [loss2.target] + target = target1 + target2 + + # Remove duplicate targets + target = list(dict.fromkeys(target)) + return opt.loss.CompositeLoss(loss_fn, name=name, target=target) + + Using CompositeLoss with a list of Loss instances: + + .. code-block:: python + + def compose_multiple_loss(loss: List[opt.loss.Loss]) -> opt.loss.CompositeLoss: + def loss_fn( + module: Dict[nn.Module, Optional[torch.Tensor]] + ) -> torch.Tensor: + loss_tensors = [loss_obj(module) for loss_obj in loss] + # We can use any operation that combines the list of tensors into a + # single tensor + return sum(loss_tensors) + + # Name of new composable loss instance + name = "Compose(" + ", ".join([obj.__name__ for obj in loss]) + ")" + + # All targets being used in the composable loss instance + # targets will either be List[nn.Module] or nn.Module + targets = [loss_obj.target for loss_obj in loss] + # Flatten list of targets + target = [ + o for l in [t if type(t) is list else [t] for t in targets] for o in l + ] + # Remove duplicate targets + target = list(dict.fromkeys(target)) + return opt.loss.CompositeLoss(loss_fn, name=name, target=target) + """ + def __init__( self, loss_fn: Callable, name: str = "", target: Union[nn.Module, List[nn.Module]] = [], ) -> None: - super(CompositeLoss, self).__init__(target) + """ + Args: + + loss_fn (Callable): A function that takes a dict of captured activations + with nn.Modules as keys, and then passes those activations through loss + objective(s) & math operations. + name (str, optional): The name of all composable operations in the + instance. + Default: ``""`` + target (nn.Module or list of nn.Module): A target nn.Module or list of + nn.Module. + """ + super().__init__(target) self.__name__ = name self.loss_fn = loss_fn def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: - return self.loss_fn(targets_to_values) - + """ + Pass collected activations through the loss function. -def loss_wrapper(cls: Any) -> Callable: - """ - Primarily for naming purposes. - """ + Args: - @functools.wraps(cls) - def wrapper(*args, **kwargs) -> object: - obj = cls(*args, **kwargs) - args_str = " [" + ", ".join([_make_arg_str(arg) for arg in args]) + "]" - obj.__name__ = cls.__name__ + args_str - return obj + module (ModuleOutputMapping): A dict of captured activations with + nn.Modules as keys. - return wrapper + Returns: + loss (torch.Tensor): The target activations after being run through the + loss function. + """ + return self.loss_fn(targets_to_values) -@loss_wrapper class LayerActivation(BaseLoss): """ Maximize activations at the target layer. This is the most basic loss available and it simply returns the activations in their original form. - - Args: - target (nn.Module): The layer to optimize for. - batch_index (int, optional): The index of the image to optimize if we - optimizing a batch of images. If unspecified, defaults to all images - in the batch. """ + def __init__( + self, + target: nn.Module, + batch_index: Optional[Union[int, List[int]]] = None, + ) -> None: + """ + Args: + + target (nn.Module): A target layer, transform, or image parameterization + instance to optimize the output of. + batch_index (int or list of int, optional): The index or index range of + activations to optimize if optimizing a batch of activations. If set + to ``None``, defaults to all activations in the batch. Index ranges + should be in the format of: [start, end]. + Default: ``None`` + """ + BaseLoss.__init__(self, target, batch_index) + def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: activations = targets_to_values[self.target] activations = activations[self.batch_index[0] : self.batch_index[1]] return activations -@loss_wrapper class ChannelActivation(BaseLoss): """ Maximize activations at the target layer and target channel. This loss maximizes the activations of a target channel in a specified target layer, and can be useful to determine what features the channel is excited by. - - Args: - target (nn.Module): The layer to containing the channel to optimize for. - channel_index (int): The index of the channel to optimize for. - batch_index (int, optional): The index of the image to optimize if we - optimizing a batch of images. If unspecified, defaults to all images - in the batch. """ def __init__( - self, target: nn.Module, channel_index: int, batch_index: Optional[int] = None + self, + target: nn.Module, + channel_index: int, + batch_index: Optional[Union[int, List[int]]] = None, ) -> None: + """ + Args: + + target (nn.Module): A target layer, transform, or image parameterization + instance to optimize the output of. + channel_index (int): The index of the channel to optimize for. + batch_index (int or list of int, optional): The index or index range of + activations to optimize if optimizing a batch of activations. If set to + ``None``, defaults to all activations in the batch. Index ranges should + be in the format of: [start, end]. + Default: ``None`` + """ BaseLoss.__init__(self, target, batch_index) self.channel_index = channel_index @@ -243,26 +420,12 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: ] -@loss_wrapper class NeuronActivation(BaseLoss): """ This loss maximizes the activations of a target neuron in the specified channel from the specified layer. This loss is useful for determining the type of features that excite a neuron, and thus is often used for circuits and neuron related research. - - Args: - target (nn.Module): The layer to containing the channel to optimize for. - channel_index (int): The index of the channel to optimize for. - x (int, optional): The x coordinate of the neuron to optimize for. If - unspecified, defaults to center, or one unit left of center for even - lengths. - y (int, optional): The y coordinate of the neuron to optimize for. If - unspecified, defaults to center, or one unit up of center for even - heights. - batch_index (int, optional): The index of the image to optimize if we - optimizing a batch of images. If unspecified, defaults to all images - in the batch. """ def __init__( @@ -271,8 +434,28 @@ def __init__( channel_index: int, x: Optional[int] = None, y: Optional[int] = None, - batch_index: Optional[int] = None, + batch_index: Optional[Union[int, List[int]]] = None, ) -> None: + """ + Args: + + target (nn.Module): A target layer, transform, or image parameterization + instance to optimize the output of. + channel_index (int): The index of the channel to optimize for. + x (int, optional): The x coordinate of the neuron to optimize for. If + unspecified, defaults to center, or one unit left of center for even + lengths. + Default: ``None`` + y (int, optional): The y coordinate of the neuron to optimize for. If + unspecified, defaults to center, or one unit up of center for even + heights. + Default: ``None`` + batch_index (int or list of int, optional): The index or index range of + activations to optimize if optimizing a batch of activations. If set to + ``None``, defaults to all activations in the batch. Index ranges should + be in the format of: [start, end]. + Default: ``None`` + """ BaseLoss.__init__(self, target, batch_index) self.channel_index = channel_index self.x = x @@ -294,30 +477,46 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: ] -@loss_wrapper class DeepDream(BaseLoss): """ Maximize 'interestingness' at the target layer. Mordvintsev et al., 2015. https://github.com/google/deepdream + This loss returns the squared layer activations. When combined with a negative mean loss summarization, this loss will create hallucinogenic visuals commonly referred to as 'Deep Dream'. - Args: - target (nn.Module): The layer to optimize for. - batch_index (int, optional): The index of the image to optimize if we - optimizing a batch of images. If unspecified, defaults to all images - in the batch. + DeepDream tries to increase the values of neurons proportional to the amount + they are presently active. This is equivalent to maximizing the sum of the + squares. If you remove the square, you'd be visualizing a direction of: + ``[1,1,1,....]`` (which is same as :class:`.LayerActivation`). """ + def __init__( + self, + target: nn.Module, + batch_index: Optional[Union[int, List[int]]] = None, + ) -> None: + """ + Args: + + target (nn.Module): A target layer, transform, or image parameterization + instance to optimize the output of. + batch_index (int or list of int, optional): The index or index range of + activations to optimize if optimizing a batch of activations. If set + to ``None``, defaults to all activations in the batch. Index ranges + should be in the format of: [start, end]. + Default: ``None`` + """ + BaseLoss.__init__(self, target, batch_index) + def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: activations = targets_to_values[self.target] activations = activations[self.batch_index[0] : self.batch_index[1]] return activations**2 -@loss_wrapper class TotalVariation(BaseLoss): """ Total variation denoising penalty for activations. @@ -326,14 +525,26 @@ class TotalVariation(BaseLoss): This loss attempts to smooth / denoise the target by performing total variance denoising. The target is most often the image that’s being optimized. This loss is often used to remove unwanted visual artifacts. - - Args: - target (nn.Module): The layer to optimize for. - batch_index (int, optional): The index of the image to optimize if we - optimizing a batch of images. If unspecified, defaults to all images - in the batch. """ + def __init__( + self, + target: nn.Module, + batch_index: Optional[Union[int, List[int]]] = None, + ) -> None: + """ + Args: + + target (nn.Module): A target layer, transform, or image parameterization + instance to optimize the output of. + batch_index (int or list of int, optional): The index or index range of + activations to optimize if optimizing a batch of activations. If set + to ``None``, defaults to all activations in the batch. Index ranges + should be in the format of: [start, end]. + Default: ``None`` + """ + BaseLoss.__init__(self, target, batch_index) + def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: activations = targets_to_values[self.target] activations = activations[self.batch_index[0] : self.batch_index[1]] @@ -342,26 +553,29 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: return torch.sum(torch.abs(x_diff)) + torch.sum(torch.abs(y_diff)) -@loss_wrapper class L1(BaseLoss): """ L1 norm of the target layer, generally used as a penalty. - - Args: - target (nn.Module): The layer to optimize for. - constant (float): Constant threshold to deduct from the activations. - Defaults to 0. - batch_index (int, optional): The index of the image to optimize if we - optimizing a batch of images. If unspecified, defaults to all images - in the batch. """ def __init__( self, target: nn.Module, constant: float = 0.0, - batch_index: Optional[int] = None, + batch_index: Optional[Union[int, List[int]]] = None, ) -> None: + """ + Args: + + target (nn.Module): A target layer, transform, or image parameterization + instance to optimize the output of. + constant (float): Constant threshold to deduct from the activations. + batch_index (int or list of int, optional): The index or index range of + activations to optimize if optimizing a batch of activations. If set to + ``None``, defaults to all activations in the batch. Index ranges should + be in the format of: [start, end]. + Default: ``None`` + """ BaseLoss.__init__(self, target, batch_index) self.constant = constant @@ -371,41 +585,45 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: return torch.abs(activations - self.constant).sum() -@loss_wrapper class L2(BaseLoss): """ L2 norm of the target layer, generally used as a penalty. - - Args: - target (nn.Module): The layer to optimize for. - constant (float): Constant threshold to deduct from the activations. - Defaults to 0. - epsilon (float): Small value to add to L2 prior to sqrt. Defaults to 1e-6. - batch_index (int, optional): The index of the image to optimize if we - optimizing a batch of images. If unspecified, defaults to all images - in the batch. """ def __init__( self, target: nn.Module, constant: float = 0.0, - epsilon: float = 1e-6, - batch_index: Optional[int] = None, + eps: float = 1e-6, + batch_index: Optional[Union[int, List[int]]] = None, ) -> None: + """ + Args: + + target (nn.Module): A target layer, transform, or image parameterization + instance to optimize the output of. + constant (float): Constant threshold to deduct from the activations. + Default: ``0.0`` + eps (float): Small value to add to L2 prior to sqrt. + Default: ``1e-6`` + batch_index (int or list of int, optional): The index or index range of + activations to optimize if optimizing a batch of activations. If set to + ``None``, defaults to all activations in the batch. Index ranges should + be in the format of: [start, end]. + Default: ``None`` + """ BaseLoss.__init__(self, target, batch_index) self.constant = constant - self.epsilon = epsilon + self.eps = eps def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: activations = targets_to_values[self.target][ self.batch_index[0] : self.batch_index[1] ] activations = ((activations - self.constant) ** 2).sum() - return torch.sqrt(self.epsilon + activations) + return torch.sqrt(self.eps + activations) -@loss_wrapper class Diversity(BaseLoss): """ Use a cosine similarity penalty to extract features from a polysemantic neuron. @@ -414,15 +632,31 @@ class Diversity(BaseLoss): This loss helps break up polysemantic layers, channels, and neurons by encouraging diversity across the different batches. This loss is to be used along with a main loss. - - Args: - target (nn.Module): The layer to optimize for. - batch_index (int, optional): Unused here since we are optimizing for diversity - across the batch. """ + def __init__( + self, + target: nn.Module, + batch_index: Optional[List[int]] = None, + ) -> None: + """ + Args: + + target (nn.Module): A target layer, transform, or image parameterization + instance to optimize the output of. + batch_index (list of int, optional): The index range of activations to + optimize. If set to ``None``, defaults to all activations in the batch. + Index ranges should be in the format of: [start, end]. + Default: ``None`` + """ + if batch_index: + assert isinstance(batch_index, (list, tuple)) + assert len(batch_index) == 2 + BaseLoss.__init__(self, target, batch_index) + def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: activations = targets_to_values[self.target] + activations = activations[self.batch_index[0] : self.batch_index[1]] batch, channels = activations.shape[:2] flattened = activations.view(batch, channels, -1) grams = torch.matmul(flattened, torch.transpose(flattened, 1, 2)) @@ -438,7 +672,6 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: ) -@loss_wrapper class ActivationInterpolation(BaseLoss): """ Interpolate between two different layers & channels. @@ -446,23 +679,29 @@ class ActivationInterpolation(BaseLoss): https://distill.pub/2017/feature-visualization/#Interaction-between-Neurons This loss helps to interpolate or mix visualizations from two activations (layer or channel) by interpolating a linear sum between the two activations. - - Args: - target1 (nn.Module): The first layer to optimize for. - channel_index1 (int): Index of channel in first layer to optimize. Defaults to - all channels. - target2 (nn.Module): The first layer to optimize for. - channel_index2 (int): Index of channel in first layer to optimize. Defaults to - all channels. """ def __init__( self, target1: nn.Module = None, - channel_index1: int = -1, + channel_index1: Optional[int] = None, target2: nn.Module = None, - channel_index2: int = -1, + channel_index2: Optional[int] = None, ) -> None: + """ + Args: + + target1 (nn.Module): The first layer, transform, or image parameterization + instance to optimize the output for. + channel_index1 (int, optional): Index of channel in first target to + optimize. Default is set to ``None`` for all channels. + Default: ``None`` + target2 (nn.Module): The second layer, transform, or image parameterization + instance to optimize the output for. + channel_index2 (int, optional): Index of channel in second target to + optimize. Default is set to ``None`` for all channels. + Default: ``None`` + """ self.target_one = target1 self.channel_index_one = channel_index1 self.target_two = target2 @@ -476,15 +715,16 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: assert activations_one is not None and activations_two is not None # ensure channel indices are valid - assert ( - self.channel_index_one < activations_one.shape[1] - and self.channel_index_two < activations_two.shape[1] - ) + if self.channel_index_one: + assert self.channel_index_one < activations_one.shape[1] + if self.channel_index_two: + assert self.channel_index_two < activations_two.shape[1] + assert activations_one.size(0) == activations_two.size(0) - if self.channel_index_one > -1: + if self.channel_index_one: activations_one = activations_one[:, self.channel_index_one] - if self.channel_index_two > -1: + if self.channel_index_two: activations_two = activations_two[:, self.channel_index_two] B = activations_one.size(0) @@ -498,7 +738,6 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: return sum_tensor -@loss_wrapper class Alignment(BaseLoss): """ Penalize the L2 distance between tensors in the batch to encourage visual @@ -508,19 +747,36 @@ class Alignment(BaseLoss): When interpolating between activations, it may be desirable to keep image landmarks in the same position for visual comparison. This loss helps to minimize L2 distance between neighbouring images. - - Args: - target (nn.Module): The layer to optimize for. - decay_ratio (float): How much to decay penalty as images move apart in batch. - Defaults to 2. """ - def __init__(self, target: nn.Module, decay_ratio: float = 2.0) -> None: - BaseLoss.__init__(self, target) + def __init__( + self, + target: nn.Module, + decay_ratio: float = 2.0, + batch_index: Optional[List[int]] = None, + ) -> None: + """ + Args: + + target (nn.Module): A target layer, transform, or image parameterization + instance to optimize the output of. + decay_ratio (float): How much to decay penalty as images move apart in + the batch. + Default: ``2.0`` + batch_index (list of int, optional): The index range of activations to + optimize. If set to ``None``, defaults to all activations in the batch. + Index ranges should be in the format of: [start, end]. + Default: ``None`` + """ + if batch_index: + assert isinstance(batch_index, (list, tuple)) + assert len(batch_index) == 2 + BaseLoss.__init__(self, target, batch_index) self.decay_ratio = decay_ratio def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: activations = targets_to_values[self.target] + activations = activations[self.batch_index[0] : self.batch_index[1]] B = activations.size(0) sum_tensor = torch.zeros(1, device=activations.device) @@ -535,7 +791,6 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: return -sum_tensor -@loss_wrapper class Direction(BaseLoss): """ Visualize a general direction vector. @@ -545,23 +800,28 @@ class Direction(BaseLoss): the alignment between the input vector and the layer’s activation vector. The dimensionality of the vector should correspond to the number of channels in the layer. - - Args: - target (nn.Module): The layer to optimize for. - vec (torch.Tensor): Vector representing direction to align to. - cossim_pow (float, optional): The desired cosine similarity power to use. - batch_index (int, optional): The index of the image to optimize if we - optimizing a batch of images. If unspecified, defaults to all images - in the batch. """ def __init__( self, target: nn.Module, vec: torch.Tensor, - cossim_pow: Optional[float] = 0.0, + cossim_pow: float = 0.0, batch_index: Optional[int] = None, ) -> None: + """ + Args: + + target (nn.Module): A target layer, transform, or image parameterization + instance to optimize the output of. + vec (torch.Tensor): Vector representing direction to align to. + cossim_pow (float, optional): The desired cosine similarity power to use. + Default: ``0.0`` + batch_index (int, optional): The index of activations to optimize if + optimizing a batch of activations. If set to ``None``, defaults to + all activations in the batch. + Default: ``None`` + """ BaseLoss.__init__(self, target, batch_index) self.vec = vec.reshape((1, -1, 1, 1)) self.cossim_pow = cossim_pow @@ -573,7 +833,6 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: return _dot_cossim(self.vec, activations, cossim_pow=self.cossim_pow) -@loss_wrapper class NeuronDirection(BaseLoss): """ Visualize a single (x, y) position for a direction vector. @@ -581,21 +840,6 @@ class NeuronDirection(BaseLoss): https://distill.pub/2019/activation-atlas/#Aggregating-Multiple-Images Extends Direction loss by focusing on visualizing a single neuron within the kernel. - - Args: - target (nn.Module): The layer to optimize for. - vec (torch.Tensor): Vector representing direction to align to. - x (int, optional): The x coordinate of the neuron to optimize for. If - unspecified, defaults to center, or one unit left of center for even - lengths. - y (int, optional): The y coordinate of the neuron to optimize for. If - unspecified, defaults to center, or one unit up of center for even - heights. - channel_index (int): The index of the channel to optimize for. - cossim_pow (float, optional): The desired cosine similarity power to use. - batch_index (int, optional): The index of the image to optimize if we - optimizing a batch of images. If unspecified, defaults to all images - in the batch. """ def __init__( @@ -605,9 +849,33 @@ def __init__( x: Optional[int] = None, y: Optional[int] = None, channel_index: Optional[int] = None, - cossim_pow: Optional[float] = 0.0, + cossim_pow: float = 0.0, batch_index: Optional[int] = None, ) -> None: + """ + Args: + + target (nn.Module): A target layer, transform, or image parameterization + instance to optimize the output of. + vec (torch.Tensor): Vector representing direction to align to. + x (int, optional): The x coordinate of the neuron to optimize for. If + set to ``None``, defaults to center, or one unit left of center for + even lengths. + Default: ``None`` + y (int, optional): The y coordinate of the neuron to optimize for. If + set to ``None``, defaults to center, or one unit up of center for + even heights. + Default: ``None`` + channel_index (int): The index of the channel to optimize for. If set to + ``None``, then all channels will be used. + Default: ``None`` + cossim_pow (float, optional): The desired cosine similarity power to use. + Default: ``0.0`` + batch_index (int, optional): The index of activations to optimize if + optimizing a batch of activations. If set to ``None``, defaults to all + activations in the batch. + Default: ``None`` + """ BaseLoss.__init__(self, target, batch_index) self.vec = vec.reshape((1, -1, 1, 1)) self.x = x @@ -631,7 +899,6 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: return _dot_cossim(self.vec, activations, cossim_pow=self.cossim_pow) -@loss_wrapper class AngledNeuronDirection(BaseLoss): """ Visualize a direction vector with an optional whitened activation vector to @@ -647,11 +914,9 @@ class AngledNeuronDirection(BaseLoss): More information on the algorithm this objective uses can be found here: https://github.com/tensorflow/lucid/issues/116 - This Lucid equivalents of this loss function can be found here: - https://github.com/tensorflow/lucid/blob/master/notebooks/ - activation-atlas/activation-atlas-simple.ipynb - https://github.com/tensorflow/lucid/blob/master/notebooks/ - activation-atlas/class-activation-atlas.ipynb + This Lucid equivalents of this loss objective can be found here: + https://github.com/tensorflow/lucid/blob/master/notebooks/activation-atlas/activation-atlas-simple.ipynb + https://github.com/tensorflow/lucid/blob/master/notebooks/activation-atlas/class-activation-atlas.ipynb Like the Lucid equivalents, our implementation differs slightly from the associated research paper. @@ -673,16 +938,29 @@ def __init__( ) -> None: """ Args: - target (nn.Module): A target layer instance. + + target (nn.Module): A target layer, transform, or image parameterization + instance to optimize the output of. vec (torch.Tensor): A neuron direction vector to use. vec_whitened (torch.Tensor, optional): A whitened neuron direction vector. + If set to ``None``, then no whitened vec will be used. + Default: ``None`` cossim_pow (float, optional): The desired cosine similarity power to use. - x (int, optional): Optionally provide a specific x position for the target - neuron. - y (int, optional): Optionally provide a specific y position for the target - neuron. + x (int, optional): The x coordinate of the neuron to optimize for. If + set to ``None``, defaults to center, or one unit left of center for + even lengths. + Default: ``None`` + y (int, optional): The y coordinate of the neuron to optimize for. If + set to ``None``, defaults to center, or one unit up of center for + even heights. + Default: ``None`` eps (float, optional): If cossim_pow is greater than zero, the desired epsilon value to use for cosine similarity calculations. + Default: ``1.0e-4`` + batch_index (int, optional): The index of activations to optimize if + optimizing a batch of activations. If set to ``None``, defaults to all + activations in the batch. + Default: ``None`` """ BaseLoss.__init__(self, target, batch_index) self.vec = vec.unsqueeze(0) if vec.dim() == 1 else vec @@ -719,30 +997,34 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: return dot * torch.clamp(cossims, min=0.1) ** self.cossim_pow -@loss_wrapper class TensorDirection(BaseLoss): """ Visualize a tensor direction vector. Carter, et al., "Activation Atlas", Distill, 2019. https://distill.pub/2019/activation-atlas/#Aggregating-Multiple-Images Extends Direction loss by allowing batch-wise direction visualization. - - Args: - target (nn.Module): The layer to optimize for. - vec (torch.Tensor): Vector representing direction to align to. - cossim_pow (float, optional): The desired cosine similarity power to use. - batch_index (int, optional): The index of the image to optimize if we - optimizing a batch of images. If unspecified, defaults to all images - in the batch. """ def __init__( self, target: nn.Module, vec: torch.Tensor, - cossim_pow: Optional[float] = 0.0, + cossim_pow: float = 0.0, batch_index: Optional[int] = None, ) -> None: + """ + Args: + + target (nn.Module): A target layer, transform, or image parameterization + instance to optimize the output of. + vec (torch.Tensor): Vector representing direction to align to. + cossim_pow (float, optional): The desired cosine similarity power to use. + Default: ``0.0`` + batch_index (int, optional): The index of activations to optimize if + optimizing a batch of activations. If set to ``None``, defaults to all + activations in the batch. + Default: ``None`` + """ BaseLoss.__init__(self, target, batch_index) assert vec.dim() == 4 self.vec = vec @@ -768,27 +1050,11 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: return _dot_cossim(self.vec, activations, cossim_pow=self.cossim_pow) -@loss_wrapper class ActivationWeights(BaseLoss): """ Apply weights to channels, neurons, or spots in the target. This loss weighs specific channels or neurons in a given layer, via a weight vector. - - Args: - target (nn.Module): The layer to optimize for. - weights (torch.Tensor): Weights to apply to targets. - neuron (bool): Whether target is a neuron. Defaults to False. - x (int, optional): The x coordinate of the neuron to optimize for. If - unspecified, defaults to center, or one unit left of center for even - lengths. - y (int, optional): The y coordinate of the neuron to optimize for. If - unspecified, defaults to center, or one unit up of center for even - heights. - wx (int, optional): Length of neurons to apply the weights to, along the - x-axis. - wy (int, optional): Length of neurons to apply the weights to, along the - y-axis. """ def __init__( @@ -801,6 +1067,29 @@ def __init__( wx: Optional[int] = None, wy: Optional[int] = None, ) -> None: + """ + Args: + + target (nn.Module): A target layer, transform, or image parameterization + instance to optimize the output of. + weights (torch.Tensor): Weights to apply to targets. + neuron (bool): Whether target is a neuron. + Default: ``False`` + x (int, optional): The x coordinate of the neuron to optimize for. If + set to ``None``, defaults to center, or one unit left of center for + even lengths. + Default: ``None`` + y (int, optional): The y coordinate of the neuron to optimize for. If + set to ``None``, defaults to center, or one unit up of center for + even heights. + Default: ``None`` + wx (int, optional): Length of neurons to apply the weights to, along the + x-axis. Set to ``None`` for the full length. + Default: ``None`` + wy (int, optional): Length of neurons to apply the weights to, along the + y-axis. Set to ``None`` for the full length. + Default: ``None`` + """ BaseLoss.__init__(self, target) self.x = x self.y = y @@ -843,26 +1132,41 @@ def sum_loss_list( ) -> CompositeLoss: """ Summarize a large number of losses without recursion errors. By default using 300+ - loss functions for a single optimization task will result in exceeding Python's + loss objectives for a single optimization task will result in exceeding Python's default maximum recursion depth limit. This function can be used to avoid the - recursion depth limit for tasks such as summarizing a large list of loss functions + recursion depth limit for tasks such as summarizing a large list of loss objectives with the built-in sum() function. This function works similar to Lucid's optvis.objectives.Objective.sum() function. Args: - loss_list (list): A list of loss function objectives. - to_scalar_fn (Callable): A function for converting loss function outputs to - scalar values, in order to prevent size mismatches. - Default: torch.mean + loss_list (list): A list of loss objectives. + to_scalar_fn (Callable): A function for converting loss objective outputs to + scalar values, in order to prevent size mismatches. Set to + :class:`torch.nn.Identity` for no reduction op. + Default: :func:`torch.mean` Returns: - loss_fn (CompositeLoss): A composite loss function containing all the loss - functions from `loss_list`. + loss_fn (CompositeLoss): A CompositeLoss instance containing all the loss + functions from ``loss_list``. """ def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: + """ + Pass collected activations through the list of loss objectives based on + specified targets, and then apply a reduction op to reduce them to scalar + before adding them together. + + Args: + + module (ModuleOutputMapping): A dict of captured activations with + nn.Modules as keys. + + Returns: + loss (torch.Tensor): The target activations after being run through the + loss objectives, and then added together. + """ return sum([to_scalar_fn(loss(module)) for loss in loss_list]) name = "Sum(" + ", ".join([loss.__name__ for loss in loss_list]) + ")" @@ -880,19 +1184,26 @@ def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: def default_loss_summarize(loss_value: torch.Tensor) -> torch.Tensor: """ - Helper function to summarize tensor outputs from loss functions. + Helper function to summarize tensor outputs from loss objectives. - default_loss_summarize applies `mean` to the loss tensor + default_loss_summarize applies :func:`torch.mean` to the loss tensor and negates it so that optimizing it maximizes the activations we are interested in. + + Args: + + loss_value (torch.Tensor): A tensor containing the loss values. + + Returns: + loss_value (torch.Tensor): The loss_value's mean multiplied by -1. """ return -1 * loss_value.mean() __all__ = [ "Loss", - "loss_wrapper", "BaseLoss", + "CompositeLoss", "LayerActivation", "ChannelActivation", "NeuronActivation",