From eb50036e626f112dd58a046119228f85e78edbb9 Mon Sep 17 00:00:00 2001 From: Fulton Wang Date: Fri, 9 Dec 2022 13:16:02 -0800 Subject: [PATCH] add test loss (#1073) Summary: Pull Request resolved: https://github.com/pytorch/captum/pull/1073 - For all `TracInCPBase` implementations, this adds an additional `test_loss_fn` initialization argument, which is the loss function to apply to test examples when computing the influence of a training example on a test example. With this change,the influence score is a sum over terms for each checkpoint, where each term is the gradient of `loss_fn` for a given training example, multiplied with the gradient of `test_loss_fn` for a given test example. Before, `test_loss_fn` was assumed to be the same as `loss_fn`. - checks regarding the reduction type of both `loss_fn` and `test_loss_fn` are now handled by helper functions `_check_tracincp_loss_fn` and `_check_tracincp_fast_loss_fn`. - documentation is updated. one detail: for `TracInCP`, we assume that `sample_wise_grads_per_batch` is applied to both `loss_fn` and `test_loss_fn` (if provided), and this is mentioned in the documentation. - `test_tracin_regression.test_tracin_regression` is slightly modified - `DataInfluenceConstructor` now can explicitly pass in the same loss function for both `loss_fn` and `test_loss_fn` (done when `duplicate_loss_fn=True`). Doing so would have the same effect as not passing in `test_loss_fn`, so the original tests are also applied to the case when `duplicate_loss_fn=True`, as the expected behavior should be the same as before. - a new test, `test_tracin_regression.test_tracin_constant_test_loss_fn` is added. For all implementations of `TracInCPBase`, it checks that if `test_loss_fn` is a constant loss function, the influence scores are all 0's. This should be the case, because if `test_loss_fn` is constant, its gradients would all be 0's, so that training examples have 0 influence on test examples. Reviewed By: cyrjano Differential Revision: D41202866 fbshipit-source-id: 4258e3597b1f2e30ba5059bb7d440c8de7fd3ac1 --- captum/_utils/gradient.py | 11 +- captum/influence/_core/tracincp.py | 118 +++++++++++------- .../_core/tracincp_fast_rand_proj.py | 110 +++++++++++----- captum/influence/_utils/common.py | 92 +++++++++++++- .../influence/_core/test_tracin_regression.py | 94 ++++++++++++++ tests/influence/_utils/common.py | 25 +++- 6 files changed, 362 insertions(+), 88 deletions(-) diff --git a/captum/_utils/gradient.py b/captum/_utils/gradient.py index 4d885ff749..b6caea77cb 100644 --- a/captum/_utils/gradient.py +++ b/captum/_utils/gradient.py @@ -849,18 +849,21 @@ def _compute_jacobian_wrt_params_with_sample_wise_trick( if labels is not None and loss_fn is not None: loss = loss_fn(out, labels) # TODO: allow loss_fn to be Callable - if isinstance(loss_fn, Module) and hasattr(loss_fn, "reduction"): + if (isinstance(loss_fn, Module) or callable(loss_fn)) and hasattr( + loss_fn, "reduction" + ): + reduction = loss_fn.reduction # type: ignore msg0 = ( "Please ensure that loss_fn.reduction is set to `sum` or `mean`" ) - assert loss_fn.reduction != "none", msg0 + assert reduction != "none", msg0 msg1 = ( - f"loss_fn.reduction ({loss_fn.reduction}) does not match" + f"loss_fn.reduction ({reduction}) does not match" f"reduction type ({reduction_type}). Please ensure they are" " matching." ) - assert loss_fn.reduction == reduction_type, msg1 + assert reduction == reduction_type, msg1 msg2 = ( "Please ensure custom loss function is applying either a " "sum or mean reduction." diff --git a/captum/influence/_core/tracincp.py b/captum/influence/_core/tracincp.py index 453e99410a..653aaa6e18 100644 --- a/captum/influence/_core/tracincp.py +++ b/captum/influence/_core/tracincp.py @@ -26,6 +26,7 @@ from captum._utils.progress import NullProgress, progress from captum.influence._core.influence import DataInfluence from captum.influence._utils.common import ( + _check_loss_fn, _format_inputs_dataset, _get_k_most_influential_helper, _gradient_dot_product, @@ -102,6 +103,7 @@ def __init__( checkpoints_load_func: Callable = _load_flexible_state_dict, loss_fn: Optional[Union[Module, Callable]] = None, batch_size: Union[int, None] = 1, + test_loss_fn: Optional[Union[Module, Callable]] = None, ) -> None: r""" Args: @@ -152,6 +154,19 @@ def __init__( `train_dataset` is a Dataset. If `train_dataset` is a DataLoader, then `batch_size` is ignored as an argument. Default: 1 + test_loss_fn (Callable, optional): In some cases, one may want to use a + separate loss functions for training examples, i.e. those in + `train_dataset`, and for test examples, i.e. those + represented by the `inputs` and `targets` arguments to the + `influence` method. For example, if one wants to calculate the + influence score of a training example on a test example's + prediction for a fixed class, `test_loss_fn` could map from the + logits for all classes to the logits for a fixed class. + `test_loss_fn` needs to satisfy the same constraints as `loss_fn`. + If not provided, the loss function for test examples is assumed to + be the same as the loss function for training examples, i.e. + `loss_fn`. + Default: None """ self.model = model @@ -167,6 +182,8 @@ def __init__( self.checkpoints_load_func = checkpoints_load_func self.loss_fn = loss_fn + # If test_loss_fn not provided, it's assumed to be same as loss_fn + self.test_loss_fn = loss_fn if test_loss_fn is None else test_loss_fn self.batch_size = batch_size if not isinstance(train_dataset, DataLoader): @@ -489,6 +506,7 @@ def __init__( layers: Optional[List[str]] = None, loss_fn: Optional[Union[Module, Callable]] = None, batch_size: Union[int, None] = 1, + test_loss_fn: Optional[Union[Module, Callable]] = None, sample_wise_grads_per_batch: bool = False, ) -> None: r""" @@ -561,6 +579,24 @@ def __init__( `train_dataset` is a Dataset. If `train_dataset` is a DataLoader, then `batch_size` is ignored as an argument. Default: 1 + test_loss_fn (Callable, optional): In some cases, one may want to use a + separate loss functions for training examples, i.e. those in + `train_dataset`, and for test examples, i.e. those + represented by the `inputs` and `targets` arguments to the + `influence` method. For example, if one wants to calculate the + influence score of a training example on a test example's + prediction for a fixed class, `test_loss_fn` could map from the + logits for all classes to the logits for a fixed class. + `test_loss_fn` needs satisfy the same constraints as `loss_fn`. + Thus, the same checks that we apply to `loss_fn` are also applied + to `test_loss_fn`, if the latter is provided. Note that the + constraints on both `loss_fn` and `test_loss_fn` both depend on + `sample_wise_grads_per_batch`. This means `loss_fn` and + `test_loss_fn` must either both be "per-example" loss functions, + or both be "reduction" loss functions. If not provided, the loss + function for test examples is assumed to be the same as the loss + function for training examples, i.e. `loss_fn`. + Default: None sample_wise_grads_per_batch (bool, optional): PyTorch's native gradient computations w.r.t. model parameters aggregates the results for a batch and does not allow to access sample-wise gradients w.r.t. @@ -590,51 +626,23 @@ def __init__( checkpoints_load_func, loss_fn, batch_size, + test_loss_fn, ) self.sample_wise_grads_per_batch = sample_wise_grads_per_batch - # If we are able to access the reduction used by `loss_fn`, we check whether - # the reduction is compatible with `sample_wise_grads_per_batch` - if isinstance(loss_fn, Module) and hasattr( - loss_fn, "reduction" - ): # TODO: allow loss_fn to be Callable - if self.sample_wise_grads_per_batch: - assert loss_fn.reduction in ["sum", "mean"], ( - 'reduction for `loss_fn` must be "sum" or "mean" when ' - "`sample_wise_grads_per_batch` is True" - ) - self.reduction_type = str(loss_fn.reduction) - else: - assert loss_fn.reduction == "none", ( - 'reduction for `loss_fn` must be "none" when ' - "`sample_wise_grads_per_batch` is False" - ) - else: - # if we are unable to access the reduction used by `loss_fn`, we warn - # the user about the assumptions we are making regarding the reduction - # used by `loss_fn` - if self.sample_wise_grads_per_batch: - warnings.warn( - 'Since `loss_fn` has no "reduction" attribute, and ' - "`sample_wise_grads_per_batch` is True, the implementation assumes " - 'that `loss_fn` is a "reduction" loss function that reduces the ' - "per-example losses by taking their *sum*. If `loss_fn` " - "instead reduces the per-example losses by taking their mean, " - 'please set the reduction attribute of `loss_fn` to "mean", i.e. ' - '`loss_fn.reduction = "mean"`. Note that if ' - "`sample_wise_grads_per_batch` is True, the implementation " - "assumes the reduction is either a sum or mean reduction." - ) - self.reduction_type = "sum" - else: - warnings.warn( - 'Since `loss_fn` has no "reduction" attribute, and ' - "`sample_wise_grads_per_batch` is False, the implementation " - 'assumes that `loss_fn` is a "per-example" loss function (see ' - "documentation for `loss_fn` for details). Please ensure that " - "this is the case." - ) + # check `loss_fn` + self.reduction_type = _check_loss_fn( + self, loss_fn, "loss_fn", sample_wise_grads_per_batch + ) + # check `test_loss_fn` if it was provided + self.test_reduction_type = ( + self.reduction_type + if test_loss_fn is None + else _check_loss_fn( + self, test_loss_fn, "test_loss_fn", sample_wise_grads_per_batch + ) + ) r""" TODO: Either restore model state after done (would have to place functionality @@ -790,11 +798,15 @@ def get_checkpoint_contribution(checkpoint): input_jacobians = self._basic_computation_tracincp( inputs, targets, + self.test_loss_fn, + self.test_reduction_type, ) return ( _gradient_dot_product( input_jacobians, - self._basic_computation_tracincp(batch[0:-1], batch[-1]), + self._basic_computation_tracincp( + batch[0:-1], batch[-1], self.loss_fn, self.reduction_type + ), ) * learning_rate ) @@ -1055,7 +1067,10 @@ def get_checkpoint_contribution(checkpoint): for batch in _inputs_dataset: layer_jacobians = self._basic_computation_tracincp( - batch[0:-1], batch[-1] + batch[0:-1], + batch[-1], + self.loss_fn, + self.reduction_type, ) # Note that all variables in this function are for an entire batch. @@ -1196,11 +1211,14 @@ def _basic_computation_tracincp( self, inputs: Tuple[Any, ...], targets: Optional[Tensor] = None, + loss_fn: Optional[Union[Module, Callable]] = None, + reduction_type: Optional[str] = None, ) -> Tuple[Tensor, ...]: """ For instances of TracInCP, computation of influence scores or self influence scores repeatedly calls this function for different checkpoints - and batches. + and batches. In particular, this function computes the jacobian of a loss + function w.r.t. parameters in the `layers` initialization argument. Args: @@ -1210,20 +1228,26 @@ def _basic_computation_tracincp( that `model(*inputs)` produces the predictions for the batch. targets (tensor or None): If computing influence scores on a loss function, these are the labels corresponding to the batch `inputs`. + Default: none + loss_fn (Callable, optional): The loss function to use when computing the + jacobian. + reduction_type (str, optional): The reduction type of `loss_fn`. This + argument is only used if `sample_wise_grads_per_batch` was true in + initialization. """ if self.sample_wise_grads_per_batch: return _compute_jacobian_wrt_params_with_sample_wise_trick( self.model, inputs, targets, - self.loss_fn, - self.reduction_type, + loss_fn, + reduction_type, self.layer_modules, ) return _compute_jacobian_wrt_params( self.model, inputs, targets, - self.loss_fn, + loss_fn, self.layer_modules, ) diff --git a/captum/influence/_core/tracincp_fast_rand_proj.py b/captum/influence/_core/tracincp_fast_rand_proj.py index d0c8eeee7d..57b4591a81 100644 --- a/captum/influence/_core/tracincp_fast_rand_proj.py +++ b/captum/influence/_core/tracincp_fast_rand_proj.py @@ -16,6 +16,7 @@ TracInCPBase, ) from captum.influence._utils.common import ( + _check_loss_fn, _DatasetFromList, _format_inputs_dataset, _get_k_most_influential_helper, @@ -88,6 +89,7 @@ def __init__( checkpoints_load_func: Callable = _load_flexible_state_dict, loss_fn: Optional[Union[Module, Callable]] = None, batch_size: Union[int, None] = 1, + test_loss_fn: Optional[Union[Module, Callable]] = None, vectorize: bool = False, ) -> None: r""" @@ -153,6 +155,20 @@ def __init__( `train_dataset` is a Dataset. If `train_dataset` is a DataLoader, then `batch_size` is ignored as an argument. Default: 1 + test_loss_fn (Callable, optional): In some cases, one may want to use a + separate loss functions for training examples, i.e. those in + `train_dataset`, and for test examples, i.e. those + represented by the `inputs` and `targets` arguments to the + `influence` method. For example, if one wants to calculate the + influence score of a training example on a test example's + prediction for a fixed class, `test_loss_fn` could map from the + logits for all classes to the logits for a fixed class. + `test_loss_fn` needs satisfy the same constraints as `loss_fn`. + Thus, the same checks that we apply to `loss_fn` are also applied + to `test_loss_fn`, if the latter is provided. If not provided, the + loss function for test examples is assumed to be the same as the + loss function for training examples, i.e. `loss_fn`. + Default: None vectorize (bool, optional): Flag to use experimental vectorize functionality for `torch.autograd.functional.jacobian`. Default: False @@ -165,6 +181,7 @@ def __init__( checkpoints_load_func, loss_fn, batch_size, + test_loss_fn, ) self.vectorize = vectorize @@ -179,29 +196,14 @@ def __init__( assert loss_fn is not None, "loss function must not be none" - # If we are able to access the reduction used by `loss_fn`, we check whether - # the reduction is either 'sum' or 'mean', as required - if isinstance(loss_fn, Module) and hasattr( - loss_fn, "reduction" - ): # TODO: allow loss_fn to be Callable - assert loss_fn.reduction in [ - "sum", - "mean", - ], 'reduction for `loss_fn` must be "sum" or "mean"' - self.reduction_type = str(loss_fn.reduction) - else: - # if we are unable to access the reduction used by `loss_fn`, we warn - # the user about the assumptions we are making regarding the reduction - # used by `loss_fn` - warnings.warn( - 'Since `loss_fn` has no "reduction" attribute, the implementation ' - 'assumes that `loss_fn` is a "reduction" loss function that ' - "reduces the per-example losses by taking their *sum*. If " - "`loss_fn` instead reduces the per-example losses by taking their " - 'mean, please set the reduction attribute of `loss_fn` to "mean", ' - 'i.e. `loss_fn.reduction = "mean"`.' - ) - self.reduction_type = "sum" + # check `loss_fn` + self.reduction_type = _check_loss_fn(self, loss_fn, "loss_fn") + # check `test_loss_fn` if it was provided + self.test_reduction_type = ( + self.reduction_type + if test_loss_fn is None + else _check_loss_fn(self, test_loss_fn, "test_loss_fn") + ) @log_usage() def influence( # type: ignore[override] @@ -340,10 +342,16 @@ def get_checkpoint_contribution(checkpoint): self, inputs, targets, + self.test_loss_fn, + self.test_reduction_type, ) src_jacobian, src_layer_input = _basic_computation_tracincp_fast( - self, batch[0:-1], batch[-1] + self, + batch[0:-1], + batch[-1], + self.loss_fn, + self.reduction_type, ) return ( _tensor_batch_dot( @@ -603,7 +611,11 @@ def get_checkpoint_contribution(checkpoint): for batch in _inputs_dataset: batch_jacobian, batch_layer_input = _basic_computation_tracincp_fast( - self, batch[0:-1], batch[-1] + self, + batch[0:-1], + batch[-1], + self.loss_fn, + self.reduction_type, ) checkpoint_contribution.append( @@ -722,11 +734,18 @@ def _basic_computation_tracincp_fast( influence_instance: TracInCPFast, inputs: Tuple[Any, ...], targets: Tensor, + loss_fn: Optional[Union[Module, Callable]] = None, + reduction_type: Optional[str] = None, ): """ For instances of TracInCPFast and children classes, computation of influence scores or self influence scores repeatedly calls this function for different checkpoints - and batches. + and batches. These computations involve a loss function. If `test` is True, the + loss function is `self.loss_fn`. If `test` is False, the loss function is + `self.test_loss_fn`. These two attributes were set in initialization, with + `self.loss_fn` equal to the `loss_fn` initialization argument, and + `self.test_loss_fn` equal to the `test_loss_fn` initialization argument if it was + provided, and `loss_fn` otherwise. Args: @@ -742,6 +761,11 @@ def _basic_computation_tracincp_fast( that `model(*inputs)` produces the predictions for the batch. targets (Tensor): If computing influence scores on a loss function, these are the labels corresponding to the batch `inputs`. + loss_fn (Callable, optional): The loss function to use when computing the + jacobian. + reduction_type (str, optional): The reduction type of `loss_fn`. This argument + is only used if `sample_wise_grads_per_batch` was true in + initialization of `influence_instance`. Returns: (input_jacobians, layer_inputs) (tuple): `input_jacobians` is a 2D tensor, @@ -773,17 +797,17 @@ def _capture_inputs(layer, input, output) -> None: ) out = influence_instance.model(*inputs) - assert influence_instance.loss_fn is not None, "loss function is required" - assert influence_instance.reduction_type in [ + assert loss_fn is not None, "loss function is required" + assert reduction_type in [ "sum", "mean", ], 'reduction_type must be either "mean" or "sum"' input_jacobians = _jacobian_loss_wrt_inputs( - influence_instance.loss_fn, + loss_fn, out, targets, influence_instance.vectorize, - influence_instance.reduction_type, + reduction_type, ) handle.remove() @@ -863,6 +887,7 @@ def __init__( checkpoints_load_func: Callable = _load_flexible_state_dict, loss_fn: Optional[Union[Module, Callable]] = None, batch_size: Union[int, None] = 1, + test_loss_fn: Optional[Union[Module, Callable]] = None, vectorize: bool = False, nearest_neighbors: Optional[NearestNeighbors] = None, projection_dim: int = None, @@ -927,6 +952,19 @@ def __init__( `train_dataset` is a Dataset. If `train_dataset` is a DataLoader, then `batch_size` is ignored as an argument. Default: 1 + test_loss_fn (Callable, optional): In some cases, one may want to use a + separate loss functions for training examples, i.e. those in + `train_dataset`, and for test examples, i.e. those + represented by the `inputs` and `targets` arguments to the + `influence` method. For example, if one wants to calculate the + influence score of a training example on a test example's + prediction for a fixed class, `test_loss_fn` could map from the + logits for all classes to the logits for a fixed class. + `test_loss_fn` needs satisfy the same constraints as `loss_fn`. + Thus, the same checks that we apply to `loss_fn` are also applied + to `test_loss_fn`, if the latter is provided. If not provided, the + loss function for test examples is assumed to be the same as the + loss function for training examples, i.e. `loss_fn`. vectorize (bool): Flag to use experimental vectorize functionality for `torch.autograd.functional.jacobian`. Default: False @@ -970,6 +1008,7 @@ def __init__( checkpoints_load_func, loss_fn, batch_size, + test_loss_fn, vectorize, ) @@ -1038,6 +1077,7 @@ def _influence( # type: ignore[override] _DatasetFromList([inputs_batch]), shuffle=False, batch_size=None ), self.projection_quantities, + test=True, ) src_projections = self.src_intermediate_quantities @@ -1088,6 +1128,7 @@ def _get_k_most_influential( # type: ignore[override] _DatasetFromList([inputs_batch]), shuffle=False, batch_size=None ), self.projection_quantities, + test=True, ) multiplier = 1 if proponents else -1 @@ -1326,6 +1367,8 @@ def _set_projections_tracincp_fast_rand_proj( self, batch[0:-1], batch[-1], + self.loss_fn, + self.reduction_type, ) jacobian_dim = batch_jacobians.shape[ @@ -1398,6 +1441,7 @@ def _get_intermediate_quantities_tracincp_fast_rand_proj( self, inputs_dataset: Union[Tuple[Any, ...], DataLoader], projection_quantities: Optional[Tuple[torch.Tensor, torch.Tensor]], + test: bool = False, ) -> torch.Tensor: r""" This method computes vectors that can be used to compute influence. (see @@ -1422,6 +1466,10 @@ def _get_intermediate_quantities_tracincp_fast_rand_proj( projection_quantities (tuple or None): Is either the two tensors defining the randomized projections to apply, or None, which means no projection is to be applied. + test (bool): If True, the intermediate quantities are computed using + `self.test_loss_fn`. Otherwise, they are computed using + `self.loss_fn`. + Default: False Returns: intermediate_quantities (Tensor): A tensor of dimension @@ -1490,6 +1538,8 @@ def _get_intermediate_quantities_tracincp_fast_rand_proj( self, batch[0:-1], batch[-1], + self.test_loss_fn, + self.test_reduction_type, ) # if doing projection, project those two quantities diff --git a/captum/influence/_utils/common.py b/captum/influence/_utils/common.py index b43e3aa553..e1a6e27f8f 100644 --- a/captum/influence/_utils/common.py +++ b/captum/influence/_utils/common.py @@ -1,13 +1,15 @@ #!/usr/bin/env python3 - import warnings -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, TYPE_CHECKING, Union import torch import torch.nn as nn from captum._utils.common import _parse_version from captum._utils.progress import progress +if TYPE_CHECKING: + from captum.influence._core.tracincp import TracInCPBase + from torch import Tensor from torch.nn import Module from torch.utils.data import DataLoader, Dataset @@ -419,3 +421,89 @@ def _self_influence_by_batches_helper( for batch in inputs_dataset ] ) + + +def _check_loss_fn( + influence_instance: "TracInCPBase", + loss_fn: Optional[Union[Module, Callable]], + loss_fn_name: str, + sample_wise_grads_per_batch: Optional[bool] = None, +) -> str: + """ + This checks whether `loss_fn` satisfies the requirements assumed of all + implementations of `TracInCPBase`. It works regardless of whether the + implementation has the `sample_wise_grads_per_batch` attribute. + It returns the reduction type of the loss_fn. If `sample_wise_grads_per_batch` + if not provided, we assume the implementation does not have that attribute. + """ + # if `loss_fn` is `None`, there is nothing to check. then, the reduction type is + # only used by `_compute_jacobian_wrt_params_with_sample_wise_trick`, where + # reduction type should be "sum" if `loss_fn` is `None`. + if loss_fn is None: + return "sum" + + # perhaps since `Module` is an implementation of `Callable`, this has redundancy + assert isinstance(loss_fn, Module) or callable(loss_fn) + + reduction_type = "none" + + # If we are able to access the reduction used by `loss_fn`, we check whether + # the reduction is compatible with `sample_wise_grads_per_batch`, if it has the + # attribute. + if hasattr(loss_fn, "reduction"): + reduction = loss_fn.reduction # type: ignore + if sample_wise_grads_per_batch is None: + assert reduction in [ + "sum", + "mean", + ], 'reduction for `loss_fn` must be "sum" or "mean"' + reduction_type = str(reduction) + elif sample_wise_grads_per_batch: + assert reduction in ["sum", "mean"], ( + 'reduction for `loss_fn` must be "sum" or "mean" when ' + "`sample_wise_grads_per_batch` is True" + ) + reduction_type = str(reduction) + else: + assert reduction == "none", ( + 'reduction for `loss_fn` must be "none" when ' + "`sample_wise_grads_per_batch` is False" + ) + else: + # if we are unable to access the reduction used by `loss_fn`, we warn + # the user about the assumptions we are making regarding the reduction + # used by `loss_fn` + if sample_wise_grads_per_batch is None: + warnings.warn( + f'Since `{loss_fn_name}` has no "reduction" attribute, the ' + f'implementation assumes that `{loss_fn_name}` is a "reduction" loss ' + "function that reduces the per-example losses by taking their *sum*. " + f"If `{loss_fn_name}` instead reduces the per-example losses by " + f"taking their mean, please set the reduction attribute of " + f'`{loss_fn_name}` to "mean", i.e. ' + f'`{loss_fn_name}.reduction = "mean"`.' + ) + reduction_type = "sum" + elif sample_wise_grads_per_batch: + warnings.warn( + f"Since `{loss_fn_name}`` has no 'reduction' attribute, and " + "`sample_wise_grads_per_batch` is True, the implementation assumes " + f"that `{loss_fn_name}` is a 'reduction' loss function that reduces " + f"the per-example losses by taking their *sum*. If `{loss_fn_name}` " + "instead reduces the per-example losses by taking their mean, " + f'please set the reduction attribute of `{loss_fn_name}` to "mean", ' + f'i.e. `{loss_fn_name}.reduction = "mean"`. Note that if ' + "`sample_wise_grads_per_batch` is True, the implementation " + "assumes the reduction is either a sum or mean reduction." + ) + reduction_type = "sum" + else: + warnings.warn( + f'Since `{loss_fn_name}` has no "reduction" attribute, and ' + "`sample_wise_grads_per_batch` is False, the implementation " + f'assumes that `{loss_fn_name}` is a "per-example" loss function (see ' + f"documentation for `{loss_fn_name}` for details). Please ensure " + "that this is the case." + ) + + return reduction_type diff --git a/tests/influence/_core/test_tracin_regression.py b/tests/influence/_core/test_tracin_regression.py index 21de267332..147566bc22 100644 --- a/tests/influence/_core/test_tracin_regression.py +++ b/tests/influence/_core/test_tracin_regression.py @@ -110,6 +110,24 @@ def _test_tracin_regression_setup( projection_dim=1, ), ), + ( + "check_idx", + "mean", + DataInfluenceConstructor( + TracInCPFast, + name="TracInCPFastDuplicateLossFn", + duplicate_loss_fn=True, + ), + ), # add a test where `duplicate_loss_fn` is True + ( + "check_idx", + "mean", + DataInfluenceConstructor( + TracInCPFastRandProj, + name="TracInCPFastRandProjDuplicateLossFn", + duplicate_loss_fn=True, + ), # add a test where `duplicate_loss_fn` is True + ), ]: if not (mode == "sample_wise_trick" and use_gpu): param_list.append((reduction, constructor, mode, dim, use_gpu)) @@ -404,3 +422,79 @@ def test_tracin_identity_regression( assertTensorAlmostEqual( self, train_scores, train_scores_tracin_sample_wise_trick ) + + @parameterized.expand( + [ + ("none", "none", DataInfluenceConstructor(TracInCP)), + ( + "mean", + "mean", + DataInfluenceConstructor(TracInCP, sample_wise_grads_per_batch=True), + ), + ("sum", "sum", DataInfluenceConstructor(TracInCPFast)), + ("mean", "mean", DataInfluenceConstructor(TracInCPFast)), + ("sum", "sum", DataInfluenceConstructor(TracInCPFastRandProj)), + ("mean", "mean", DataInfluenceConstructor(TracInCPFastRandProj)), + ], + name_func=build_test_name_func(), + ) + def test_tracin_constant_test_loss_fn( + self, + reduction: Optional[str], + test_reduction: Optional[str], + tracin_constructor: Callable, + ) -> None: + """ + All implementations of `TracInCPBase` can accept `test_loss_fn` in + initialization, which sets the loss function applied to test examples, which + can thus be different from the loss function applied to training examples. + This test passes `test_loss_fn` to be a constant function. Then, the influence + scores should all be 0, because gradients w.r.t. `test_loss_fn` will all be 0. + It re-uses the dataset and model from `test_tracin_identity_regression`. + + The reduction for `loss_fn` and `test_loss_fn` initialization arguments is + the same for all parameterized tests, for simplicity, and also because for + `TracInCP`, both loss functions must both be reduction loss functions (i.e. + reduction is "mean" or "sum"), or both be per-example loss functions (i.e. + reduction is "none"). Recall that for `TracInCP`, the + `sample_wise_grads_per_batch` initialization argument determines which of + those cases holds. + """ + with tempfile.TemporaryDirectory() as tmpdir: + + batch_size = 4 + + dataset, net = self._test_tracin_identity_regression_setup(tmpdir) + + train_inputs = dataset.samples + train_labels = dataset.labels + + self.assertTrue(callable(tracin_constructor)) + + self.assertTrue(isinstance(reduction, str)) + criterion = nn.MSELoss(reduction=cast(str, reduction)) + + # the output of `net`, i.e. `input` for the loss functions below, is a + # batch_size x 1 2D tensor + if test_reduction == "none": + # loss function returns 1D tensor of all 0's, so is constant + def test_loss_fn(input, target): + return input.squeeze() * 0.0 + + elif test_reduction in ["sum", "mean"]: + # loss function returns scalar tensor of all 0's, so is constant + def test_loss_fn(input, target): + return input.mean() * 0.0 + + tracin = tracin_constructor( + net, + dataset, + tmpdir, + batch_size, + criterion, + test_loss_fn=test_loss_fn, + ) + + # check influence scores of training data. they should all be 0 + train_scores = tracin.influence(train_inputs, train_labels, k=None) + assertTensorAlmostEqual(self, train_scores, torch.zeros(train_scores.shape)) diff --git a/tests/influence/_utils/common.py b/tests/influence/_utils/common.py index 999dc6404f..dbfc0de550 100644 --- a/tests/influence/_utils/common.py +++ b/tests/influence/_utils/common.py @@ -248,10 +248,19 @@ class DataInfluenceConstructor: data_influence_class: type def __init__( - self, data_influence_class: type, name: Optional[str] = None, **kwargs + self, + data_influence_class: type, + name: Optional[str] = None, + duplicate_loss_fn: bool = False, + **kwargs, ) -> None: + """ + if `duplicate_loss_fn` is True, will explicitly pass the provided `loss_fn` as + the `test_loss_fn` when constructing the TracInCPBase instance + """ self.data_influence_class = data_influence_class self.name = name if name else data_influence_class.__name__ + self.duplicate_loss_fn = duplicate_loss_fn self.kwargs = kwargs def __repr__(self) -> str: @@ -266,8 +275,14 @@ def __call__( loss_fn: Optional[Union[Module, Callable]], **kwargs, ) -> DataInfluence: - constuctor_kwargs = self.kwargs.copy() - constuctor_kwargs.update(kwargs) + constructor_kwargs = self.kwargs.copy() + constructor_kwargs.update(kwargs) + # if `self.duplicate_loss_fn`, explicitly pass in `loss_fn` as `test_loss_fn` + # when constructing the instance. Doing so should not affect the behavior of + # the returned tracincp instance, since if `test_loss_fn` is not passed in, + # the constructor sets `test_loss_fn` to be the same as `loss_fn` + if self.duplicate_loss_fn: + constructor_kwargs["test_loss_fn"] = loss_fn if self.data_influence_class is TracInCPFastRandProj: self.check_annoy() if self.data_influence_class in [TracInCPFast, TracInCPFastRandProj]: @@ -278,7 +293,7 @@ def __call__( tmpdir, loss_fn=loss_fn, batch_size=batch_size, - **constuctor_kwargs, + **constructor_kwargs, ) else: return self.data_influence_class( @@ -287,7 +302,7 @@ def __call__( tmpdir, batch_size=batch_size, loss_fn=loss_fn, - **constuctor_kwargs, + **constructor_kwargs, ) def check_annoy(self) -> None: