diff --git a/captum/_utils/gradient.py b/captum/_utils/gradient.py index 4d885ff749..b6caea77cb 100644 --- a/captum/_utils/gradient.py +++ b/captum/_utils/gradient.py @@ -849,18 +849,21 @@ def _compute_jacobian_wrt_params_with_sample_wise_trick( if labels is not None and loss_fn is not None: loss = loss_fn(out, labels) # TODO: allow loss_fn to be Callable - if isinstance(loss_fn, Module) and hasattr(loss_fn, "reduction"): + if (isinstance(loss_fn, Module) or callable(loss_fn)) and hasattr( + loss_fn, "reduction" + ): + reduction = loss_fn.reduction # type: ignore msg0 = ( "Please ensure that loss_fn.reduction is set to `sum` or `mean`" ) - assert loss_fn.reduction != "none", msg0 + assert reduction != "none", msg0 msg1 = ( - f"loss_fn.reduction ({loss_fn.reduction}) does not match" + f"loss_fn.reduction ({reduction}) does not match" f"reduction type ({reduction_type}). Please ensure they are" " matching." ) - assert loss_fn.reduction == reduction_type, msg1 + assert reduction == reduction_type, msg1 msg2 = ( "Please ensure custom loss function is applying either a " "sum or mean reduction." diff --git a/captum/influence/_core/tracincp.py b/captum/influence/_core/tracincp.py index 453e99410a..02c388378c 100644 --- a/captum/influence/_core/tracincp.py +++ b/captum/influence/_core/tracincp.py @@ -26,6 +26,7 @@ from captum._utils.progress import NullProgress, progress from captum.influence._core.influence import DataInfluence from captum.influence._utils.common import ( + _check_loss_fn, _format_inputs_dataset, _get_k_most_influential_helper, _gradient_dot_product, @@ -102,6 +103,7 @@ def __init__( checkpoints_load_func: Callable = _load_flexible_state_dict, loss_fn: Optional[Union[Module, Callable]] = None, batch_size: Union[int, None] = 1, + test_loss_fn: Optional[Union[Module, Callable]] = None, ) -> None: r""" Args: @@ -152,6 +154,19 @@ def __init__( `train_dataset` is a Dataset. If `train_dataset` is a DataLoader, then `batch_size` is ignored as an argument. Default: 1 + test_loss_fn (Callable, optional): In some cases, one may want to use a + separate loss functions for training examples, i.e. those in + `train_dataset`, and for test examples, i.e. those + represented by the `inputs` and `targets` arguments to the + `influence` method. For example, if one wants to calculate the + influence score of a training example on a test example's + prediction for a fixed class, `test_loss_fn` could map from the + logits for all classes to the logits for a fixed class. + `test_loss_fn` needs to satisfy the same constraints as `loss_fn`. + If not provided, the loss function for test examples is assumed to + be the same as the loss function for training examples, i.e. + `loss_fn`. + Default: None """ self.model = model @@ -167,6 +182,8 @@ def __init__( self.checkpoints_load_func = checkpoints_load_func self.loss_fn = loss_fn + # If test_loss_fn not provided, it's assumed to be same as loss_fn + self.test_loss_fn = loss_fn if test_loss_fn is None else test_loss_fn self.batch_size = batch_size if not isinstance(train_dataset, DataLoader): @@ -489,6 +506,7 @@ def __init__( layers: Optional[List[str]] = None, loss_fn: Optional[Union[Module, Callable]] = None, batch_size: Union[int, None] = 1, + test_loss_fn: Optional[Union[Module, Callable]] = None, sample_wise_grads_per_batch: bool = False, ) -> None: r""" @@ -561,6 +579,24 @@ def __init__( `train_dataset` is a Dataset. If `train_dataset` is a DataLoader, then `batch_size` is ignored as an argument. Default: 1 + test_loss_fn (Callable, optional): In some cases, one may want to use a + separate loss functions for training examples, i.e. those in + `train_dataset`, and for test examples, i.e. those + represented by the `inputs` and `targets` arguments to the + `influence` method. For example, if one wants to calculate the + influence score of a training example on a test example's + prediction for a fixed class, `test_loss_fn` could map from the + logits for all classes to the logits for a fixed class. + `test_loss_fn` needs satisfy the same constraints as `loss_fn`. + Thus, the same checks that we apply to `loss_fn` are also applied + to `test_loss_fn`, if the latter is provided. Note that the + constraints on both `loss_fn` and `test_loss_fn` both depend on + `sample_wise_grads_per_batch`. This means `loss_fn` and + `test_loss_fn` must either both be "per-example" loss functions, + or both be "reduction" loss functions. If not provided, the loss + function for test examples is assumed to be the same as the loss + function for training examples, i.e. `loss_fn`. + Default: None sample_wise_grads_per_batch (bool, optional): PyTorch's native gradient computations w.r.t. model parameters aggregates the results for a batch and does not allow to access sample-wise gradients w.r.t. @@ -590,51 +626,23 @@ def __init__( checkpoints_load_func, loss_fn, batch_size, + test_loss_fn, ) self.sample_wise_grads_per_batch = sample_wise_grads_per_batch - # If we are able to access the reduction used by `loss_fn`, we check whether - # the reduction is compatible with `sample_wise_grads_per_batch` - if isinstance(loss_fn, Module) and hasattr( - loss_fn, "reduction" - ): # TODO: allow loss_fn to be Callable - if self.sample_wise_grads_per_batch: - assert loss_fn.reduction in ["sum", "mean"], ( - 'reduction for `loss_fn` must be "sum" or "mean" when ' - "`sample_wise_grads_per_batch` is True" - ) - self.reduction_type = str(loss_fn.reduction) - else: - assert loss_fn.reduction == "none", ( - 'reduction for `loss_fn` must be "none" when ' - "`sample_wise_grads_per_batch` is False" - ) - else: - # if we are unable to access the reduction used by `loss_fn`, we warn - # the user about the assumptions we are making regarding the reduction - # used by `loss_fn` - if self.sample_wise_grads_per_batch: - warnings.warn( - 'Since `loss_fn` has no "reduction" attribute, and ' - "`sample_wise_grads_per_batch` is True, the implementation assumes " - 'that `loss_fn` is a "reduction" loss function that reduces the ' - "per-example losses by taking their *sum*. If `loss_fn` " - "instead reduces the per-example losses by taking their mean, " - 'please set the reduction attribute of `loss_fn` to "mean", i.e. ' - '`loss_fn.reduction = "mean"`. Note that if ' - "`sample_wise_grads_per_batch` is True, the implementation " - "assumes the reduction is either a sum or mean reduction." - ) - self.reduction_type = "sum" - else: - warnings.warn( - 'Since `loss_fn` has no "reduction" attribute, and ' - "`sample_wise_grads_per_batch` is False, the implementation " - 'assumes that `loss_fn` is a "per-example" loss function (see ' - "documentation for `loss_fn` for details). Please ensure that " - "this is the case." - ) + # check `loss_fn` + self.reduction_type = _check_loss_fn( + self, loss_fn, "loss_fn", sample_wise_grads_per_batch + ) + # check `test_loss_fn` if it was provided + self.test_reduction_type = ( + self.reduction_type + if test_loss_fn is None + else _check_loss_fn( + self, test_loss_fn, "test_loss_fn", sample_wise_grads_per_batch + ) + ) r""" TODO: Either restore model state after done (would have to place functionality @@ -769,6 +777,162 @@ def influence( # type: ignore[override] show_progress, ) + def _sum_jacobians( + self, + inputs_dataset: DataLoader, + loss_fn: Optional[Union[Module, Callable]] = None, + reduction_type: Optional[str] = None, + ): + """ + sums the jacobians of all examples in `inputs_dataset`. result is of the + same format as layer_jacobians, but the batch dimension has size 1 + """ + inputs_dataset_iter = iter(inputs_dataset) + + inputs_batch = next(inputs_dataset_iter) + + def get_batch_contribution(inputs_batch): + _input_jacobians = self._basic_computation_tracincp( + inputs_batch[0:-1], + inputs_batch[-1], + loss_fn, + reduction_type, + ) + + return tuple( + torch.sum(jacobian, dim=0).unsqueeze(0) for jacobian in _input_jacobians + ) + + inputs_jacobians = get_batch_contribution(inputs_batch) + + for inputs_batch in inputs_dataset_iter: + inputs_batch_jacobians = get_batch_contribution(inputs_batch) + inputs_jacobians = tuple( + [ + inputs_jacobian + inputs_batch_jacobian + for (inputs_jacobian, inputs_batch_jacobian) in zip( + inputs_jacobians, inputs_batch_jacobians + ) + ] + ) + + return inputs_jacobians + + def _concat_jacobians( + self, + inputs_dataset: DataLoader, + loss_fn: Optional[Union[Module, Callable]] = None, + reduction_type: Optional[str] = None, + ): + all_inputs_batch_jacobians = [ + self._basic_computation_tracincp( + inputs_batch[0:-1], + inputs_batch[-1], + loss_fn, + reduction_type, + ) + for inputs_batch in inputs_dataset + ] + + return tuple( + torch.cat(all_inputs_batch_jacobian, dim=0) + for all_inputs_batch_jacobian in zip(*all_inputs_batch_jacobians) + ) + + @log_usage() + def compute_intermediate_quantities( + self, + inputs_dataset: Union[Tuple[Any, ...], DataLoader], + aggregate: bool = False, + ) -> Tensor: + """ + Computes "embedding" vectors for all examples in a single batch, or a + `Dataloader` that yields batches. These embedding vectors are constructed so + that the influence score of a training example on a test example is simply the + dot-product of their corresponding vectors. Allowing a `DataLoader` + yielding batches to be passed in (as opposed to a single batch) gives the + potential to improve efficiency, because we load each checkpoint only once in + this method call. Thus if a `DataLoader` yielding batches is passed in, this + reduces the total number of times each checkpoint is loaded for a dataset, + compared to if a single batch is passed in. The reason we do not just increase + the batch size is that for large models, large batches do not fit in memory. + + If `aggregate` is True, the *sum* of the vectors for all examples is returned, + instead of the vectors for each example. This can be useful for computing the + influence of a given training example on the total loss over a validation + dataset, because due to properties of the dot-product, this influence is the + dot-product of the training example's vector with the sum of the vectors in the + validation dataset. Also, by doing the sum aggregation within this method as + opposed to outside of it (by computing all vectors for the validation dataset, + then taking the sum) allows memory usage to be reduced. + + Args: + inputs_dataset (Tuple, or DataLoader): Either a single tuple of any, or a + `DataLoader`, where each batch yielded is a tuple of any. In + either case, the tuple represents a single batch, where the last + element is assumed to be the labels for the batch. That is, + `model(*batch[0:-1])` produces the output for `model`, and + and `batch[-1]` are the labels, if any. Here, `model` is model + provided in initialization. This is the same assumption made for + each batch yielded by training dataset `train_dataset`. + aggregate (bool): Whether to return the sum of the vectors for all + examples, as opposed to vectors for each example. + + Returns: + intermediate_quantities (Tensor): A tensor of dimension + (N, D * C). Here, N is the total number of examples in + `inputs_dataset` if `aggregate` is False, and 1, otherwise (so that + a 2D tensor is always returned). C is the number of checkpoints + passed as the `checkpoints` argument of `TracInCP.__init__`, and + each row represents the vector for an example. Regarding D: Let I + be the dimension of the output of the last fully-connected layer + times the dimension of the input of the last fully-connected layer. + If `self.projection_dim` is specified in initialization, + D = min(I * C, `self.projection_dim` * C). Otherwise, D = I * C. + In summary, if `self.projection_dim` is None, the dimension of each + vector will be determined by the size of the input and output of + the last fully-connected layer of `model`. Otherwise, + `self.projection_dim` must be an int, and random projection will be + performed to ensure that the vector is of dimension no more than + `self.projection_dim` * C. `self.projection_dim` corresponds to + the variable d in the top of page 15 of the TracIn paper: + https://arxiv.org/pdf/2002.08484.pdf. + """ + # If `inputs_dataset` is not a `DataLoader`, turn it into one. + inputs_dataset = _format_inputs_dataset(inputs_dataset) + + def get_checkpoint_contribution(checkpoint): + assert ( + checkpoint is not None + ), "None returned from `checkpoints`, cannot load." + + learning_rate = self.checkpoints_load_func(self.model, checkpoint) + # get jacobians as tuple of tensors + if aggregate: + inputs_jacobians = self._sum_jacobians( + inputs_dataset, self.loss_fn, self.reduction_type + ) + else: + inputs_jacobians = self._concat_jacobians( + inputs_dataset, self.loss_fn, self.reduction_type + ) + # flatten into single tensor + return learning_rate * torch.cat( + [ + input_jacobian.flatten(start_dim=1) + for input_jacobian in inputs_jacobians + ], + dim=1, + ) + + return torch.cat( + [ + get_checkpoint_contribution(checkpoint) + for checkpoint in self.checkpoints + ], + dim=1, + ) + def _influence_batch_tracincp( self, inputs: Tuple[Any, ...], @@ -790,11 +954,15 @@ def get_checkpoint_contribution(checkpoint): input_jacobians = self._basic_computation_tracincp( inputs, targets, + self.test_loss_fn, + self.test_reduction_type, ) return ( _gradient_dot_product( input_jacobians, - self._basic_computation_tracincp(batch[0:-1], batch[-1]), + self._basic_computation_tracincp( + batch[0:-1], batch[-1], self.loss_fn, self.reduction_type + ), ) * learning_rate ) @@ -1055,7 +1223,10 @@ def get_checkpoint_contribution(checkpoint): for batch in _inputs_dataset: layer_jacobians = self._basic_computation_tracincp( - batch[0:-1], batch[-1] + batch[0:-1], + batch[-1], + self.loss_fn, + self.reduction_type, ) # Note that all variables in this function are for an entire batch. @@ -1115,6 +1286,7 @@ def get_checkpoint_contribution(checkpoint): return batches_self_tracin_scores + @log_usage() def self_influence( self, inputs_dataset: Optional[Union[Tuple[Any, ...], DataLoader]] = None, @@ -1196,11 +1368,14 @@ def _basic_computation_tracincp( self, inputs: Tuple[Any, ...], targets: Optional[Tensor] = None, + loss_fn: Optional[Union[Module, Callable]] = None, + reduction_type: Optional[str] = None, ) -> Tuple[Tensor, ...]: """ For instances of TracInCP, computation of influence scores or self influence scores repeatedly calls this function for different checkpoints - and batches. + and batches. In particular, this function computes the jacobian of a loss + function w.r.t. parameters in the `layers` initialization argument. Args: @@ -1210,20 +1385,26 @@ def _basic_computation_tracincp( that `model(*inputs)` produces the predictions for the batch. targets (tensor or None): If computing influence scores on a loss function, these are the labels corresponding to the batch `inputs`. + Default: none + loss_fn (Callable, optional): The loss function to use when computing the + jacobian. + reduction_type (str, optional): The reduction type of `loss_fn`. This + argument is only used if `sample_wise_grads_per_batch` was true in + initialization. """ if self.sample_wise_grads_per_batch: return _compute_jacobian_wrt_params_with_sample_wise_trick( self.model, inputs, targets, - self.loss_fn, - self.reduction_type, + loss_fn, + reduction_type, self.layer_modules, ) return _compute_jacobian_wrt_params( self.model, inputs, targets, - self.loss_fn, + loss_fn, self.layer_modules, ) diff --git a/captum/influence/_core/tracincp_fast_rand_proj.py b/captum/influence/_core/tracincp_fast_rand_proj.py index d0c8eeee7d..dec58914f3 100644 --- a/captum/influence/_core/tracincp_fast_rand_proj.py +++ b/captum/influence/_core/tracincp_fast_rand_proj.py @@ -16,6 +16,7 @@ TracInCPBase, ) from captum.influence._utils.common import ( + _check_loss_fn, _DatasetFromList, _format_inputs_dataset, _get_k_most_influential_helper, @@ -88,6 +89,7 @@ def __init__( checkpoints_load_func: Callable = _load_flexible_state_dict, loss_fn: Optional[Union[Module, Callable]] = None, batch_size: Union[int, None] = 1, + test_loss_fn: Optional[Union[Module, Callable]] = None, vectorize: bool = False, ) -> None: r""" @@ -153,6 +155,20 @@ def __init__( `train_dataset` is a Dataset. If `train_dataset` is a DataLoader, then `batch_size` is ignored as an argument. Default: 1 + test_loss_fn (Callable, optional): In some cases, one may want to use a + separate loss functions for training examples, i.e. those in + `train_dataset`, and for test examples, i.e. those + represented by the `inputs` and `targets` arguments to the + `influence` method. For example, if one wants to calculate the + influence score of a training example on a test example's + prediction for a fixed class, `test_loss_fn` could map from the + logits for all classes to the logits for a fixed class. + `test_loss_fn` needs satisfy the same constraints as `loss_fn`. + Thus, the same checks that we apply to `loss_fn` are also applied + to `test_loss_fn`, if the latter is provided. If not provided, the + loss function for test examples is assumed to be the same as the + loss function for training examples, i.e. `loss_fn`. + Default: None vectorize (bool, optional): Flag to use experimental vectorize functionality for `torch.autograd.functional.jacobian`. Default: False @@ -165,6 +181,7 @@ def __init__( checkpoints_load_func, loss_fn, batch_size, + test_loss_fn, ) self.vectorize = vectorize @@ -179,29 +196,14 @@ def __init__( assert loss_fn is not None, "loss function must not be none" - # If we are able to access the reduction used by `loss_fn`, we check whether - # the reduction is either 'sum' or 'mean', as required - if isinstance(loss_fn, Module) and hasattr( - loss_fn, "reduction" - ): # TODO: allow loss_fn to be Callable - assert loss_fn.reduction in [ - "sum", - "mean", - ], 'reduction for `loss_fn` must be "sum" or "mean"' - self.reduction_type = str(loss_fn.reduction) - else: - # if we are unable to access the reduction used by `loss_fn`, we warn - # the user about the assumptions we are making regarding the reduction - # used by `loss_fn` - warnings.warn( - 'Since `loss_fn` has no "reduction" attribute, the implementation ' - 'assumes that `loss_fn` is a "reduction" loss function that ' - "reduces the per-example losses by taking their *sum*. If " - "`loss_fn` instead reduces the per-example losses by taking their " - 'mean, please set the reduction attribute of `loss_fn` to "mean", ' - 'i.e. `loss_fn.reduction = "mean"`.' - ) - self.reduction_type = "sum" + # check `loss_fn` + self.reduction_type = _check_loss_fn(self, loss_fn, "loss_fn") + # check `test_loss_fn` if it was provided + self.test_reduction_type = ( + self.reduction_type + if test_loss_fn is None + else _check_loss_fn(self, test_loss_fn, "test_loss_fn") + ) @log_usage() def influence( # type: ignore[override] @@ -340,10 +342,16 @@ def get_checkpoint_contribution(checkpoint): self, inputs, targets, + self.test_loss_fn, + self.test_reduction_type, ) src_jacobian, src_layer_input = _basic_computation_tracincp_fast( - self, batch[0:-1], batch[-1] + self, + batch[0:-1], + batch[-1], + self.loss_fn, + self.reduction_type, ) return ( _tensor_batch_dot( @@ -603,7 +611,11 @@ def get_checkpoint_contribution(checkpoint): for batch in _inputs_dataset: batch_jacobian, batch_layer_input = _basic_computation_tracincp_fast( - self, batch[0:-1], batch[-1] + self, + batch[0:-1], + batch[-1], + self.loss_fn, + self.reduction_type, ) checkpoint_contribution.append( @@ -640,6 +652,7 @@ def get_checkpoint_contribution(checkpoint): checkpoints_progress.update() return batches_self_tracin_scores + @log_usage() def self_influence( self, inputs_dataset: Optional[Union[Tuple[Any, ...], DataLoader]] = None, @@ -722,11 +735,18 @@ def _basic_computation_tracincp_fast( influence_instance: TracInCPFast, inputs: Tuple[Any, ...], targets: Tensor, + loss_fn: Optional[Union[Module, Callable]] = None, + reduction_type: Optional[str] = None, ): """ For instances of TracInCPFast and children classes, computation of influence scores or self influence scores repeatedly calls this function for different checkpoints - and batches. + and batches. These computations involve a loss function. If `test` is True, the + loss function is `self.loss_fn`. If `test` is False, the loss function is + `self.test_loss_fn`. These two attributes were set in initialization, with + `self.loss_fn` equal to the `loss_fn` initialization argument, and + `self.test_loss_fn` equal to the `test_loss_fn` initialization argument if it was + provided, and `loss_fn` otherwise. Args: @@ -742,6 +762,11 @@ def _basic_computation_tracincp_fast( that `model(*inputs)` produces the predictions for the batch. targets (Tensor): If computing influence scores on a loss function, these are the labels corresponding to the batch `inputs`. + loss_fn (Callable, optional): The loss function to use when computing the + jacobian. + reduction_type (str, optional): The reduction type of `loss_fn`. This argument + is only used if `sample_wise_grads_per_batch` was true in + initialization of `influence_instance`. Returns: (input_jacobians, layer_inputs) (tuple): `input_jacobians` is a 2D tensor, @@ -773,17 +798,17 @@ def _capture_inputs(layer, input, output) -> None: ) out = influence_instance.model(*inputs) - assert influence_instance.loss_fn is not None, "loss function is required" - assert influence_instance.reduction_type in [ + assert loss_fn is not None, "loss function is required" + assert reduction_type in [ "sum", "mean", ], 'reduction_type must be either "mean" or "sum"' input_jacobians = _jacobian_loss_wrt_inputs( - influence_instance.loss_fn, + loss_fn, out, targets, influence_instance.vectorize, - influence_instance.reduction_type, + reduction_type, ) handle.remove() @@ -863,6 +888,7 @@ def __init__( checkpoints_load_func: Callable = _load_flexible_state_dict, loss_fn: Optional[Union[Module, Callable]] = None, batch_size: Union[int, None] = 1, + test_loss_fn: Optional[Union[Module, Callable]] = None, vectorize: bool = False, nearest_neighbors: Optional[NearestNeighbors] = None, projection_dim: int = None, @@ -927,6 +953,19 @@ def __init__( `train_dataset` is a Dataset. If `train_dataset` is a DataLoader, then `batch_size` is ignored as an argument. Default: 1 + test_loss_fn (Callable, optional): In some cases, one may want to use a + separate loss functions for training examples, i.e. those in + `train_dataset`, and for test examples, i.e. those + represented by the `inputs` and `targets` arguments to the + `influence` method. For example, if one wants to calculate the + influence score of a training example on a test example's + prediction for a fixed class, `test_loss_fn` could map from the + logits for all classes to the logits for a fixed class. + `test_loss_fn` needs satisfy the same constraints as `loss_fn`. + Thus, the same checks that we apply to `loss_fn` are also applied + to `test_loss_fn`, if the latter is provided. If not provided, the + loss function for test examples is assumed to be the same as the + loss function for training examples, i.e. `loss_fn`. vectorize (bool): Flag to use experimental vectorize functionality for `torch.autograd.functional.jacobian`. Default: False @@ -970,6 +1009,7 @@ def __init__( checkpoints_load_func, loss_fn, batch_size, + test_loss_fn, vectorize, ) @@ -1038,6 +1078,7 @@ def _influence( # type: ignore[override] _DatasetFromList([inputs_batch]), shuffle=False, batch_size=None ), self.projection_quantities, + test=True, ) src_projections = self.src_intermediate_quantities @@ -1088,6 +1129,7 @@ def _get_k_most_influential( # type: ignore[override] _DatasetFromList([inputs_batch]), shuffle=False, batch_size=None ), self.projection_quantities, + test=True, ) multiplier = 1 if proponents else -1 @@ -1101,6 +1143,7 @@ def _get_k_most_influential( # type: ignore[override] return KMostInfluentialResults(indices, distances) + @log_usage() def self_influence( self, inputs_dataset: Optional[Union[Tuple[Any, ...], DataLoader]] = None, @@ -1326,6 +1369,8 @@ def _set_projections_tracincp_fast_rand_proj( self, batch[0:-1], batch[-1], + self.loss_fn, + self.reduction_type, ) jacobian_dim = batch_jacobians.shape[ @@ -1398,6 +1443,7 @@ def _get_intermediate_quantities_tracincp_fast_rand_proj( self, inputs_dataset: Union[Tuple[Any, ...], DataLoader], projection_quantities: Optional[Tuple[torch.Tensor, torch.Tensor]], + test: bool = False, ) -> torch.Tensor: r""" This method computes vectors that can be used to compute influence. (see @@ -1422,6 +1468,10 @@ def _get_intermediate_quantities_tracincp_fast_rand_proj( projection_quantities (tuple or None): Is either the two tensors defining the randomized projections to apply, or None, which means no projection is to be applied. + test (bool): If True, the intermediate quantities are computed using + `self.test_loss_fn`. Otherwise, they are computed using + `self.loss_fn`. + Default: False Returns: intermediate_quantities (Tensor): A tensor of dimension @@ -1490,6 +1540,8 @@ def _get_intermediate_quantities_tracincp_fast_rand_proj( self, batch[0:-1], batch[-1], + self.test_loss_fn, + self.test_reduction_type, ) # if doing projection, project those two quantities @@ -1539,6 +1591,7 @@ def _get_intermediate_quantities_tracincp_fast_rand_proj( # each row in this result is the "embedding" vector for an example in `batch` return torch.cat(checkpoint_contributions, dim=1) # type: ignore + @log_usage() def compute_intermediate_quantities( self, inputs_dataset: Union[Tuple[Any, ...], DataLoader], diff --git a/captum/influence/_utils/common.py b/captum/influence/_utils/common.py index b43e3aa553..e1a6e27f8f 100644 --- a/captum/influence/_utils/common.py +++ b/captum/influence/_utils/common.py @@ -1,13 +1,15 @@ #!/usr/bin/env python3 - import warnings -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, TYPE_CHECKING, Union import torch import torch.nn as nn from captum._utils.common import _parse_version from captum._utils.progress import progress +if TYPE_CHECKING: + from captum.influence._core.tracincp import TracInCPBase + from torch import Tensor from torch.nn import Module from torch.utils.data import DataLoader, Dataset @@ -419,3 +421,89 @@ def _self_influence_by_batches_helper( for batch in inputs_dataset ] ) + + +def _check_loss_fn( + influence_instance: "TracInCPBase", + loss_fn: Optional[Union[Module, Callable]], + loss_fn_name: str, + sample_wise_grads_per_batch: Optional[bool] = None, +) -> str: + """ + This checks whether `loss_fn` satisfies the requirements assumed of all + implementations of `TracInCPBase`. It works regardless of whether the + implementation has the `sample_wise_grads_per_batch` attribute. + It returns the reduction type of the loss_fn. If `sample_wise_grads_per_batch` + if not provided, we assume the implementation does not have that attribute. + """ + # if `loss_fn` is `None`, there is nothing to check. then, the reduction type is + # only used by `_compute_jacobian_wrt_params_with_sample_wise_trick`, where + # reduction type should be "sum" if `loss_fn` is `None`. + if loss_fn is None: + return "sum" + + # perhaps since `Module` is an implementation of `Callable`, this has redundancy + assert isinstance(loss_fn, Module) or callable(loss_fn) + + reduction_type = "none" + + # If we are able to access the reduction used by `loss_fn`, we check whether + # the reduction is compatible with `sample_wise_grads_per_batch`, if it has the + # attribute. + if hasattr(loss_fn, "reduction"): + reduction = loss_fn.reduction # type: ignore + if sample_wise_grads_per_batch is None: + assert reduction in [ + "sum", + "mean", + ], 'reduction for `loss_fn` must be "sum" or "mean"' + reduction_type = str(reduction) + elif sample_wise_grads_per_batch: + assert reduction in ["sum", "mean"], ( + 'reduction for `loss_fn` must be "sum" or "mean" when ' + "`sample_wise_grads_per_batch` is True" + ) + reduction_type = str(reduction) + else: + assert reduction == "none", ( + 'reduction for `loss_fn` must be "none" when ' + "`sample_wise_grads_per_batch` is False" + ) + else: + # if we are unable to access the reduction used by `loss_fn`, we warn + # the user about the assumptions we are making regarding the reduction + # used by `loss_fn` + if sample_wise_grads_per_batch is None: + warnings.warn( + f'Since `{loss_fn_name}` has no "reduction" attribute, the ' + f'implementation assumes that `{loss_fn_name}` is a "reduction" loss ' + "function that reduces the per-example losses by taking their *sum*. " + f"If `{loss_fn_name}` instead reduces the per-example losses by " + f"taking their mean, please set the reduction attribute of " + f'`{loss_fn_name}` to "mean", i.e. ' + f'`{loss_fn_name}.reduction = "mean"`.' + ) + reduction_type = "sum" + elif sample_wise_grads_per_batch: + warnings.warn( + f"Since `{loss_fn_name}`` has no 'reduction' attribute, and " + "`sample_wise_grads_per_batch` is True, the implementation assumes " + f"that `{loss_fn_name}` is a 'reduction' loss function that reduces " + f"the per-example losses by taking their *sum*. If `{loss_fn_name}` " + "instead reduces the per-example losses by taking their mean, " + f'please set the reduction attribute of `{loss_fn_name}` to "mean", ' + f'i.e. `{loss_fn_name}.reduction = "mean"`. Note that if ' + "`sample_wise_grads_per_batch` is True, the implementation " + "assumes the reduction is either a sum or mean reduction." + ) + reduction_type = "sum" + else: + warnings.warn( + f'Since `{loss_fn_name}` has no "reduction" attribute, and ' + "`sample_wise_grads_per_batch` is False, the implementation " + f'assumes that `{loss_fn_name}` is a "per-example" loss function (see ' + f"documentation for `{loss_fn_name}` for details). Please ensure " + "that this is the case." + ) + + return reduction_type diff --git a/tests/influence/_core/test_tracin_intermediate_quantities.py b/tests/influence/_core/test_tracin_intermediate_quantities.py index 9f0daebad3..5d1dde3ab3 100644 --- a/tests/influence/_core/test_tracin_intermediate_quantities.py +++ b/tests/influence/_core/test_tracin_intermediate_quantities.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn +from captum.influence._core.tracincp import TracInCP from captum.influence._core.tracincp_fast_rand_proj import ( TracInCPFast, TracInCPFastRandProj, @@ -19,12 +20,68 @@ class TestTracInIntermediateQuantities(BaseTest): + @parameterized.expand( + [ + (reduction, constructor, unpack_inputs) + for unpack_inputs in [True, False] + for (reduction, constructor) in [ + ("none", DataInfluenceConstructor(TracInCP)), + ] + ], + name_func=build_test_name_func(), + ) + def test_tracin_intermediate_quantities_aggregate( + self, reduction: str, tracin_constructor: Callable, unpack_inputs: bool + ) -> None: + """ + tests that calling `compute_intermediate_quantities` with `aggregate=True` + does give the same result as calling it with `aggregate=False`, and then + summing + """ + with tempfile.TemporaryDirectory() as tmpdir: + (net, train_dataset,) = get_random_model_and_data( + tmpdir, + unpack_inputs, + return_test_data=False, + ) + + # create a dataloader that yields batches from the dataset + train_dataset = DataLoader(train_dataset, batch_size=5) + + # create tracin instance + criterion = nn.MSELoss(reduction=reduction) + batch_size = 5 + + tracin = tracin_constructor( + net, + train_dataset, + tmpdir, + batch_size, + criterion, + ) + + intermediate_quantities = tracin.compute_intermediate_quantities( + train_dataset, aggregate=False + ) + aggregated_intermediate_quantities = tracin.compute_intermediate_quantities( + train_dataset, aggregate=True + ) + + assertTensorAlmostEqual( + self, + torch.sum(intermediate_quantities, dim=0, keepdim=True), + aggregated_intermediate_quantities, + delta=1e-4, # due to numerical issues, we can't set this to 0.0 + mode="max", + ) + @parameterized.expand( [ (reduction, constructor, unpack_inputs) for unpack_inputs in [True, False] for (reduction, constructor) in [ ("sum", DataInfluenceConstructor(TracInCPFastRandProj)), + ("none", DataInfluenceConstructor(TracInCP)), ] ], name_func=build_test_name_func(), @@ -103,6 +160,11 @@ def test_tracin_intermediate_quantities_api( DataInfluenceConstructor(TracInCPFast), DataInfluenceConstructor(TracInCPFastRandProj), ), + ( + "none", + DataInfluenceConstructor(TracInCP), + DataInfluenceConstructor(TracInCP), + ), ] ], name_func=build_test_name_func(), diff --git a/tests/influence/_core/test_tracin_regression.py b/tests/influence/_core/test_tracin_regression.py index 21de267332..147566bc22 100644 --- a/tests/influence/_core/test_tracin_regression.py +++ b/tests/influence/_core/test_tracin_regression.py @@ -110,6 +110,24 @@ def _test_tracin_regression_setup( projection_dim=1, ), ), + ( + "check_idx", + "mean", + DataInfluenceConstructor( + TracInCPFast, + name="TracInCPFastDuplicateLossFn", + duplicate_loss_fn=True, + ), + ), # add a test where `duplicate_loss_fn` is True + ( + "check_idx", + "mean", + DataInfluenceConstructor( + TracInCPFastRandProj, + name="TracInCPFastRandProjDuplicateLossFn", + duplicate_loss_fn=True, + ), # add a test where `duplicate_loss_fn` is True + ), ]: if not (mode == "sample_wise_trick" and use_gpu): param_list.append((reduction, constructor, mode, dim, use_gpu)) @@ -404,3 +422,79 @@ def test_tracin_identity_regression( assertTensorAlmostEqual( self, train_scores, train_scores_tracin_sample_wise_trick ) + + @parameterized.expand( + [ + ("none", "none", DataInfluenceConstructor(TracInCP)), + ( + "mean", + "mean", + DataInfluenceConstructor(TracInCP, sample_wise_grads_per_batch=True), + ), + ("sum", "sum", DataInfluenceConstructor(TracInCPFast)), + ("mean", "mean", DataInfluenceConstructor(TracInCPFast)), + ("sum", "sum", DataInfluenceConstructor(TracInCPFastRandProj)), + ("mean", "mean", DataInfluenceConstructor(TracInCPFastRandProj)), + ], + name_func=build_test_name_func(), + ) + def test_tracin_constant_test_loss_fn( + self, + reduction: Optional[str], + test_reduction: Optional[str], + tracin_constructor: Callable, + ) -> None: + """ + All implementations of `TracInCPBase` can accept `test_loss_fn` in + initialization, which sets the loss function applied to test examples, which + can thus be different from the loss function applied to training examples. + This test passes `test_loss_fn` to be a constant function. Then, the influence + scores should all be 0, because gradients w.r.t. `test_loss_fn` will all be 0. + It re-uses the dataset and model from `test_tracin_identity_regression`. + + The reduction for `loss_fn` and `test_loss_fn` initialization arguments is + the same for all parameterized tests, for simplicity, and also because for + `TracInCP`, both loss functions must both be reduction loss functions (i.e. + reduction is "mean" or "sum"), or both be per-example loss functions (i.e. + reduction is "none"). Recall that for `TracInCP`, the + `sample_wise_grads_per_batch` initialization argument determines which of + those cases holds. + """ + with tempfile.TemporaryDirectory() as tmpdir: + + batch_size = 4 + + dataset, net = self._test_tracin_identity_regression_setup(tmpdir) + + train_inputs = dataset.samples + train_labels = dataset.labels + + self.assertTrue(callable(tracin_constructor)) + + self.assertTrue(isinstance(reduction, str)) + criterion = nn.MSELoss(reduction=cast(str, reduction)) + + # the output of `net`, i.e. `input` for the loss functions below, is a + # batch_size x 1 2D tensor + if test_reduction == "none": + # loss function returns 1D tensor of all 0's, so is constant + def test_loss_fn(input, target): + return input.squeeze() * 0.0 + + elif test_reduction in ["sum", "mean"]: + # loss function returns scalar tensor of all 0's, so is constant + def test_loss_fn(input, target): + return input.mean() * 0.0 + + tracin = tracin_constructor( + net, + dataset, + tmpdir, + batch_size, + criterion, + test_loss_fn=test_loss_fn, + ) + + # check influence scores of training data. they should all be 0 + train_scores = tracin.influence(train_inputs, train_labels, k=None) + assertTensorAlmostEqual(self, train_scores, torch.zeros(train_scores.shape)) diff --git a/tests/influence/_utils/common.py b/tests/influence/_utils/common.py index 999dc6404f..dbfc0de550 100644 --- a/tests/influence/_utils/common.py +++ b/tests/influence/_utils/common.py @@ -248,10 +248,19 @@ class DataInfluenceConstructor: data_influence_class: type def __init__( - self, data_influence_class: type, name: Optional[str] = None, **kwargs + self, + data_influence_class: type, + name: Optional[str] = None, + duplicate_loss_fn: bool = False, + **kwargs, ) -> None: + """ + if `duplicate_loss_fn` is True, will explicitly pass the provided `loss_fn` as + the `test_loss_fn` when constructing the TracInCPBase instance + """ self.data_influence_class = data_influence_class self.name = name if name else data_influence_class.__name__ + self.duplicate_loss_fn = duplicate_loss_fn self.kwargs = kwargs def __repr__(self) -> str: @@ -266,8 +275,14 @@ def __call__( loss_fn: Optional[Union[Module, Callable]], **kwargs, ) -> DataInfluence: - constuctor_kwargs = self.kwargs.copy() - constuctor_kwargs.update(kwargs) + constructor_kwargs = self.kwargs.copy() + constructor_kwargs.update(kwargs) + # if `self.duplicate_loss_fn`, explicitly pass in `loss_fn` as `test_loss_fn` + # when constructing the instance. Doing so should not affect the behavior of + # the returned tracincp instance, since if `test_loss_fn` is not passed in, + # the constructor sets `test_loss_fn` to be the same as `loss_fn` + if self.duplicate_loss_fn: + constructor_kwargs["test_loss_fn"] = loss_fn if self.data_influence_class is TracInCPFastRandProj: self.check_annoy() if self.data_influence_class in [TracInCPFast, TracInCPFastRandProj]: @@ -278,7 +293,7 @@ def __call__( tmpdir, loss_fn=loss_fn, batch_size=batch_size, - **constuctor_kwargs, + **constructor_kwargs, ) else: return self.data_influence_class( @@ -287,7 +302,7 @@ def __call__( tmpdir, batch_size=batch_size, loss_fn=loss_fn, - **constuctor_kwargs, + **constructor_kwargs, ) def check_annoy(self) -> None: