From b4d1c43ac0d4690e7623bdb4b5ffd96fb5596976 Mon Sep 17 00:00:00 2001 From: Fulton Wang Date: Mon, 4 Dec 2023 06:38:32 -0800 Subject: [PATCH] add `NaiveInfluenceFunction` (#1214) Summary: # Overview This diff, along with D42006733, implement 2 different implementations that both calculate the "infinitesimal" influence score as defined in the paper ["Understanding Black-box Predictions via Influence Functions"](https://arxiv.org/pdf/1703.04730.pdf). - `NaiveInfluenceFunction`: a computationally slow but exact implementation that is useful for obtaining "ground-truth" (though, note that influence scores themselves are an approximation of the effect of removing then retraining). Several papers actually use this approach, i.e. ["Learning Augmentation Network via Influence Functions"](https://openaccess.thecvf.com/content_CVPR_2020/papers/Lee_Learning_Augmentation_Network_via_Influence_Functions_CVPR_2020_paper.pdf), ["Quantifying and Mitigating the Impact of Label Errors on Model Disparity Metrics"](https://openreview.net/forum?id=RUzSobdYy0V), ["Achieving Fairness at No Utility Cost via Data Reweighting with Influence"](https://proceedings.mlr.press/v162/li22p/li22p.pdf) - `ArnoldiInfluenceFunction`: This is a computationally efficient implementation described in the paper ["Scaling Up Influence Functions"](https://arxiv.org/pdf/2112.03052.pdf) by Schioppa et al. These [slides](https://docs.google.com/presentation/d/1yJ86FkJO1IZn7YzFYpkJUJUBqaLynDJCbCWlKKglv-w/edit#slide=id.p) give a brief summary of it. This diff is rebased on top of D41324297, which implements the new API. Again, note that the 2 above implementations are implemented across 2 diffs, for easier review, though they are jointly described here. # What is the "infinitesimal" influence score More details on the "infinitesimal" influence score: This "infinitesimal" influence score approximately answers the question if a given training example were infinitesimally down-weighted and the model re-trained to optimality, how much would the loss on a given test example change. Mathematically, the aforementioned influence score is given by `\nabla_\theta L(x)' H^{-1} \nabla_\theta L(z)`, where `\nabla_\theta L(x)` is the gradient of the loss, considering only training example `x` with respect to (a subset of) model parameters `\theta`, `\nabla_\theta L(z)` is the analogous quantity for a test example `z`, and `H` is the Hessian of the (subset of) model parameters at a given model checkpoint. # What the two implementations have in common Both implementations compute a low-rank approximation of the inverse Hessian, i.e. a tall and skinny (with width k) matrix `R` such that `H^{-1} \approx RR'`, where k is small. In particular, let `L` be the matrix of width k whose columns contain the top-k eigenvectors of `H`, and let `V` be the k by k matrix whose diagonals contain the corresponding eigenvalues. Both implementations let `R=LV^{-1}L'`. Thus, the core computational step is computing the top-k eigenvalues / eigenvectors. This approximation is useful for several reasons: - It avoids numerical issues associated with inverting small eigenvalues - Since the influence score is given by `\nabla_\theta L(x)' H^{-1} \nabla_\theta L(z)`, which is approximated by `(\nabla_\theta L(x)' R) (\nabla_\theta L(z)' R)`, we can compute an "influence embedding" for a given example `x`, `\nabla_\theta L(x)' R`, such that the influence score of one example on another is approximately the dot-product of their respective embeddings. Because k is small, i.e. 50, these influence embeddings are low-dimensional. - Even for large models, we can store `R` in memory, provided k is small. This means influence embeddings (and thus influence scores) can be efficiently computed by doing a backwards pass to compute `\nabla_\theta L(x)` and then multiplying by `R'`. This is orders of magnitude faster than the previous LISSA approach of Koh et al, which to compute the influence score involving a given example, need to compute Hessian-vector products involving on the order of 10^4 examples. The implementations differ in how they compute the top-k eigenvalues / eigenvectors. # How `NaiveInfluenceFunction` computes the top-k eigenvalues / eigenvectors It is "naive" in that it computes the top-k eigenvalues / eigenvectors by explicitly forming the Hessian, converting it to a 2D tensor, computing its eigenvectors / eigenvalues, and then sorting. See documentation of the `_set_projections_naive_influence_function` method for more details. # How `ArnoldiInfluenceFunction` computes the top-k eigenvalues / eigenvectors The key novelty of the approach by Schioppa et al is that it uses the Arnoldi iteration to find the top-k eigenvalues / eigenvectors of the Hessian without explicitly forming the Hessian. In more detail, the approach first runs the Arnoldi iteration, which only requires the ability to compute Hessian-vector products, to find a Krylov subspace of moderate dimension, i.e. 200. It then finds the top-k eigenvalues / eigenvectors of the restriction of the Hessian to the subspace, where k is small, i.e. 50. Finally, it expresses the eigenvectors in the original basis. This approach for finding the top-k eigenvalues / eigenvectors is justified by the property of the Arnoldi iteration, that the Krylov subspace it returns tends to contain the top eigenvectors. This implementation does incur some one-time overhead in `__init__`, where it runs the Arnoldi iteration to calculate `R`. After that overhead, calculation of influence scores is quick, only requiring a backwards pass and multiplication, per example. Unlike `NaiveInfluenceFunction`, this implementation does not flatten any parameters, as the 2D Hessian is never formed, and Pytorch's Hessian-vector implementation (`torch.autograd.functional.hvp`) allows the input and output vector to be a tuple of tensors. Avoiding flattening / unflattening parameters brings scalability gains. # High-level organization of the two implementations Because of the common logic of the two implementations, they share the same high-level organization. - Both implementations accept a `hessian_dataset` initialization argument. This is because "infinitesimal" influence scores depend on the Hessian, which is in practice, computed not over the entire training data, but over a subset of it, which is specified by `hessian_dataset`. - in `__init__`, `NaiveInfluenceFunction` and `ArnoldiInfluenceFunction` both compute `R` using private helper methods `_set_projections_naive_influence_function` and `_set_projections_arnoldi_influence_function`, respectively. - `R` is used by their respective `compute_intermediate_quantities` methods to compute influence embeddings. - Because influence scores (and self-influence scores) are computed by first computing influence embeddings, the `_influence` and `self_influence` methods for both implementations call the `_influence_helper_intermediate_quantities_influence_function` and `_self_influence_helper_intermediate_quantities_influence_function` helper functions, which both assume the implementation implements the `compute_intermediate_quantities` method. # Reason for inheritance structure `InfluenceFunctionBase` refers to any implementation that computes the "infinitesimal" influence score (as opposed to `TracInCPBase`, which computes the checkpoint-based definition of influence score). Thus the different "base" implementations implement differently-defined influence scores, and children of a base implementation compute the same influence score in different ways. `IntermediateQuantitiesInfluenceFunction` refers to implementations of `InfluenceFunctionBase` that implement the `compute_intermediate_quantities` method. The reason we don't let `NaiveInfluenceFunction` and `ArnoldiInfluenceFunction` directly inherit from `InfluenceFunctionBase` is that their implementations of `influence` and `self_influence` are actually identical (though for logging reasons, we cannot just move those methods into `IntermediateQuantitiesInfluenceFunction`). In the future, there may be implementations of `InfluenceFunctionBase` that do *not* inherit from `IntermediateQuantitiesInfluenceFunction`, i.e. the LISSA approach of Koh et al. # Key helper methods - `captum._utils._stateless.functional_call` is copy pasted from [Pytorch 13.0 implementation](https://github.com/pytorch/pytorch/blob/17202b363780a06ae07e5cecceffaae6418ad6f8/torch/nn/utils/stateless.py) so that the user does not need to use the latest Pytorch version, and turns a Pytorch `module` into a function whose inputs are the parameters of the `module` (represented as a dictionary). This function is used to compute the Hessian in `NaiveInfluenceFunction`, and Hessian-vector products in `ArnoldiInfluenceFunction`. - `_compute_dataset_func` is used by `NaiveInfluenceFunction` to compute the Hessian over `hessian_dataset`. This is done by calculating the Hessian over individual batches, and then summing them up. One complication is that `torch.autograd.functional.hessian`, which we use to compute Hessians, does not return the Hessian as a 2D tensor unless the function we seek the Hessian of accepts a 1D tensor. Therefore, we need to define a function of the model's parameters whose input is the parameters, *flattened* into a 1D tensor (and a batch). This function is given by the factory returned by `naive_influnce_function._flatten_forward_factory`. - `_parameter_arnoldi` performs the Arnoldi iteration and is used by `ArnoldiInfluenceFunction`. It differs from a "traditional" implementation in that the Hessian-vector function it accepts does not map from 1D tensor to 1D tensor. Instead, it maps from tuple of tensor to tuple of tensor, because the "vector" in this case represents a parameter setting, which Pytorch represents as a tuple of tensor. Therefore, all the operations work with tuple of tensors, which required defining various operations for tuple of tensors in `captum.influence._utils.common`. This method returns a basis for the Krylov subspace, and the restriction of the Hessian to it. - `_parameter_distill` takes the output of `_parameter_distill`, and returns the (approximate) top-k eigenvalues / eigenvectors of the Hessian. This is what is needed to compute `R`. It is used by `ArnoldiInfluenceFunction`. # Tests We create a new test file `tests.influence._core.test_arnoldi_influence.py`, which defines the class `TestArnoldiInfluence` implementing the following tests: #### Tests used only by `NaiveInfluenceFunction`, i.e. appear in this diff: - `test_matches_linear_regression` compares the influence scores and self-influence scores produced by a given implementation with analytically-calculated counterparts for a model where the exact influence scores are known - linear regression. Different reductions for loss function - 'mean', 'sum', 'none' are tested. Here, we test the following implementation: -- `NaiveInfluenceFunction` with `projection_dim=None`, i.e. we use the inverse Hessian, not a low-rank approximation of it. In this case, the influence scores should equal the analytically calculated ones, modulo numerical issues. - `test_flatten_unflattener`: a common operation is flattening a tuple of tensors and unflattening it (the inverse operation). This tests checks that flattening and unflattening a tuple of tensors gives the original tensor. - `test_top_eigen`: a common operation is finding the the top eigenvectors / eigenvalues of a possibly non-symmetric matrix. Since `torch.linalg.eig` doesn't sort the eigenvalues, we make a wrapper that does do it. This checks that the wrapper is working properly. #### Tests used only by `ArnoldiInfluenceFunction`, i.e. appear in next diff: - `test_parameter_arnoldi` checks that `_parameter_arnoldi` is correct. In particular, it checks that the top-`k` eigenvalues of the restriction of `A` to a Krylov subspace (the `H` returned by `_parameter_arnoldi`) agree with those of the original matrix. This is a property we expect of the Arnoldi iteration that `_parameter_arnoldi` implements. - `test_parameter_distill` checks that `_parameter_distill` is correct. In particular, it checks that the eigenvectors corresponding to the top eigenvalues it returns agree with the top eigenvectors of `A`. This is the property we require of `distill`, because we use the top eigenvectors (and eigenvalues) of (implicitly-defined) `A` to calculate a low-rank approximation of its inverse. - `test_matches_linear_regression` where the implementation tested is the following: -- `ArnoldiInfluenceFunction` with `arnoldi_dim` and `projection_dim` set to a large value. The Krylov subspace should contain the largest eigenvectors because `arnoldi_dim` is large, and `projection_dim` is not too large relative to `arnoldi_dim`, but still large on an absolute level. - When `projection_dim` is small, `ArnoldiInfluenceFunction` and `NaiveInfluenceFunction` should produce the same influence scores, provided `arnoldi_dim` for `ArnoldiInfluenceFunction` is large, since in this case, the top-k eigenvalues / eigenvectors for the two implementations should agree. This agreement is tested in `test_compare_implementations_trained_NN_model_and_data` and `test_compare_implementations_random_model_and_data` for a trained and untrained 2-layer NN, respectively. # Minor changes / functionalities / tests - `test_tracin_intermediate_quantities_aggregate`, `test_tracin_self_influence`, `test_tracin_identity_regression` are applied to both implementations - `_set_active_params` now extracts the layers to consider when computing gradients and sets their `requires_grad`. This refactoring is done since the same logic is used by `TracInCPBase` and `InfluenceFunctionBase`. - some helpers are moved from `tracincp` to `captum.influence._utils.common` - a separate `test_loss_fn` initialization argument is supported, and both implementations are now tested in `TestTracinRegression.test_tracin_constant_test_loss_fn` - `compute_intermediate_quantities` for both implementations support the `aggregate` option. This means that both implementations can be used with D40386079, the validation influence FAIM workflow. - given the aforementioned tests, testing now generates multiple kinds of models / data. The ability to do so is added to `get_random_model_and_data`. The specific model (and its parameters) are specified by the `model_type` argument. Before, the method only supports the random 2-layer NN. Now, it also supports an optimally-trained linear regression, and a 2-layer NN trained with SGD. - `TracInCP` and implementations of `InfluenceFunctionBase` all accept a `sample_wise_grads_per_batch` option, and have the same requirements on the loss function. Thus, `_check_loss_fn_tracincp`, which previously performed those checks, is renamed `_check_loss_fn_sample_wise_grads_per_batch` and moved to `captum.influence._utils.common`. Similarly, those implementations all need to compute the jacobian, with the method depending on `sample_wise_grads_per_batch`. The jacobian computation is moved to helper function `_compute_jacobian_sample_wise_grads_per_batch`. Reviewed By: NarineK Differential Revision: D40541294 --- captum/influence/__init__.py | 2 + captum/influence/_core/influence.py | 14 +- captum/influence/_core/influence_function.py | 1314 +++++++++++++++++ captum/influence/_core/tracincp.py | 103 +- captum/influence/_utils/common.py | 531 ++++++- tests/influence/_core/test_naive_influence.py | 281 ++++ .../test_tracin_intermediate_quantities.py | 12 +- .../influence/_core/test_tracin_regression.py | 3 + .../_core/test_tracin_self_influence.py | 75 +- tests/influence/_utils/common.py | 297 +++- 10 files changed, 2464 insertions(+), 168 deletions(-) create mode 100644 captum/influence/_core/influence_function.py create mode 100644 tests/influence/_core/test_naive_influence.py diff --git a/captum/influence/__init__.py b/captum/influence/__init__.py index ac2c40a618..506851fe1b 100644 --- a/captum/influence/__init__.py +++ b/captum/influence/__init__.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 from captum.influence._core.influence import DataInfluence # noqa +from captum.influence._core.influence_function import NaiveInfluenceFunction # noqa from captum.influence._core.similarity_influence import SimilarityInfluence # noqa from captum.influence._core.tracincp import TracInCP, TracInCPBase # noqa from captum.influence._core.tracincp_fast_rand_proj import ( @@ -15,4 +16,5 @@ "TracInCP", "TracInCPFast", "TracInCPFastRandProj", + "NaiveInfluenceFunction", ] diff --git a/captum/influence/_core/influence.py b/captum/influence/_core/influence.py index 553ab38abb..51b33d0a9c 100644 --- a/captum/influence/_core/influence.py +++ b/captum/influence/_core/influence.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 from abc import ABC, abstractmethod -from typing import Any +from typing import Any, Type from torch.nn import Module from torch.utils.data import Dataset @@ -42,3 +42,15 @@ def influence(self, inputs: Any = None, **kwargs: Any) -> Any: though this may change in the future. """ pass + + @classmethod + def get_name(cls: Type["DataInfluence"]) -> str: + r""" + Create readable class name. Due to the nature of the names of `TracInCPBase` + subclasses, simply returns the class name. For example, for a class called + TracInCP, we return the string TracInCP. + + Returns: + name (str): a readable class name + """ + return cls.__name__ diff --git a/captum/influence/_core/influence_function.py b/captum/influence/_core/influence_function.py new file mode 100644 index 0000000000..6e3540f1e4 --- /dev/null +++ b/captum/influence/_core/influence_function.py @@ -0,0 +1,1314 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +import functools +from abc import abstractmethod +from operator import add +from typing import Any, Callable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from captum._utils.gradient import _extract_parameters_from_layers +from captum.influence._core.influence import DataInfluence + +from captum.influence._utils.common import ( + _check_loss_fn, + _compute_batch_loss_influence_function_base, + _compute_jacobian_sample_wise_grads_per_batch, + _dataset_fn, + _flatten_params, + _format_inputs_dataset, + _functional_call, + _get_k_most_influential_helper, + _influence_batch_intermediate_quantities_influence_function, + _influence_helper_intermediate_quantities_influence_function, + _influence_route_to_helpers, + _load_flexible_state_dict, + _params_to_names, + _progress_bar_constructor, + _self_influence_helper_intermediate_quantities_influence_function, + _set_active_parameters, + _top_eigen, + _unflatten_params_factory, + KMostInfluentialResults, +) +from captum.log import log_usage +from torch import Tensor +from torch.nn import Module +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm + + +class InfluenceFunctionBase(DataInfluence): + r""" + `InfluenceFunctionBase` is a base class for implementations which compute the + influence score as defined in the paper "Understanding Black-box Predictions via + Influence Functions" (https://arxiv.org/pdf/1703.04730.pdf). This "infinitesimal" + influence score approximately answers the question if a given training example + were infinitesimally down-weighted and the model re-trained to optimality, how much + would the loss on a given test example change. Mathematically, the aforementioned + influence score is given by :math`\nabla_\theta L(x)' H^{-1} \nabla_\theta L(z)`, + where :math`\nabla_\theta L(x)` is the gradient of the loss, considering only + training example :math`x` with respect to (a subset of) model parameters + :math`\theta`, :math`\nabla_\theta L(z)` is the analogous quantity for a test + example :math`z`, and :math`H` is the Hessian of the (subset of) model parameters + at a given model checkpoint. "Subset of model parameters" refers to the parameters + specified by the `layers` initialization argument; for computational purposes, + we may only consider the gradients / Hessian involving parameters in a subset of + the model's layers. This is a commonly-taken approach in the research literature. + + There can be multiple implementations of this class, because although the paper + defines a particular "infinitesimal" kind of influence score, there can be multiple + ways to compute it, each with different levels of accuracy / scalability. + """ + + def __init__( + self, + model: Module, + train_dataset: Union[Dataset, DataLoader], + checkpoint: str, + checkpoints_load_func: Callable = _load_flexible_state_dict, + layers: Optional[List[str]] = None, + loss_fn: Optional[Union[Module, Callable]] = None, + batch_size: Union[int, None] = 1, + hessian_dataset: Optional[Union[Dataset, DataLoader]] = None, + test_loss_fn: Optional[Union[Module, Callable]] = None, + sample_wise_grads_per_batch: bool = False, + ) -> None: + """ + Args: + model (torch.nn.Module): An instance of pytorch model. This model should + define all of its layers as attributes of the model. + train_dataset (torch.utils.data.Dataset or torch.utils.data.DataLoader): + In the `influence` method, we either compute the influence score of + training examples on examples in a test batch, or self influence + scores for those training examples, depending on which mode is used. + This argument represents the training dataset containing those + training examples. In order to compute those influence scores, we + will create a Pytorch DataLoader yielding batches of training + examples that is then used for processing. If this argument is + already a Pytorch Dataloader, that DataLoader can be directly + used for processing. If it is instead a Pytorch Dataset, we will + create a DataLoader using it, with batch size specified by + `batch_size`. For efficiency purposes, the batch size of the + DataLoader used for processing should be as large as possible, but + not too large, so that certain intermediate quantities created + from a batch still fit in memory. Therefore, if + `train_dataset` is a Dataset, `batch_size` should be large. + If `train_dataset` was already a DataLoader to begin with, + it should have been constructed to have a large batch size. It is + assumed that the Dataloader (regardless of whether it is created + from a Pytorch Dataset or not) yields tuples. For a `batch` that is + yielded, of length `L`, it is assumed that the forward function of + `model` accepts `L-1` arguments, and the last element of `batch` is + the label. In other words, `model(*batch[:-1])` gives the output of + `model`, and `batch[-1]` are the labels for the batch. + checkpoint (str): The path to the checkpoint used to compute influence + scores. + checkpoints_load_func (Callable, optional): The function to load a saved + checkpoint into a model to update its parameters, and get the + learning rate if it is saved. By default uses a utility to load a + model saved as a state dict. + Default: _load_flexible_state_dict + layers (list[str] or None, optional): A list of layer names for which + gradients should be computed. If `layers` is None, gradients will + be computed for all layers. Otherwise, they will only be computed + for the layers specified in `layers`. + Default: None + loss_fn (Callable, optional): The loss function applied to model. There + are two options for the return type of `loss_fn`. First, `loss_fn` + can be a "per-example" loss function - returns a 1D Tensor of + losses for each example in a batch. `nn.BCELoss(reduction="none")` + would be an "per-example" loss function. Second, `loss_fn` can be + a "reduction" loss function that reduces the per-example losses, + in a batch, and returns a single scalar Tensor. For this option, + the reduction must be the *sum* or the *mean* of the per-example + losses. For instance, `nn.BCELoss(reduction="sum")` is acceptable. + Note for the first option, the `sample_wise_grads_per_batch` + argument must be False, and for the second option, + `sample_wise_grads_per_batch` must be True. Also note that for + the second option, if `loss_fn` has no "reduction" attribute, + the implementation assumes that the reduction is the *sum* of the + per-example losses. If this is not the case, i.e. the reduction + is the *mean*, please set the "reduction" attribute of `loss_fn` + to "mean", i.e. `loss_fn.reduction = "mean"`. + batch_size (int or None, optional): Batch size of the DataLoader created to + iterate through `train_dataset` and `hessian_dataset`, if they are + of type `Dataset`. `batch_size` should be chosen as large as + possible so that a backwards pass on a batch still fits in memory. + If `train_dataset` and `hessian_dataset`are both of type + `DataLoader`, then `batch_size` is ignored as an argument. + Default: 1 + hessian_dataset (Dataset or Dataloader, optional): The influence score and + self-influence scores this implementation calculates are defined in + terms of the Hessian, i.e. the second-derivative of the model + parameters. This argument provides the dataset used for calculating + the Hessian. It should be smaller than `train_dataset`, which + is the dataset whose examples we want the influence of. If not + provided or none, it will be assumed to be the same as + `train_dataset`. + Default: None + test_loss_fn (Callable, optional): In some cases, one may want to use a + separate loss functions for training examples, i.e. those in + `train_dataset`, and for test examples, i.e. those + represented by the `inputs` and `targets` arguments to the + `influence` method. For example, if one wants to calculate the + influence score of a training example on a test example's + prediction for a fixed class, `test_loss_fn` could map from the + logits for all classes to the logits for a fixed class. + `test_loss_fn` needs satisfy the same constraints as `loss_fn`. + Thus, the same checks that we apply to `loss_fn` are also applied + to `test_loss_fn`, if the latter is provided. Note that the + constraints on both `loss_fn` and `test_loss_fn` both depend on + `sample_wise_grads_per_batch`. This means `loss_fn` and + `test_loss_fn` must either both be "per-example" loss functions, + or both be "reduction" loss functions. If not provided, the loss + function for test examples is assumed to be the same as the loss + function for training examples, i.e. `loss_fn`. + Default: None + sample_wise_grads_per_batch (bool, optional): PyTorch's native gradient + computations w.r.t. model parameters aggregates the results for a + batch and does not allow to access sample-wise gradients w.r.t. + model parameters. This forces us to iterate over each sample in + the batch if we want sample-wise gradients which is computationally + inefficient. We offer an implementation of batch-wise gradient + computations w.r.t. to model parameters which is computationally + more efficient. This implementation can be enabled by setting the + `sample_wise_grad_per_batch` argument to `True`, and should be + enabled if and only if the `loss_fn` argument is a "reduction" loss + function. For example, `nn.BCELoss(reduction="sum")` would be a + valid `loss_fn` if this implementation is enabled (see + documentation for `loss_fn` for more details). Note that our + current implementation enables batch-wise gradient computations + only for a limited number of PyTorch nn.Modules: Conv2D and Linear. + This list will be expanded in the near future. Therefore, please + do not enable this implementation if gradients will be computed + for other kinds of layers. + Default: False + """ + + self.model = model + + self.checkpoint = checkpoint + + self.checkpoints_load_func = checkpoints_load_func + # actually load the checkpoint + checkpoints_load_func(model, checkpoint) + self.loss_fn = loss_fn + # If test_loss_fn not provided, it's assumed to be same as loss_fn + self.test_loss_fn = loss_fn if test_loss_fn is None else test_loss_fn + self.sample_wise_grads_per_batch = sample_wise_grads_per_batch + self.batch_size = batch_size + + if not isinstance(train_dataset, DataLoader): + assert isinstance(batch_size, int), ( + "since the `train_dataset` argument was a `Dataset`, " + "`batch_size` must be an int." + ) + self.train_dataloader = DataLoader(train_dataset, batch_size, shuffle=False) + else: + self.train_dataloader = train_dataset + + if hessian_dataset is None: + self.hessian_dataloader = self.train_dataloader + elif not isinstance(hessian_dataset, DataLoader): + assert isinstance(batch_size, int), ( + "since the `shared_dataset` argument was a `Dataset`, " + "`batch_size` must be an int." + ) + self.hessian_dataloader = DataLoader( + hessian_dataset, batch_size, shuffle=False + ) + else: + self.hessian_dataloader = hessian_dataset + + # we check the loss functions in `InfluenceFunctionBase` rather than + # individually in its child classes because we assume all its implementations + # have the same requirements on loss functions, i.e. the type of reductions + # supported. furthermore, these checks are done using a helper function that + # handles all implementations with a `sample_wise_grads_per_batch` + # initialization argument. + + # we save the reduction type for both `loss_fn` and `test_loss_fn` because + # 1) if `sample_wise_grads_per_batch` is true, the reduction type is needed + # to compute per-example gradients, and 2) regardless, reduction type for + # `loss_fn` is needed to compute the Hessian. + + # check `loss_fn` + self.reduction_type = _check_loss_fn( + self, loss_fn, "loss_fn", sample_wise_grads_per_batch + ) + # check `test_loss_fn` if it was provided + if not (test_loss_fn is None): + self.test_reduction_type = _check_loss_fn( + self, test_loss_fn, "test_loss_fn", sample_wise_grads_per_batch + ) + else: + self.test_reduction_type = self.reduction_type + + self.layer_modules = None + if not (layers is None): + self.layer_modules = _set_active_parameters(model, layers) + + @abstractmethod + def self_influence( + self, + inputs_dataset: Optional[Union[Tuple[Any, ...], DataLoader]] = None, + show_progress: bool = False, + ) -> Tensor: + """ + Computes self influence scores for the examples in `inputs_dataset`, which is + either a single batch or a Pytorch `DataLoader` that yields batches. Therefore, + the computed self influence scores are *not* for the examples in training + dataset `train_dataset` (unlike when computing self influence scores using the + `influence` method). Note that if `inputs_dataset` is a single batch, this + will call `model` on that single batch, and if `inputs_dataset` yields + batches, this will call `model` on each batch that is yielded. Therefore, + please ensure that for both cases, the batch(es) that `model` is called + with are not too large, so that there will not be an out-of-memory error. + + Args: + inputs_dataset (tuple or DataLoader): Either a single tuple of any, or a + `DataLoader`, where each batch yielded is a tuple of any. In + either case, the tuple represents a single batch, where the last + element is assumed to be the labels for the batch. That is, + `model(*batch[0:-1])` produces the output for `model`, + and `batch[-1]` are the labels, if any. This is the same + assumption made for each batch yielded by training dataset + `train_dataset`. + show_progress (bool, optional): Computation of self influence scores can + take a long time if `inputs_dataset` represents many examples. If + `show_progress` is true, the progress of this computation will be + displayed. In more detail, this computation will iterate over all + checkpoints (provided as the `checkpoints` initialization argument) + in an outer loop, and iterate over all batches that + `inputs_dataset` represents in an inner loop. Therefore, the + total number of (checkpoint, batch) combinations that need to be + iterated over is + (# of checkpoints x # of batches that `inputs_dataset` represents). + If `show_progress` is True, the total progress of both the outer + iteration over checkpoints and the inner iteration over batches is + displayed. It will try to use tqdm if available for advanced + features (e.g. time estimation). Otherwise, it will fallback to a + simple output of progress. + Default: False + + Returns: + self_influence_scores (Tensor): This is a 1D tensor containing the self + influence scores of all examples in `inputs_dataset`, regardless of + whether it represents a single batch or a `DataLoader` that yields + batches. + """ + pass + + @abstractmethod + def _get_k_most_influential( + self, + inputs: Union[Tuple[Any, ...], DataLoader], + k: int = 5, + proponents: bool = True, + show_progress: bool = False, + ) -> KMostInfluentialResults: + r""" + Args: + + inputs (tuple): `inputs` is the test batch and is a tuple of + any, where the last element is assumed to be the labels for the + batch. That is, `model(*batch[0:-1])` produces the output for + `model`, and `batch[-1]` are the labels, if any. This is the same + assumption made for each batch yielded by training dataset + `train_dataset` - please see its documentation in `__init__` for + more details on the assumed structure of a batch. + k (int, optional): The number of proponents or opponents to return per test + example. + Default: 5 + proponents (bool, optional): Whether seeking proponents (`proponents=True`) + or opponents (`proponents=False`) + Default: True + show_progress (bool, optional): To compute the proponents (or opponents) + for the batch of examples, we perform computation for each batch in + training dataset `train_dataset`, If `show_progress` is + true, the progress of this computation will be displayed. In + particular, the number of batches for which the computation has + been performed will be displayed. It will try to use tqdm if + available for advanced features (e.g. time estimation). Otherwise, + it will fallback to a simple output of progress. + Default: False + + Returns: + (indices, influence_scores) (namedtuple): `indices` is a torch.long Tensor + that contains the indices of the proponents (or opponents) for each + test example. Its dimension is `(inputs_batch_size, k)`, where + `inputs_batch_size` is the number of examples in `inputs`. For + example, if `proponents==True`, `indices[i][j]` is the index of the + example in training dataset `train_dataset` with the + k-th highest influence score for the j-th example in `inputs`. + `indices` is a `torch.long` tensor so that it can directly be used + to index other tensors. Each row of `influence_scores` contains the + influence scores for a different test example, in sorted order. In + particular, `influence_scores[i][j]` is the influence score of + example `indices[i][j]` in training dataset `train_dataset` + on example `i` in the test dataset represented by `inputs`. + """ + pass + + @abstractmethod + def _influence( + self, + inputs: Union[Tuple[Any, ...], DataLoader], + show_progress: bool = False, + ) -> Tensor: + r""" + Args: + + inputs (tuple[Any, ...]): A batch of examples. Does not represent labels, + which are passed as `targets`. The assumption is that + `model(*inputs)` produces the predictions for the batch. + targets (Tensor, optional): If computing influence scores on a loss + function, these are the labels corresponding to the batch + `inputs`. + Default: None + + Returns: + influence_scores (Tensor): Influence scores over the entire + training dataset `train_dataset`. Dimensionality is + (inputs_batch_size, src_dataset_size). For example: + influence_scores[i][j] = the influence score for the j-th training + example to the i-th input example. + """ + pass + + @abstractmethod + def influence( # type: ignore[override] + self, + inputs: Tuple, + k: Optional[int] = None, + proponents: bool = True, + show_progress: bool = False, + ) -> Union[Tensor, KMostInfluentialResults]: + r""" + This is the key method of this class, and can be run in 2 different modes, + where the mode that is run depends on the arguments passed to this method: + + - influence score mode: This mode is used if `k` is None. This mode computes + the influence score of every example in training dataset `train_dataset` + on every example in the test dataset represented by `inputs`. + - k-most influential mode: This mode is used if `k` is not None, and an int. + This mode computes the proponents or opponents of every example in the + test dataset represented by `inputs`. In particular, for each test example in + the test dataset, this mode computes its proponents (resp. opponents), + which are the indices in the training dataset `train_dataset` of the + training examples with the `k` highest (resp. lowest) influence scores on the + test example. Proponents are computed if `proponents` is True. Otherwise, + opponents are computed. For each test example, this method also returns the + actual influence score of each proponent (resp. opponent) on the test + example. + + Args: + + inputs (tuple): `inputs` is the test batch and is a tuple of + any, where the last element is assumed to be the labels for the + batch. That is, `model(*batch[0:-1])` produces the output for + `model`, and `batch[-1]` are the labels, if any. This is the same + assumption made for each batch yielded by training dataset + `train_dataset` - please see its documentation in `__init__` for + more details on the assumed structure of a batch. + k (int, optional): If not provided or `None`, the influence score mode will + be run. Otherwise, the k-most influential mode will be run, + and `k` is the number of proponents / opponents to return per + example in the test dataset. + Default: None + proponents (bool, optional): Whether seeking proponents (`proponents=True`) + or opponents (`proponents=False`), if running in k-most influential + mode. + Default: True + show_progress (bool, optional): For all modes, computation of results + requires "training dataset computations": computations for each + batch in the training dataset `train_dataset`, which may + take a long time. If `show_progress` is true, the progress of + "training dataset computations" will be displayed. In particular, + the number of batches for which computations have been performed + will be displayed. It will try to use tqdm if available for + advanced features (e.g. time estimation). Otherwise, it will + fallback to a simple output of progress. + Default: False + + Returns: + The return value of this method depends on which mode is run. + + - influence score mode: if this mode is run (`k` is None), returns a 2D + tensor `influence_scores` of shape `(input_size, train_dataset_size)`, + where `input_size` is the number of examples in the test dataset, and + `train_dataset_size` is the number of examples in training dataset + `train_dataset`. In other words, `influence_scores[i][j]` is the + influence score of the `j`-th example in `train_dataset` on the `i`-th + example in the test dataset. + - k-most influential mode: if this mode is run (`k` is an int), returns + a namedtuple `(indices, influence_scores)`. `indices` is a 2D tensor of + shape `(input_size, k)`, where `input_size` is the number of examples in + the test dataset. If computing proponents (resp. opponents), + `indices[i][j]` is the index in training dataset `train_dataset` of the + example with the `j`-th highest (resp. lowest) influence score (out of + the examples in `train_dataset`) on the `i`-th example in the test + dataset. `influence_scores` contains the corresponding influence scores. + In particular, `influence_scores[i][j]` is the influence score of example + `indices[i][j]` in `train_dataset` on example `i` in the test dataset + represented by `inputs`. + """ + pass + + +class IntermediateQuantitiesInfluenceFunction(InfluenceFunctionBase): + """ + Implementations of this class all implement the `compute_intermediate_quantities` + method, which computes the "embedding" vectors for all examples in a test dataset. + These embedding vectors are assumed to have the following properties: + - the influence score of one example on another example, as calculated by the + implementation, is the dot-product of their respective embeddings. + - the self influence score of an example is the squared norm of its embedding. + """ + + @abstractmethod + def compute_intermediate_quantities( + self, + inputs_dataset: Union[Tuple[Any, ...], DataLoader], + aggregate: bool = False, + show_progress: bool = False, + return_on_cpu: bool = True, + test: bool = False, + ): + pass + + +def _flatten_forward_factory( + model: nn.Module, + loss_fn: Optional[Union[Module, Callable]], + reduction_type: str, + unflatten_fn: Callable, + param_names: List[str], +): + """ + Given a model, loss function, reduction type of the loss, function that unflattens + 1D tensor input into a tuple of tensors, the name of each tensor in that tuple, + each of which represents a parameter of `model`, and returns a factory. The factory + accepts a batch, and returns a function whose input is the parameters represented + by `param_names`, and output is the total loss of the model with those parameters, + calculated on the batch. The parameter input to the returned function is assumed to + be *flattened* via the inverse of `unflatten_fn`, which takes a tuple of tensors to + a 1D tensor. This returned function, accepting a single flattened 1D parameter, is + useful for computing the parameter gradient involving the batch as a 1D tensor, and + the Hessian involving the batch as a 2D tensor. Both quantities are needed to + calculate the kind of influence scores returned by implementations of + `InfluenceFunctionBase`. + """ + # this is the factory that accepts a batch + def flatten_forward_factory_given_batch(batch): + + # this is the function that factory returns, which is a function of flattened + # parameters + def flattened_forward(flattened_params): + # as everywhere else, the all but the last elements of a batch are + # assumed to correspond to the features, i.e. input to forward function + features, labels = tuple(batch[0:-1]), batch[-1] + + _output = _functional_call( + model, dict(zip(param_names, unflatten_fn(flattened_params))), features + ) + + # compute the total loss for the batch, adjusting the output of + # `loss_fn` based on `reduction_type` + return _compute_batch_loss_influence_function_base( + loss_fn, _output, labels, reduction_type + ) + + return flattened_forward + + return flatten_forward_factory_given_batch + + +def _compute_dataset_func( + inputs_dataset: Union[Tuple[Tensor, ...], DataLoader], + model: Module, + loss_fn: Optional[Union[Module, Callable]], + reduction_type: str, + layer_modules: Optional[List[Module]], + f: Callable, + show_progress: bool, + **f_kwargs, +): + """ + This function is used to compute higher-order functions of a given model's loss + over a given dataset, using the model's current parameters. For example, that + higher-order function `f` could be the Hessian, or a Hessian-vector product. + This function uses the factory returned by `_flatten_forward_factory`, which given + a batch, returns the loss for the batch as a function of flattened parameters. + In particular, for each batch in `inputs_dataset`, this function uses the factory + to obtain `flattened_forward`, which returns the loss for `model`, using the batch. + `flattened_forward`, as well as the flattened parameters for `model`, are used by + argument `f`, a higher-order function, to compute a batch-specific quantity. + For example, `f` could compute the Hessian via `torch.autograd.functional.hessian`, + or compute a Hessian-vector product via `torch.autograd.functional.hvp`. Additional + arguments besides `flattened_forward` and the flattened parameters, i.e. the vector + in Hessian-vector products, can be passed via named arguments. + """ + # extract the parameters in a tuple + params = tuple( + model.parameters() + if layer_modules is None + else _extract_parameters_from_layers(layer_modules) + ) + + # construct functions that can flatten / unflatten tensors, and get + # names of each param in `params`. + # Both are needed for calling `_flatten_forward_factory` + _unflatten_params = _unflatten_params_factory( + tuple([param.shape for param in params]) + ) + param_names = _params_to_names(params, model) + + # prepare factory + factory_given_batch = _flatten_forward_factory( + model, + loss_fn, + reduction_type, + _unflatten_params, + param_names, + ) + + # the function returned by the factor is evaluated at a *flattened* version of + # params, so need to create that + flattened_params = _flatten_params(params) + + # define function of a single batch + def batch_f(batch): + flattened_forward = factory_given_batch(batch) # accepts flattened params + return f(flattened_forward, flattened_params, **f_kwargs) + + # sum up results of `batch_f` + if show_progress: + inputs_dataset = tqdm(inputs_dataset, desc="processing `hessian_dataset` batch") + + return _dataset_fn(inputs_dataset, batch_f, add) + + +def _get_dataset_embeddings_intermediate_quantities_influence_function( + batch_embeddings_fn: Callable, + inputs_dataset: DataLoader, + aggregate: bool, +): + """ + given `batch_embeddings_fn`, which produces the embeddings for a given batch, + returns either the embeddings for an entire dataset (if `aggregate` is false), + or the sum of the embeddings for an entire dataset (if `aggregate` is true). + """ + # if aggregate is false, we concatenate the embeddings for all batches + if not aggregate: + return torch.cat( + [batch_embeddings_fn(batch) for batch in inputs_dataset], dim=0 + ) + else: + # if aggregate is True, we return the sum of all embeddings for all + # batches. we do this by summing over each batch, and then summing over all + # batches. + inputs_dataset_iter = iter(inputs_dataset) + + batch = next(inputs_dataset_iter) + total_embedding = torch.sum(batch_embeddings_fn(batch), dim=0) + + for batch in inputs_dataset_iter: + total_embedding += torch.sum(batch_embeddings_fn(batch), dim=0) + + # we unsqueeze because regardless of aggregate, the returned tensor should + # be 2D. + return total_embedding.unsqueeze(0) + + +class NaiveInfluenceFunction(IntermediateQuantitiesInfluenceFunction): + r""" + This is a computationally-inefficient implementation that computes the type of + "infinitesimal" influence scores defined in the paper "Understanding Black-box + Predictions via Influence Functions" by Koh et al + (https://arxiv.org/pdf/1703.04730.pdf). The computational bottleneck in computing + infinitesimal influence scores is computing inverse Hessian-vector products, as can + be seen from its definition in `InfluenceFunctionBase`. This implementation is + inefficient / naive in that it explicitly forms the Hessian :math`H`, unlike other + implementations which compute inverse Hessian-vector products without explicitly + forming the Hessian. The purpose of this implementation is to have a way to + generate the "ground-truth" influence scores, to which other implementations, + which are more efficient but return only approximations of the influence score, can + be compared. + + This implementation computes a low-rank approximation of the inverse Hessian, i.e. + a tall and skinny (with width k) matrix :math`R` such that + :math`H^{-1} \approx RR'`, where k is small. In particular, let :math`L` be the + matrix of width k whose columns contain the top-k eigenvectors of :math`H`, and let + :math`V` be the k by k matrix whose diagonals contain the corresponding eigenvalues. + This implementation lets :math`R=LV^{-1}L'`. Thus, the core computational step is + computing the top-k eigenvalues / eigenvectors. + + This low-rank approximation is useful for several reasons: + - It avoids numerical issues associated with inverting small eigenvalues. + - Since the influence score is given by + :math`\nabla_\theta L(x)' H^{-1} \nabla_\theta L(z)`, which is approximated by + :math`(\nabla_\theta L(x)' R) (\nabla_\theta L(z)' R)`, we can compute an + "influence embedding" for a given example :math`x`, :math`\nabla_\theta L(x)' R`, + such that the influence score of one example on another is approximately the + dot-product of their respective embeddings. + + This implementation is "naive" in that it computes the top-k eigenvalues / + eigenvectors by explicitly forming the Hessian, converting it to a 2D tensor, + computing its eigenvectors / eigenvalues, and then sorting. See documentation of the + `_retrieve_projections_naive_influence_function` method for more details. + """ + + def __init__( + self, + model: Module, + train_dataset: Union[Dataset, DataLoader], + checkpoint: str, + checkpoints_load_func: Callable = _load_flexible_state_dict, + layers: Optional[List[str]] = None, + loss_fn: Optional[Union[Module, Callable]] = None, + batch_size: Union[int, None] = 1, + hessian_dataset: Optional[Union[Dataset, DataLoader]] = None, + test_loss_fn: Optional[Union[Module, Callable]] = None, + sample_wise_grads_per_batch: bool = False, + projection_dim: int = 50, + seed: int = 42, + hessian_reg: float = 1e-6, + hessian_inverse_tol: float = 1e-5, + projection_on_cpu: bool = True, + show_progress: bool = False, + ) -> None: + """ + Args: + model (torch.nn.Module): An instance of pytorch model. This model should + define all of its layers as attributes of the model. + train_dataset (torch.utils.data.Dataset or torch.utils.data.DataLoader): + In the `influence` method, we either compute the influence score of + training examples on examples in a test batch, or self influence + scores for those training examples, depending on which mode is used. + This argument represents the training dataset containing those + training examples. In order to compute those influence scores, we + will create a Pytorch DataLoader yielding batches of training + examples that is then used for processing. If this argument is + already a Pytorch Dataloader, that DataLoader can be directly + used for processing. If it is instead a Pytorch Dataset, we will + create a DataLoader using it, with batch size specified by + `batch_size`. For efficiency purposes, the batch size of the + DataLoader used for processing should be as large as possible, but + not too large, so that certain intermediate quantities created + from a batch still fit in memory. Therefore, if + `train_dataset` is a Dataset, `batch_size` should be large. + If `train_dataset` was already a DataLoader to begin with, + it should have been constructed to have a large batch size. It is + assumed that the Dataloader (regardless of whether it is created + from a Pytorch Dataset or not) yields tuples. For a `batch` that is + yielded, of length `L`, it is assumed that the forward function of + `model` accepts `L-1` arguments, and the last element of `batch` is + the label. In other words, `model(*batch[:-1])` gives the output of + `model`, and `batch[-1]` are the labels for the batch. + checkpoint (str): The path to the checkpoint used to compute influence + scores. + checkpoints_load_func (Callable, optional): The function to load a saved + checkpoint into a model to update its parameters, and get the + learning rate if it is saved. By default uses a utility to load a + model saved as a state dict. + Default: _load_flexible_state_dict + layers (list[str] or None, optional): A list of layer names for which + gradients should be computed. If `layers` is None, gradients will + be computed for all layers. Otherwise, they will only be computed + for the layers specified in `layers`. + Default: None + loss_fn (Callable, optional): The loss function applied to model. For now, + we require it to be a "reduction='none'" loss function. For + example, `BCELoss(reduction='none')` would be acceptable, but + `BCELoss(reduction='sum')` would not. + batch_size (int or None, optional): Batch size of the DataLoader created to + iterate through `train_dataset` and `hessian_dataset`, if they are + of type `Dataset`. `batch_size` should be chosen as large as + possible so that a backwards pass on a batch still fits in memory. + If `train_dataset` and `hessian_dataset`are both of type + `DataLoader`, then `batch_size` is ignored as an argument. + Default: 1 + hessian_dataset (Dataset or Dataloader, optional): The influence score and + self-influence scores this implementation calculates are defined in + terms of the Hessian, i.e. the second-derivative of the model + parameters. This argument provides the dataset used for calculating + the Hessian. It should be smaller than `train_dataset`, which + is the dataset whose examples we want the influence of. If not + provided or none, it will be assumed to be the same as + `train_dataset`. + Default: None + test_loss_fn (Callable, optional): In some cases, one may want to use a + separate loss functions for training examples, i.e. those in + `train_dataset`, and for test examples, i.e. those + represented by the `inputs` and `targets` arguments to the + `influence` method. For example, if one wants to calculate the + influence score of a training example on a test example's + prediction for a fixed class, `test_loss_fn` could map from the + logits for all classes to the logits for a fixed class. + `test_loss_fn` needs satisfy the same constraints as `loss_fn`. + Thus, the same checks that we apply to `loss_fn` are also applied + to `test_loss_fn`, if the latter is provided. Note that the + constraints on both `loss_fn` and `test_loss_fn` both depend on + `sample_wise_grads_per_batch`. This means `loss_fn` and + `test_loss_fn` must either both be "per-example" loss functions, + or both be "reduction" loss functions. If not provided, the loss + function for test examples is assumed to be the same as the loss + function for training examples, i.e. `loss_fn`. + Default: None + sample_wise_grads_per_batch (bool, optional): PyTorch's native gradient + computations w.r.t. model parameters aggregates the results for a + batch and does not allow to access sample-wise gradients w.r.t. + model parameters. This forces us to iterate over each sample in + the batch if we want sample-wise gradients which is computationally + inefficient. We offer an implementation of batch-wise gradient + computations w.r.t. to model parameters which is computationally + more efficient. This implementation can be enabled by setting the + `sample_wise_grad_per_batch` argument to `True`, and should be + enabled if and only if the `loss_fn` argument is a "reduction" loss + function. For example, `nn.BCELoss(reduction="sum")` would be a + valid `loss_fn` if this implementation is enabled (see + documentation for `loss_fn` for more details). Note that our + current implementation enables batch-wise gradient computations + only for a limited number of PyTorch nn.Modules: Conv2D and Linear. + This list will be expanded in the near future. Therefore, please + do not enable this implementation if gradients will be computed + for other kinds of layers. + Default: False + projection_dim (int, optional): This implementation produces a low-rank + approximation of the (inverse) Hessian. This is the rank of that + approximation, and also corresponds to the dimension of the + "influence embeddings" produced by the + `compute_intermediate_quantities` method. + Default: 50 + seed (int, optional): This implementation has a source of randomness - the + initialization basis to the Arnoldi iteration. This seed is used + to make that randomness reproducible. + Default: 42 + hessian_reg (float, optional): We add an entry to the hessian's diagonal + entries before computing its eigenvalues / eigenvectors. + This is that entry. + Default: 1e-6 + hessian_inverse_tol: (float) The tolerance to use when computing the + pseudo-inverse of the (square root of) hessian. + Default: 1e-6 + projection_on_cpu (bool, optional): Whether to move the projection, + i.e. low-rank approximation of the inverse Hessian, to cpu, to save + gpu memory. + Default: True + show_progress (bool, optional): This implementation explicitly computes the + Hessian over batches in `hessian_dataloader` (and sums them) which + can take a long time. If `show_progress` is true, the number of + batches for which the Hessian has been computed will be displayed. + It will try to use tqdm if available for advanced features (e.g. + time estimation). Otherwise, it will fallback to a simple output of + progress. + Default: False + """ + InfluenceFunctionBase.__init__( + self, + model, + train_dataset, + checkpoint, + checkpoints_load_func, + layers, + loss_fn, + batch_size, + hessian_dataset, + test_loss_fn, + sample_wise_grads_per_batch, + ) + + self.projection_dim = projection_dim + torch.manual_seed(seed) # for reproducibility + + self.hessian_reg = hessian_reg + self.hessian_inverse_tol = hessian_inverse_tol + + # infer the device the model is on. all parameters are assumed to be on the + # same device + self.model_device = next(model.parameters()).device + + self.R = self._retrieve_projections_naive_influence_function( + self.hessian_dataloader, + projection_on_cpu, + show_progress, + ) + + def _retrieve_projections_naive_influence_function( + self, + dataloader: DataLoader, + projection_on_cpu: bool, + show_progress: bool, + ) -> Tensor: + r""" + Returns the matrix `R` described in the documentation for + `NaiveInfluenceFunction`. In short, `R` has the property that + :math`H^{-1} \approx RR'`, where `H` is the Hessian. Since this is a "naive" + implementation, it does so by explicitly forming the Hessian, converting + it to a 2D tensor, and computing its eigenvectors / eigenvalues, before + filtering out some eigenvalues and then inverting them. The returned matrix + `R` represents a set of parameters in parameter space. Since the Hessian + is obtained by first flattening the parameters, each column of `R` corresponds + to a *flattened* parameter in parameter space. + + Args: + dataloader (DataLoader): The returned matrix `R` gives a low-rank + approximation of the Hessian `H`. This dataloader defines the + dataset used to compute the Hessian that is being approximated. + projection_on_cpu (bool, optional): Whether to move the projection, + i.e. low-rank approximation of the inverse Hessian, to cpu, to save + gpu memory. + show_progress (bool): Computing the Hessian that is being approximated + requires summing up the Hessians computed using different batches, + which may take a long time. If `show_progress` is true, the number + of batches that have been processed will be displayed. It will try + to use tqdm if available for advanced features (e.g. time + estimation). Otherwise, it will fallback to a simple output of + progress. + + Returns: + R (Tensor): Tall and skinny tensor with width `projection_dim` + (initialization argument). Each column corresponds to a flattened + parameter in parameter-space. `R` has the property that + :math`H^{-1} \approx RR'`. + """ + # compute the hessian using the dataloader. hessian is always computed using + # the training loss function. H is 2D, with each column / row corresponding to + # a different parameter. we cannot directly use + # `torch.autograd.functional.hessian`, because it does not return a 2D tensor. + # instead, to compute H, we first create a function that accepts *flattened* + # model parameters (i.e. a 1D tensor), and outputs the loss of `self.model`, + # using those parameters, aggregated over `dataloader`. this function is then + # passed to `torch.autograd.functional.hessian`. because its input is 1D, the + # resulting hessian is 2D, as desired. all this functionality is handled by + # `_compute_dataset_func`. + H = _compute_dataset_func( + dataloader, + self.model, + self.loss_fn, + self.reduction_type, + self.layer_modules, + torch.autograd.functional.hessian, + show_progress, + ) + + # H is approximately `vs @ torch.diag(ls) @ vs.T``, using eigendecomposition + ls, vs = _top_eigen( + H, self.projection_dim, self.hessian_reg, self.hessian_inverse_tol + ) + + # if no positive eigenvalues exist, we cannot compute a low-rank + # approximation of the square root of the hessian H, so raise exception + if len(ls) == 0: + raise Exception( + "Hessian has no positive " + "eigenvalues, so cannot take its square root." + ) + + # `R` is `vs @ torch.diag(ls ** -0.5)`, since H^{-1} is approximately + # `vs @ torch.diag(ls ** -1) @ vs.T` + # see https://en.wikipedia.org/wiki/Eigendecomposition_of_a_matrix#Matrix_inverse_via_eigendecomposition # noqa: E501 + # for details, which mentions that discarding small eigenvalues (as done in + # `_top_eigen`) reduces noisiness of the inverse. + ls = (1.0 / ls) ** 0.5 + return (ls.unsqueeze(0) * vs).to( + device=torch.device("cpu") if projection_on_cpu else self.model_device + ) + + def compute_intermediate_quantities( + self, + inputs_dataset: Union[Tuple[Any, ...], DataLoader], + aggregate: bool = False, + show_progress: bool = False, + return_on_cpu: bool = True, + test: bool = False, + ) -> Tensor: + r""" + Computes "embedding" vectors for all examples in a single batch, or a + `Dataloader` that yields batches. These embedding vectors are constructed so + that the influence score of a training example on a test example is simply the + dot-product of their corresponding vectors. In both cases, a batch should be + small enough so that a backwards pass for a batch does not lead to + out-of-memory errors. + + In more detail, the embedding vector for an example `x` is + :math`\nabla_\theta L(x)' R`, where :math`R` is as defined in this class' + description. The embeddings for a batch of examples are computed by assembling + :math`\nabla_\theta L(x)` for all examples `x` in the batch as rows in a 2D + tensor, and right-multiplying by `R`. + + If `aggregate` is True, the *sum* of the vectors for all examples is returned, + instead of the vectors for each example. This can be useful for computing the + influence of a given training example on the total loss over a validation + dataset, because due to properties of the dot-product, this influence is the + dot-product of the training example's vector with the sum of the vectors in the + validation dataset. Also, by doing the sum aggregation within this method as + opposed to outside of it (by computing all vectors for the validation dataset, + then taking the sum) allows memory usage to be reduced. + + Args: + inputs_dataset (Tuple, or DataLoader): Either a single tuple of any, or a + `DataLoader`, where each batch yielded is a tuple of any. In + either case, the tuple represents a single batch, where the last + element is assumed to be the labels for the batch. That is, + `model(*batch[0:-1])` produces the output for `model`, and + and `batch[-1]` are the labels, if any. Here, `model` is model + provided in initialization. This is the same assumption made for + each batch yielded by training dataset `train_dataset`. + aggregate (bool): Whether to return the sum of the vectors for all + examples, as opposed to vectors for each example. + show_progress (bool, optional): Computation of vectors can take a long + time if `inputs_dataset` represents many examples. If + `show_progress`is true, the progress of this computation will be + displayed. In particular, the number of batches for which + vectors have been computed will be displayed. It will try to + use tqdm if available for advanced features (e.g. time estimation). + Otherwise, it will fallback to a simple output of progress. + Default: False + return_on_cpu (bool, optional): Whether to return the vectors on the cpu. + If None or False, is set to the device that the model is on. + Default: None + test (bool, optional): Whether to compute the vectors using the loss + function `test_loss_fn` provided in initialization (instead of + `loss_fn`). This argument does not matter if `test_loss_fn` was + not provided, as in this case, `test_loss_fn` and `loss_fn` are the + same. + + Returns: + intermediate_quantities (Tensor): This is a 2D tensor with shape + `(N, projection_dim)`, where `N` is the total number of examples in + `inputs_dataset`, and `projection_dim` was provided in + initialization. Each row contains the vector for a different + example. + """ + # if `inputs_dataset` is not a `DataLoader`, turn it into one. + inputs_dataset = _format_inputs_dataset(inputs_dataset) + + if show_progress: + inputs_dataset = _progress_bar_constructor( + self, inputs_dataset, "inputs_dataset", "intermediate quantities" + ) + # infer model / data device through model + if return_on_cpu is None or (not return_on_cpu): + return_device = self.model_device + else: + return_device = torch.device("cpu") + + # as described in the description for `NaiveInfluenceFunction`, the embedding + # for an example `x` is :math`\nabla_\theta L(x)' R`. + # `_basic_computation_naive_influence_function` returns a 2D tensor where + # each row is :math`\nabla_\theta L(x)'` for a different example `x` in a + # batch. therefore, we right-multiply its output with `R` to get the embeddings + # for a batch, and then concatenate the per-batch embeddings to get embeddings + # for the entire dataset. + + # choose the correct loss function and reduction type based on `test` + loss_fn = self.test_loss_fn if test else self.loss_fn + reduction_type = self.test_reduction_type if test else self.reduction_type + + # define a helper function that returns the embeddings for a batch + def get_batch_embeddings(batch): + # if `self.R` is on cpu, and `self.model_device` was not cpu, this implies + # `self.R` was too large to fit in gpu memory, and we should do the matrix + # multiplication of the batch jacobians with `self.R` separately for each + # column of `self.R`, to avoid moving the entire `self.R` to gpu all at + # once and running out of gpu memory + batch_jacobians = _basic_computation_naive_influence_function( + self, batch[0:-1], batch[-1], loss_fn, reduction_type + ) + if self.R.device == torch.device( + "cpu" + ) and self.model_device != torch.device("cpu"): + return torch.stack( + [ + torch.matmul(batch_jacobians, R_col.to(batch_jacobians.device)) + for R_col in self.R.T + ], + dim=1, + ).to(return_device) + else: + return torch.matmul(batch_jacobians, self.R).to(device=return_device) + + # using `get_batch_embeddings` and a helper, return all the vectors or their + # sum, depending on `aggregate` + return _get_dataset_embeddings_intermediate_quantities_influence_function( + get_batch_embeddings, + inputs_dataset, + aggregate, + ) + + @log_usage(skip_self_logging=True) + def influence( # type: ignore[override] + self, + inputs: Tuple, + k: Optional[int] = None, + proponents: bool = True, + show_progress: bool = False, + ) -> Union[Tensor, KMostInfluentialResults]: + """ + This is the key method of this class, and can be run in 2 different modes, + where the mode that is run depends on the arguments passed to this method: + + - influence score mode: This mode is used if `k` is None. This mode computes + the influence score of every example in training dataset `train_dataset` + on every example in the test batch represented by `inputs`. + - k-most influential mode: This mode is used if `k` is not None, and an int. + This mode computes the proponents or opponents of every example in the + test batch represented by `inputs`. In particular, for each test example in + the test batch, this mode computes its proponents (resp. opponents), + which are the indices in the training dataset `train_dataset` of the + training examples with the `k` highest (resp. lowest) influence scores on the + test example. Proponents are computed if `proponents` is True. Otherwise, + opponents are computed. For each test example, this method also returns the + actual influence score of each proponent (resp. opponent) on the test + example. + + Args: + + inputs (tuple): `inputs` is the test batch and is a tuple of + any, where the last element is assumed to be the labels for the + batch. That is, `model(*batch[0:-1])` produces the output for + `model`, and `batch[-1]` are the labels, if any. This is the same + assumption made for each batch yielded by training dataset + `train_dataset` - please see its documentation in `__init__` for + more details on the assumed structure of a batch. + k (int, optional): If not provided or `None`, the influence score mode will + be run. Otherwise, the k-most influential mode will be run, + and `k` is the number of proponents / opponents to return per + example in the test batch. + Default: None + proponents (bool, optional): Whether seeking proponents (`proponents=True`) + or opponents (`proponents=False`), if running in k-most influential + mode. + Default: True + show_progress (bool, optional): For all modes, computation of results + requires "training dataset computations": computations for each + batch in the training dataset `train_dataset`, which may + take a long time. If `show_progress` is true, the progress of + "training dataset computations" will be displayed. In particular, + the number of batches for which computations have been performed + will be displayed. It will try to use tqdm if available for + advanced features (e.g. time estimation). Otherwise, it will + fallback to a simple output of progress. + Default: False + + Returns: + The return value of this method depends on which mode is run. + + - influence score mode: if this mode is run (`k` is None), returns a 2D + tensor `influence_scores` of shape `(input_size, train_dataset_size)`, + where `input_size` is the number of examples in the test dataset, and + `train_dataset_size` is the number of examples in training dataset + `train_dataset`. In other words, `influence_scores[i][j]` is the + influence score of the `j`-th example in `train_dataset` on the `i`-th + example in the test batch. + - k-most influential mode: if this mode is run (`k` is an int), returns + a namedtuple `(indices, influence_scores)`. `indices` is a 2D tensor of + shape `(input_size, k)`, where `input_size` is the number of examples in + the test batch. If computing proponents (resp. opponents), + `indices[i][j]` is the index in training dataset `train_dataset` of the + example with the `j`-th highest (resp. lowest) influence score (out of + the examples in `train_dataset`) on the `i`-th example in the test + batch. `influence_scores` contains the corresponding influence scores. + In particular, `influence_scores[i][j]` is the influence score of example + `indices[i][j]` in `train_dataset` on example `i` in the test batch + represented by `inputs`. + """ + + return _influence_route_to_helpers( + self, + inputs, + k, + proponents, + show_progress=show_progress, + ) + + def _get_k_most_influential( + self, + inputs: Union[Tuple[Any, ...], DataLoader], + k: int = 5, + proponents: bool = True, + show_progress: bool = False, + ) -> KMostInfluentialResults: + r""" + Args: + + inputs (tuple): `inputs` is the test batch and is a tuple of + any, where the last element is assumed to be the labels for the + batch. That is, `model(*batch[0:-1])` produces the output for + `model`, and `batch[-1]` are the labels, if any. This is the same + assumption made for each batch yielded by training dataset + `train_dataset` - please see its documentation in `__init__` for + more details on the assumed structure of a batch. + k (int, optional): The number of proponents or opponents to return per test + example. + Default: 5 + proponents (bool, optional): Whether seeking proponents (`proponents=True`) + or opponents (`proponents=False`) + Default: True + show_progress (bool, optional): To compute the proponents (or opponents) + for the batch of examples, we perform computation for each batch in + training dataset `train_dataset`, If `show_progress` is + true, the progress of this computation will be displayed. In + particular, the number of batches for which the computation has + been performed will be displayed. It will try to use tqdm if + available for advanced features (e.g. time estimation). Otherwise, + it will fallback to a simple output of progress. + Default: False + + Returns: + (indices, influence_scores) (namedtuple): `indices` is a torch.long Tensor + that contains the indices of the proponents (or opponents) for each + test example. Its dimension is `(inputs_batch_size, k)`, where + `inputs_batch_size` is the number of examples in `inputs`. For + example, if `proponents==True`, `indices[i][j]` is the index of the + example in training dataset `train_dataset` with the + k-th highest influence score for the j-th example in `inputs`. + `indices` is a `torch.long` tensor so that it can directly be used + to index other tensors. Each row of `influence_scores` contains the + influence scores for a different test example, in sorted order. In + particular, `influence_scores[i][j]` is the influence score of + example `indices[i][j]` in training dataset `train_dataset` + on example `i` in the test dataset represented by `inputs`. + """ + desc = ( + None + if not show_progress + else ( + ( + f"Using {self.get_name()} to perform computation for " + f'getting {"proponents" if proponents else "opponents"}. ' + "Processing training batches" + ) + ) + ) + return KMostInfluentialResults( + *_get_k_most_influential_helper( + self.train_dataloader, + functools.partial( + _influence_batch_intermediate_quantities_influence_function, self + ), + inputs, + k, + proponents, + show_progress, + desc, + ) + ) + + def _influence( + self, + inputs: Union[Tuple[Any, ...], DataLoader], + show_progress: bool = False, + ) -> Tensor: + r""" + Args: + + inputs (tuple): `inputs` is the test batch and is a tuple of + any, where the last element is assumed to be the labels for the + batch. That is, `model(*batch[0:-1])` produces the output for + `model`, and `batch[-1]` are the labels, if any. This is the same + assumption made for each batch yielded by training dataset + `train_dataset` - please see its documentation in `__init__` for + more details on the assumed structure of a batch. + show_progress (bool, optional): To compute the influence of examples in + training dataset `train_dataset`, we compute the influence + of each batch. If `show_progress` is true, the progress of this + computation will be displayed. In particular, the number of batches + for which influence has been computed will be displayed. It will + try to use tqdm if available for advanced features (e.g. time + estimation). Otherwise, it will fallback to a simple output of + progress. + Default: False + + Returns: + influence_scores (Tensor): Influence scores over the entire + training dataset `train_dataset`. Dimensionality is + (inputs_batch_size, src_dataset_size). For example: + influence_scores[i][j] = the influence score for the j-th training + example to the i-th example in the test dataset. + """ + # turn inputs and targets into a dataset. inputs has already been processed + # so that it should always be unpacked + inputs_dataset = _format_inputs_dataset(inputs) + return _influence_helper_intermediate_quantities_influence_function( + self, inputs_dataset, show_progress + ) + + @log_usage(skip_self_logging=True) + def self_influence( + self, + inputs_dataset: Optional[Union[Tuple[Any, ...], DataLoader]] = None, + show_progress: bool = False, + ) -> Tensor: + """ + Computes self influence scores for the examples in `inputs_dataset`, which is + either a single batch or a Pytorch `DataLoader` that yields batches. Therefore, + the computed self influence scores are *not* for the examples in training + dataset `train_dataset` (unlike when computing self influence scores using the + `influence` method). Note that if `inputs_dataset` is a single batch, this + will call `model` on that single batch, and if `inputs_dataset` yields + batches, this will call `model` on each batch that is yielded. Therefore, + please ensure that for both cases, the batch(es) that `model` is called + with are not too large, so that there will not be an out-of-memory error. + + Implementation-wise, the self-influence score for an example is simply the + squared norm of the example's "embedding" vector. Therefore, the implementation + leverages `compute_intermediate_quantities`. + + Args: + inputs_dataset (tuple or DataLoader): Either a single tuple of any, or a + `DataLoader`, where each batch yielded is a tuple of any. In + either case, the tuple represents a single batch, where the last + element is assumed to be the labels for the batch. That is, + `model(*batch[0:-1])` produces the output for `model`, + and `batch[-1]` are the labels, if any. This is the same + assumption made for each batch yielded by training dataset + `train_dataset`. + Default: None + show_progress (bool, optional): Computation of self influence scores can + take a long time if `inputs_dataset` represents many examples. If + `show_progress`is true, the progress of this computation will be + displayed. In particular, the number of batches for which + self influence scores have been computed will be displayed. It will + try to use tqdm if available for advanced features (e.g. time + estimation). Otherwise, it will fallback to a simple output of + progress. + Default: False + """ + return _self_influence_helper_intermediate_quantities_influence_function( + self, inputs_dataset, show_progress + ) + + +def _basic_computation_naive_influence_function( + influence_inst: InfluenceFunctionBase, + inputs: Tuple[Any, ...], + targets: Optional[Tensor] = None, + loss_fn: Optional[Union[Module, Callable]] = None, + reduction_type: Optional[str] = None, +) -> Tensor: + """ + This computes the per-example parameter gradients for a batch, flattened into a + 2D tensor where the first dimension is batch dimension. This is used by + `NaiveInfluenceFunction` which computes embedding vectors for each example by + projecting their parameter gradients. + """ + # `jacobians` contains one tensor for each parameter we compute jacobians for. + # the first dimension of each tensor is the batch dimension, and the remaining + # dimensions correspond to the parameter, so that for the tensor corresponding + # to parameter `p`, its shape is `(batch_size, *p.shape)` + jacobians = _compute_jacobian_sample_wise_grads_per_batch( + influence_inst, inputs, targets, loss_fn, reduction_type + ) + + return torch.stack( + [ + _flatten_params(tuple(jacobian[i] for jacobian in jacobians)) + for i in range(len(next(iter(jacobians)))) + ], + dim=0, + ) diff --git a/captum/influence/_core/tracincp.py b/captum/influence/_core/tracincp.py index c66e8ca7b4..5ef223d43c 100644 --- a/captum/influence/_core/tracincp.py +++ b/captum/influence/_core/tracincp.py @@ -4,34 +4,24 @@ import warnings from abc import abstractmethod from os.path import join -from typing import ( - Any, - Callable, - Iterator, - List, - NamedTuple, - Optional, - Tuple, - Type, - Union, -) +from typing import Any, Callable, Iterator, List, Optional, Tuple, Type, Union import torch from captum._utils.av import AV -from captum._utils.common import _get_module_from_name, _parse_version -from captum._utils.gradient import ( - _compute_jacobian_wrt_params, - _compute_jacobian_wrt_params_with_sample_wise_trick, -) +from captum._utils.common import _parse_version from captum._utils.progress import NullProgress, progress from captum.influence._core.influence import DataInfluence from captum.influence._utils.common import ( _check_loss_fn, + _compute_jacobian_sample_wise_grads_per_batch, _format_inputs_dataset, _get_k_most_influential_helper, _gradient_dot_product, + _influence_route_to_helpers, _load_flexible_state_dict, _self_influence_by_batches_helper, + _set_active_parameters, + KMostInfluentialResults, ) from captum.log import log_usage from torch import Tensor @@ -69,24 +59,6 @@ """ -class KMostInfluentialResults(NamedTuple): - """ - This namedtuple stores the results of using the `influence` method. This method - is implemented by all subclasses of `TracInCPBase` to calculate - proponents / opponents. The `indices` field stores the indices of the - proponents / opponents for each example in the test dataset. For example, if - finding opponents, `indices[i][j]` stores the index in the training data of the - example with the `j`-th highest influence score on the `i`-th example in the test - dataset. Similarly, the `influence_scores` field stores the actual influence - scores, so that `influence_scores[i][j]` is the influence score of example - `indices[i][j]` in the training data on example `i` of the test dataset. - Please see `TracInCPBase.influence` for more details. - """ - - indices: Tensor - influence_scores: Tensor - - class TracInCPBase(DataInfluence): """ To implement the `influence` method, classes inheriting from `TracInCPBase` will @@ -448,34 +420,6 @@ def get_name(cls: Type["TracInCPBase"]) -> str: return cls.__name__ -def _influence_route_to_helpers( - influence_instance: TracInCPBase, - inputs: Union[Tuple[Any, ...], DataLoader], - k: Optional[int] = None, - proponents: bool = True, - **kwargs, -) -> Union[Tensor, KMostInfluentialResults]: - """ - This is a helper function called by `TracInCP.influence` and - `TracInCPFast.influence`. Those methods share a common logic in that they assume - an instance of their respective classes implement 2 private methods - (``_influence`, `_get_k_most_influential`), and the logic of - which private method to call is common, as described in the documentation of the - `influence` method. The arguments and return values of this function are the exact - same as the `influence` method. Note that `influence_instance` refers to the - instance for which the `influence` method was called. - """ - if k is None: - return influence_instance._influence(inputs, **kwargs) - else: - return influence_instance._get_k_most_influential( - inputs, - k, - proponents, - **kwargs, - ) - - class TracInCP(TracInCPBase): def __init__( self, @@ -630,23 +574,7 @@ def __init__( """ self.layer_modules = None if layers is not None: - assert isinstance(layers, List), "`layers` should be a list!" - assert len(layers) > 0, "`layers` cannot be empty!" - assert isinstance( - layers[0], str - ), "`layers` should contain str layer names." - self.layer_modules = [ - _get_module_from_name(self.model, layer) for layer in layers - ] - for layer, layer_module in zip(layers, self.layer_modules): - for name, param in layer_module.named_parameters(): - if not param.requires_grad: - warnings.warn( - "Setting required grads for layer: {}, name: {}".format( - ".".join(layer), name - ) - ) - param.requires_grad = True + self.layer_modules = _set_active_parameters(model, layers) @log_usage() def influence( # type: ignore[override] @@ -1463,19 +1391,6 @@ def _basic_computation_tracincp( argument is only used if `sample_wise_grads_per_batch` was true in initialization. """ - if self.sample_wise_grads_per_batch: - return _compute_jacobian_wrt_params_with_sample_wise_trick( - self.model, - inputs, - targets, - loss_fn, - reduction_type, - self.layer_modules, - ) - return _compute_jacobian_wrt_params( - self.model, - inputs, - targets, - loss_fn, - self.layer_modules, + return _compute_jacobian_sample_wise_grads_per_batch( + self, inputs, targets, loss_fn, reduction_type ) diff --git a/captum/influence/_utils/common.py b/captum/influence/_utils/common.py index 4acfabcd42..ce3a4a23b7 100644 --- a/captum/influence/_utils/common.py +++ b/captum/influence/_utils/common.py @@ -1,19 +1,38 @@ #!/usr/bin/env python3 import warnings -from typing import Any, Callable, List, Optional, Tuple, TYPE_CHECKING, Union +from functools import reduce +from typing import ( + Any, + Callable, + Iterable, + List, + NamedTuple, + Optional, + Tuple, + TYPE_CHECKING, + Union, +) import torch import torch.nn as nn -from captum._utils.common import _parse_version +from captum._utils.common import _get_module_from_name, _parse_version +from captum._utils.gradient import ( + _compute_jacobian_wrt_params, + _compute_jacobian_wrt_params_with_sample_wise_trick, +) from captum._utils.progress import progress -if TYPE_CHECKING: - from captum.influence._core.tracincp import TracInCPBase - from torch import Tensor from torch.nn import Module from torch.utils.data import DataLoader, Dataset +if TYPE_CHECKING: + from captum.influence._core.influence_function import ( + InfluenceFunctionBase, + IntermediateQuantitiesInfluenceFunction, + ) + from captum.influence._core.tracincp import TracInCP, TracInCPBase + def _tensor_batch_dot(t1: Tensor, t2: Tensor) -> Tensor: r""" @@ -422,7 +441,7 @@ def _self_influence_by_batches_helper( def _check_loss_fn( - influence_instance: "TracInCPBase", + influence_instance: Union["TracInCPBase", "InfluenceFunctionBase"], loss_fn: Optional[Union[Module, Callable]], loss_fn_name: str, sample_wise_grads_per_batch: Optional[bool] = None, @@ -505,3 +524,503 @@ def _check_loss_fn( ) return reduction_type + + +def _set_active_parameters(model: Module, layers: List[str]) -> List[Module]: + """ + sets relevant parameters, as indicated by `layers`, to have `requires_grad=True`, + and returns relevant modules. + """ + assert isinstance(layers, List), "`layers` should be a list!" + assert len(layers) > 0, "`layers` cannot be empty!" + assert isinstance(layers[0], str), "`layers` should contain str layer names." + layer_modules = [_get_module_from_name(model, layer) for layer in layers] + for layer, layer_module in zip(layers, layer_modules): + for name, param in layer_module.named_parameters(): + if not param.requires_grad: + warnings.warn( + "Setting required grads for layer: {}, name: {}".format( + ".".join(layer), name + ) + ) + param.requires_grad = True + return layer_modules + + +def _progress_bar_constructor( + influence_inst: "InfluenceFunctionBase", + inputs_dataset: DataLoader, + quantities_name: str, + dataset_name: str = "inputs_dataset", +): + # Try to determine length of progress bar if possible, with a default + # of `None`. + inputs_dataset_len = None + try: + inputs_dataset_len = len(inputs_dataset) + except TypeError: + warnings.warn( + f"Unable to determine the number of batches in " + f"`{dataset_name}`. Therefore, if showing the progress " + f"of the computation of {quantities_name}, " + "only the number of batches processed can be " + "displayed, and not the percentage completion of the computation, " + "nor any time estimates." + ) + + return progress( + inputs_dataset, + desc=( + f"Using {influence_inst.get_name()} to compute {quantities_name}. " + "Processing batch" + ), + total=inputs_dataset_len, + ) + + +def _params_to_names(params: Iterable[nn.Parameter], model: nn.Module) -> List[str]: + """ + Given an iterable of parameters, `params` of a model, `model`, returns the names of + the parameters from the perspective of `model`. This is useful if, given + parameters for which we do not know the name, want to pass them as a dict + to a function of those parameters, i.e. `torch.nn.utils._stateless`. + """ + param_id_to_name = { + id(param): param_name for (param_name, param) in model.named_parameters() + } + return [param_id_to_name[id(param)] for param in params] + + +def _flatten_params(_params: Tuple[Tensor, ...]) -> Tensor: + """ + Given a tuple of tensors, which is how Pytorch represents parameters of a model, + flattens it into a single tensor. This is useful if we want to do matrix operations + on the parameters of a model, i.e. invert its Hessian, or compute dot-product of + parameter-gradients. Note that flattening and then passing to standard linear + algebra operations may not be the most efficient way to perform them. + """ + return torch.cat([_param.view(-1) for _param in _params]) + + +def _unflatten_params_factory( + param_shapes: Union[List[Tuple[int, ...]], Tuple[Tensor, ...]] +): + """ + returns a function which is the inverse of `_flatten_params` + """ + + def _unflatten_params(flattened_params): + params = [] + offset = 0 + for shape in param_shapes: + length = 1 + for s in shape: + length *= s + params.append(flattened_params[offset : offset + length].view(shape)) + offset += length + return tuple(params) + + return _unflatten_params + + +def _influence_batch_intermediate_quantities_influence_function( + influence_inst: "IntermediateQuantitiesInfluenceFunction", + test_batch: Tuple[Any, ...], + train_batch: Tuple[Any, ...], +): + """ + computes influence of a test batch on a train batch, for implementations of + `IntermediateQuantitiesInfluenceFunction` + """ + return torch.matmul( + influence_inst.compute_intermediate_quantities(test_batch), + influence_inst.compute_intermediate_quantities(train_batch).T, + ) + + +def _influence_helper_intermediate_quantities_influence_function( + influence_inst: "IntermediateQuantitiesInfluenceFunction", + inputs_dataset: Union[Tuple[Any, ...], DataLoader], + show_progress: bool, +): + """ + Helper function that computes influence scores for implementations of + `NaiveInfluenceFunction` which implement the `compute_intermediate_quantities` + method returning "embedding" vectors, so that the influence score of one example + on another is the dot-product of their vectors. + """ + # If `inputs_dataset` is not a `DataLoader`, turn it into one. + inputs_dataset = _format_inputs_dataset(inputs_dataset) + + inputs_intermediate_quantities = influence_inst.compute_intermediate_quantities( + inputs_dataset, + show_progress=show_progress, + test=True, + ) + + train_dataloader = influence_inst.train_dataloader + if show_progress: + train_dataloader = _progress_bar_constructor( + influence_inst, train_dataloader, "train_dataset", "influence scores" + ) + + return torch.cat( + [ + torch.matmul( + inputs_intermediate_quantities, + influence_inst.compute_intermediate_quantities(batch).T, + ) + for batch in train_dataloader + ], + dim=1, + ) + + +def _self_influence_helper_intermediate_quantities_influence_function( + influence_inst: "IntermediateQuantitiesInfluenceFunction", + inputs_dataset: Optional[Union[Tuple[Any, ...], DataLoader]], + show_progress: bool, +): + """ + Helper function that computes self-influence scores for implementations of + `NaiveInfluenceFunction` which implement the `compute_intermediate_quantities` + method returning "embedding" vectors, so that the self-influence score of an + example is the squared norm of its vector. + """ + + inputs_dataset = ( + inputs_dataset + if inputs_dataset is not None + else influence_inst.train_dataloader + ) + + # If `inputs_dataset` is not a `DataLoader`, turn it into one. + inputs_dataset = _format_inputs_dataset(inputs_dataset) + + if show_progress: + inputs_dataset = _progress_bar_constructor( + influence_inst, inputs_dataset, "inputs_dataset", "self influence scores" + ) + + return torch.cat( + [ + torch.sum( + influence_inst.compute_intermediate_quantities( + batch, + show_progress=False, + ) + ** 2, + dim=1, + ) + for batch in inputs_dataset + ] + ) + + +def _eig_helper(H: Tensor): + """ + wrapper around `torch.linalg.eig` that sorts eigenvalues / eigenvectors by + ascending eigenvalues, like `torch.linalg.eigh`, and returns the real component + (since `H` is never complex, there should never be a complex component. however, + `torch.linalg.eig` always returns a complex tensor, which in this case would + actually have no complex component) + """ + version = _parse_version(torch.__version__) + if version < (1, 9): + ls, vs = torch.eig(H, eigenvectors=True) + ls = ls[:, 0] + else: + ls, vs = torch.linalg.eig(H) + ls, vs = ls.real, vs.real + + ls_argsort = torch.argsort(ls) + vs = vs[:, ls_argsort] + ls = ls[ls_argsort] + return ls, vs + + +def _top_eigen( + H: Tensor, k: Optional[int], hessian_reg: float, hessian_inverse_tol: float +) -> Tuple[Tensor, Tensor]: + """ + This is a wrapper around `torch.linalg.eig` that performs some pre / + post-processing to make it suitable for computing the low-rank + "square root" of a matrix, i.e. given square matrix H, find tall and + skinny L such that LL' approximates H. This function returns eigenvectors (as the + columns of a matrix Q) and corresponding eigenvectors (as diagonal entries in + a matrix V), and we can then let L=QV^{1/2}Q'. However, doing so requires the + eigenvalues in V to be positive. Thus, this function does pre-processing (adds + an entry to the diagonal of H) and post-processing (returns only the top-k + eigenvectors / eigenvalues where the eigenvalues are above a positive tolerance) + to encourage and guarantee, respectively, that the returned eigenvalues be + positive. The pre-processing shifts the eigenvalues up by a constant, and the + post-processing effectively replaces H with the most similar matrix (in terms of + Frobenius norm) whose eigenvalues are above the tolerance, see + https://nhigham.com/2021/01/26/what-is-the-nearest-positive-semidefinite-matrix/. + + Args: + H (Tensor): a 2D square Tensor for which the top eigenvectors / eigenvalues + will be computed. + k (int): how many eigenvectors / eigenvalues to return (before dropping pairs + whose eigenvalue is below the tolerance). + hessian_reg (float): We add an entry to the diagonal of `H` to encourage it to + be positive definite. This is that entry. + hessian_inverse_tol (float): To compute the "square root" of `H` using the top + eigenvectors / eigenvalues, the eigenvalues should be positive, and + furthermore if above a tolerance, the inversion will be more + numerically stable. Therefore, we only return eigenvectors / + eigenvalues where the eigenvalue is above a tolerance. This argument + specifies that tolerance. + + Returns: + (eigenvalues, eigenvectors) (tuple of tensors): Mimicking the output of + `torch.linalg.eigh`, `eigenvalues` is a 1D tensor of the top-k + eigenvalues of the regularized `H` that are additionally above + `hessian_inverse_tol`, and `eigenvectors` is a 2D tensor whose columns + contain the corresponding eigenvectors. The eigenvalues are in + ascending order. + """ + # add regularization to hopefully make H positive definite + H = H + (torch.eye(len(H)).to(device=H.device) * hessian_reg) + + # find eigvectors / eigvals of H + # ls are eigenvalues, in ascending order + # columns of vs are corresponding eigenvectors + ls, vs = _eig_helper(H) + + # despite adding regularization to the hessian, it may still not be positive + # definite. we can get rid of negative eigenvalues, but for numerical stability + # can get rid of eigenvalues below a tolerance + keep = ls > hessian_inverse_tol + + ls = ls[keep] + vs = vs[:, keep] + + # only keep the top `k` eigvals / eigvectors + if not (k is None): + ls = ls[-k:] + vs = vs[:, -k:] + + # `torch.linalg.eig` is not deterministic in that you can multiply an eigenvector + # by -1, and it is still an eigenvector. to make eigenvectors deterministic, + # we multiply an eigenvector according to some rule that flips if you multiply + # the eigenvector by -1. in this case, that rule is whether the sum of the + # entries of the eigenvector are > 0 + rule = torch.sum(vs, dim=0) > 0 # entries are 0/1 + rule_multiplier = (2 * rule) - 1 # entries are -1/1 + vs = vs * rule_multiplier.unsqueeze(0) + + return ls, vs + + +class KMostInfluentialResults(NamedTuple): + """ + This namedtuple stores the results of using the `influence` method. This method + is implemented by all subclasses of `TracInCPBase` to calculate + proponents / opponents. The `indices` field stores the indices of the + proponents / opponents for each example in the test batch. For example, if finding + opponents, `indices[i][j]` stores the index in the training data of the example + with the `j`-th highest influence score on the `i`-th example in the test batch. + Similarly, the `influence_scores` field stores the actual influence scores, so that + `influence_scores[i][j]` is the influence score of example `indices[i][j]` in the + training data on example `i` of the test batch. Please see `TracInCPBase.influence` + for more details. + """ + + indices: Tensor + influence_scores: Tensor + + +def _influence_route_to_helpers( + influence_instance: Union["TracInCPBase", "InfluenceFunctionBase"], + inputs: Union[Tuple[Any, ...], DataLoader], + k: Optional[int] = None, + proponents: bool = True, + **kwargs, +) -> Union[Tensor, KMostInfluentialResults]: + """ + This is a helper function called by `TracInCPBase` and `InfluenceFunctionBase` + implementations. Those methods share a common logic in that they assume + an instance of their respective classes implement 2 private methods + (``_influence`, `_get_k_most_influential`), and the logic of + which private method to call is common, as described in the documentation of the + `influence` method. The arguments and return values of this function are the exact + same as the `influence` method. Note that `influence_instance` refers to the + instance for which the `influence` method was called. + """ + if k is None: + return influence_instance._influence(inputs, **kwargs) + else: + return influence_instance._get_k_most_influential( + inputs, + k, + proponents, + **kwargs, + ) + + +def _compute_jacobian_sample_wise_grads_per_batch( + influence_inst: Union["TracInCP", "InfluenceFunctionBase"], + inputs: Tuple[Any, ...], + targets: Optional[Tensor] = None, + loss_fn: Optional[Union[Module, Callable]] = None, + reduction_type: Optional[str] = "none", +) -> Tuple[Tensor, ...]: + """ + `TracInCP`, `InfluenceFunction`, and `ArnoldiInfluenceFunction` all compute + jacobians, depending on their `sample_wise_grads_per_batch` attribute. this helper + wraps that logic. + """ + + if influence_inst.sample_wise_grads_per_batch: + return _compute_jacobian_wrt_params_with_sample_wise_trick( + influence_inst.model, + inputs, + targets, + loss_fn, + reduction_type, + influence_inst.layer_modules, + ) + return _compute_jacobian_wrt_params( + influence_inst.model, + inputs, + targets, + loss_fn, + influence_inst.layer_modules, + ) + + +def _compute_batch_loss_influence_function_base( + loss_fn: Optional[Union[Module, Callable]], + input: Any, + target: Any, + reduction_type: str, +): + """ + In implementations of `InfluenceFunctionBase`, we need to compute the total loss + for a batch given `loss_fn`, whose reduction can either be 'none', 'sum', or + 'mean', and whose output requires different scaling based on the reduction. This + helper houses that common logic, and returns the total loss for a batch given the + predictions (`inputs`) and labels (`targets`) for it. We compute the total loss + in order to compute the Hessian. + """ + if loss_fn is not None: + _loss = loss_fn(input, target) + else: + # following convention of `_compute_jacobian_wrt_params`, is no loss function is + # provided, the quantity backpropped is the output of the forward function. + assert reduction_type == "none" + _loss = input + + if reduction_type == "none": + # if loss_fn is a "reduction='none'" loss function, need to sum + # up the per-example losses. + return torch.sum(_loss) + elif reduction_type == "mean": + # in this case, we want the total loss for the batch, and should + # multiply the mean loss for the batch by the batch size. however, + # we can only infer the batch size if `_output` is a Tensor, and + # we assume the 0-th dimension to be the batch dimension. + if isinstance(input, Tensor): + multiplier = input.shape[0] + else: + multiplier = 1 + msg = ( + "`loss_fn` was inferred to behave as a `reduction='mean'` " + "loss function. however, the batch size of batches could not " + "be inferred. therefore, the total loss of a batch, which is " + "needed to compute the Hessian, is approximated as the output " + "of `loss_fn` for the batch. if this approximation is not " + "accurate, please change `loss_fn` to behave as a " + "`reduction='sum'` loss function, or a `reduction='none'` " + "and set `sample_grads_per_batch` to false." + ) + warnings.warn(msg) + return _loss * multiplier + elif reduction_type == "sum": + return _loss + else: + # currently, only support `reduction_type` to be + # 'none', 'sum', or 'mean' for + # `InfluenceFunctionBase` implementations + raise Exception + + +def _set_attr(obj, names, val): + if len(names) == 1: + setattr(obj, names[0], val) + else: + _set_attr(getattr(obj, names[0]), names[1:], val) + + +def _del_attr(obj, names): + if len(names) == 1: + delattr(obj, names[0]) + else: + _del_attr(getattr(obj, names[0]), names[1:]) + + +def _model_make_functional(model, param_names, params): + params = tuple([param.detach().requires_grad_() for param in params]) + + for param_name in param_names: + _del_attr(model, param_name.split(".")) + + return params + + +def _model_reinsert_params(model, param_names, params, register=False): + for (param_name, param) in zip(param_names, params): + _set_attr( + model, + param_name.split("."), + torch.nn.Parameter(param) if register else param, + ) + + +def _custom_functional_call(model, d, features): + param_names, params = zip(*list(d.items())) + _params = _model_make_functional(model, param_names, params) + _model_reinsert_params(model, param_names, params) + out = model(*features) + _model_reinsert_params(model, param_names, _params, register=True) + return out + + +def _functional_call(model, d, features): + """ + Makes a call to `model.forward`, which is treated as a function of the parameters + in `d`, a dict from parameter name to parameter, instead of as a function of + `features`, the argument that is unpacked to `model.forward` (i.e. + `model.forward(*features)`). Depending on what version of PyTorch is available, + we either use our own implementation, or directly use `torch.nn.utils.stateless` + or `torch.func.functional_call`. Put another way, this function mimics the latter + two implementations, using our own when the PyTorch version is too old. + """ + import torch + + version = _parse_version(torch.__version__) + if version < (1, 12, 0): + return _custom_functional_call(model, d, features) + elif version >= (1, 12, 0) and version < (2, 0, 0): + import torch.nn.utils.stateless + + return torch.nn.utils.stateless.functional_call(model, d, features) + else: + import torch.func + + return torch.func.functional_call(model, d, features) + + +def _dataset_fn(dataloader, batch_fn, reduce_fn, *batch_fn_args, **batch_fn_kwargs): + """ + Applies `batch_fn` to each batch in `dataloader`, reducing the results using + `reduce_fn`. This is useful for computing Hessians over an entire dataloader. + """ + _dataloader = iter(dataloader) + + def _reduce_fn(_result, _batch): + return reduce_fn(_result, batch_fn(_batch, *batch_fn_args, **batch_fn_kwargs)) + + result = batch_fn(next(_dataloader), *batch_fn_args, **batch_fn_kwargs) + return reduce(_reduce_fn, _dataloader, result) diff --git a/tests/influence/_core/test_naive_influence.py b/tests/influence/_core/test_naive_influence.py new file mode 100644 index 0000000000..9d2b6ad63e --- /dev/null +++ b/tests/influence/_core/test_naive_influence.py @@ -0,0 +1,281 @@ +import tempfile +from typing import Callable, List, Tuple + +import torch + +import torch.nn as nn +from captum._utils.common import _parse_version +from captum.influence._core.influence_function import NaiveInfluenceFunction +from captum.influence._utils.common import ( + _custom_functional_call, + _flatten_params, + _functional_call, + _unflatten_params_factory, +) +from parameterized import parameterized +from tests.helpers.basic import ( + assertTensorAlmostEqual, + assertTensorTuplesAlmostEqual, + BaseTest, +) +from tests.influence._utils.common import ( + _format_batch_into_tuple, + build_test_name_func, + DataInfluenceConstructor, + ExplicitDataset, + get_random_model_and_data, + Linear, + UnpackDataset, + USE_GPU_LIST, +) +from torch.utils.data import DataLoader + + +class TestNaiveInfluence(BaseTest): + @parameterized.expand( + [ + (param_shape,) + for param_shape in [ + [(2, 3), (4, 5)], + [(3, 2), (4, 2), (1, 5)], + ] + ], + name_func=build_test_name_func(), + ) + def test_flatten_unflattener(self, param_shapes: List[Tuple[int, ...]]): + # unflatten and flatten should be inverses of each other. check this holds. + _unflatten_params = _unflatten_params_factory(param_shapes) + params = tuple(torch.randn(shape) for shape in param_shapes) + assertTensorTuplesAlmostEqual( + self, + params, + _unflatten_params(_flatten_params(params)), + delta=1e-4, + mode="max", + ) + + @parameterized.expand( + [ + ( + reduction, + influence_constructor, + delta, + mode, + unpack_inputs, + use_gpu, + ) + for reduction in ["none", "sum", "mean"] + for use_gpu in USE_GPU_LIST + for (influence_constructor, delta) in [ + ( + DataInfluenceConstructor( + NaiveInfluenceFunction, + layers=["module.linear"] + if use_gpu == "cuda_dataparallel" + else ["linear"], + projection_dim=None, + # letting projection_dim is None means no projection is done, + # in which case exact influence is returned + show_progress=False, + ), + 1e-3, + ), + ( + DataInfluenceConstructor( + NaiveInfluenceFunction, + layers=None, + # this tests that not specifyiing layers still works + projection_dim=None, + show_progress=False, + name="NaiveInfluenceFunction_all_layers", + ), + 1e-3, + ), + ] + for mode in [ + "influence", + "self_influence", + ] + for unpack_inputs in [ + False, + True, + ] + ], + name_func=build_test_name_func(), + ) + def test_matches_linear_regression( + self, + reduction: str, + influence_constructor: Callable, + delta: float, + mode: str, + unpack_inputs: bool, + use_gpu: bool, + ): + """ + this tests that `NaiveInfluence`, the simplest implementation, agree with the + analytically calculated solution for influence and self-influence for a model + where we can calculate that solution - linear regression trained with squared + error loss. + """ + with tempfile.TemporaryDirectory() as tmpdir: + ( + net, + train_dataset, + hessian_samples, + hessian_labels, + test_samples, + test_labels, + ) = get_random_model_and_data( + tmpdir, + unpack_inputs, + return_test_data=True, + use_gpu=use_gpu, + return_hessian_data=True, + model_type="trained_linear", + ) + + train_dataset = DataLoader(train_dataset, batch_size=5) + + hessian_dataset = ( + ExplicitDataset(hessian_samples, hessian_labels, use_gpu) + if not unpack_inputs + else UnpackDataset(hessian_samples, hessian_labels, use_gpu) + ) + hessian_dataset = DataLoader(hessian_dataset, batch_size=5) + + criterion = nn.MSELoss(reduction=reduction) + batch_size = None + + # set `sample_grads_per_batch` based on `reduction` to be compatible + sample_wise_grads_per_batch = False if reduction == "none" else True + + influence = influence_constructor( + net, + train_dataset, + tmpdir, + batch_size, + criterion, + sample_wise_grads_per_batch=sample_wise_grads_per_batch, + hessian_dataset=hessian_dataset, + ) + + # since the model is a linear regression model trained with MSE loss, we + # can calculate the hessian and per-example parameter gradients + # analytically + tensor_hessian_samples = ( + hessian_samples + if not unpack_inputs + else torch.cat(hessian_samples, dim=1) + ) + # hessian at optimal parameters is 2 * X'X, where X is the feature matrix + # of the examples used for calculating the hessian. + # this is based on https://math.stackexchange.com/questions/2864585/hessian-on-linear-least-squares-problem # noqa: E501 + # and multiplying by 2, since we optimize squared error, + # not 1/2 squared error. + hessian = torch.matmul(tensor_hessian_samples.T, tensor_hessian_samples) * 2 + hessian = hessian + ( + torch.eye(len(hessian)).to(device=hessian.device) * 1e-4 + ) + version = _parse_version(torch.__version__) + if version < (1, 8): + hessian_inverse = torch.pinverse(hessian, rcond=1e-4) + else: + hessian_inverse = torch.linalg.pinv(hessian, rcond=1e-4) + + # gradient for an example is 2 * features * error + + # compute train gradients + tensor_train_samples = torch.cat( + [torch.cat(batch[:-1], dim=1) for batch in train_dataset], dim=0 + ) + train_predictions = torch.cat( + [net(*batch[:-1]) for batch in train_dataset], dim=0 + ) + train_labels = torch.cat([batch[-1] for batch in train_dataset], dim=0) + train_gradients = ( + (train_predictions - train_labels) * tensor_train_samples * 2 + ) + + # compute test gradients + tensor_test_samples = ( + test_samples if not unpack_inputs else torch.cat(test_samples, dim=1) + ) + test_predictions = ( + net(test_samples) if not unpack_inputs else net(*test_samples) + ) + test_gradients = (test_predictions - test_labels) * tensor_test_samples * 2 + + if mode == "influence": + # compute pairwise influences, analytically + analytical_train_test_influences = torch.matmul( + torch.matmul(test_gradients, hessian_inverse), train_gradients.T + ) + # compute pairwise influences using influence implementation + influence_train_test_influences = influence.influence( + _format_batch_into_tuple(test_samples, test_labels, unpack_inputs) + ) + # check error + assertTensorAlmostEqual( + self, + influence_train_test_influences, + analytical_train_test_influences, + delta=delta, + mode="max", + ) + elif mode == "self_influence": + # compute self influence, analytically + analytical_self_influences = torch.diag( + torch.matmul( + torch.matmul(train_gradients, hessian_inverse), + train_gradients.T, + ) + ) + # compute pairwise influences using influence implementation + influence_self_influences = influence.self_influence(train_dataset) + # check error + assertTensorAlmostEqual( + self, + influence_self_influences, + analytical_self_influences, + delta=delta, + mode="max", + ) + else: + raise Exception("unknown test mode") + + @parameterized.expand( + [(_custom_functional_call,), (_functional_call,)], + name_func=build_test_name_func(), + ) + def test_functional_call(self, method): + """ + tests `influence._utils.common._functional_call` for a simple case where the + model and loss are linear regression and squared error. `method` can either be + `_custom_functional_call`, which uses the custom implementation that is used + if pytorch does not provide one, or `_functional_call`, which uses a pytorch + implementation if available. + """ + # get linear model and a batch + batch_size = 25 + num_features = 5 + batch_samples = torch.normal(0, 1, (batch_size, num_features)) + batch_labels = torch.normal(0, 1, (batch_size, 1)) + net = Linear(num_features) + + # get the analytical gradient wrt to model parameters + batch_predictions = net(batch_samples) + analytical_grad = 2 * torch.sum( + (batch_predictions - batch_labels) * batch_samples, dim=0 + ) + + # get gradient as computed using `_functional_call` + param = net.linear.weight.detach().clone().requires_grad_(True) + _batch_predictions = method(net, {"linear.weight": param}, (batch_samples,)) + loss = torch.sum((_batch_predictions - batch_labels) ** 2) + actual_grad = torch.autograd.grad(loss, param)[0][0] + + # they should be the same + assertTensorAlmostEqual( + self, actual_grad, analytical_grad, delta=1e-3, mode="max" + ) diff --git a/tests/influence/_core/test_tracin_intermediate_quantities.py b/tests/influence/_core/test_tracin_intermediate_quantities.py index b65cd0225c..10ad17be02 100644 --- a/tests/influence/_core/test_tracin_intermediate_quantities.py +++ b/tests/influence/_core/test_tracin_intermediate_quantities.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn +from captum.influence._core.influence_function import NaiveInfluenceFunction from captum.influence._core.tracincp import TracInCP from captum.influence._core.tracincp_fast_rand_proj import ( TracInCPFast, @@ -27,6 +28,7 @@ class TestTracInIntermediateQuantities(BaseTest): for unpack_inputs in [True, False] for (reduction, constructor) in [ ("none", DataInfluenceConstructor(TracInCP)), + ("none", DataInfluenceConstructor(NaiveInfluenceFunction)), ] ], name_func=build_test_name_func(), @@ -83,6 +85,7 @@ def test_tracin_intermediate_quantities_aggregate( for (reduction, constructor) in [ ("sum", DataInfluenceConstructor(TracInCPFastRandProj)), ("none", DataInfluenceConstructor(TracInCP)), + ("none", DataInfluenceConstructor(NaiveInfluenceFunction)), ] ], name_func=build_test_name_func(), @@ -166,6 +169,11 @@ def test_tracin_intermediate_quantities_api( DataInfluenceConstructor(TracInCP), DataInfluenceConstructor(TracInCP), ), + ( + "none", + DataInfluenceConstructor(NaiveInfluenceFunction), + DataInfluenceConstructor(NaiveInfluenceFunction), + ), ] ], name_func=build_test_name_func(), @@ -190,7 +198,9 @@ def test_tracin_intermediate_quantities_consistent( methods for the 2 cases are different, we need to parametrize the test with 2 different tracin constructors. `tracin_constructor` is the constructor for the tracin implementation for case 1. `intermediate_quantities_tracin_constructor` - is the constructor for the tracin implementation for case 2. + is the constructor for the tracin implementation for case 2. Note that we also + use this test for implementations of `InfluenceFunctionBase`, where for the + same method, both ways should give the same result by definition. """ with tempfile.TemporaryDirectory() as tmpdir: ( diff --git a/tests/influence/_core/test_tracin_regression.py b/tests/influence/_core/test_tracin_regression.py index 27b6ec9f5d..c16b1243f5 100644 --- a/tests/influence/_core/test_tracin_regression.py +++ b/tests/influence/_core/test_tracin_regression.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn +from captum.influence._core.influence_function import NaiveInfluenceFunction from captum.influence._core.tracincp import TracInCP from captum.influence._core.tracincp_fast_rand_proj import ( TracInCPFast, @@ -340,6 +341,7 @@ def _test_tracin_identity_regression_setup(self, tmpdir: str): ("check_idx", "sum", DataInfluenceConstructor(TracInCPFastRandProj)), ("check_idx", "mean", DataInfluenceConstructor(TracInCPFast)), ("check_idx", "mean", DataInfluenceConstructor(TracInCPFastRandProj)), + ("check_idx", "none", DataInfluenceConstructor(NaiveInfluenceFunction)), ], name_func=build_test_name_func(), ) @@ -435,6 +437,7 @@ def test_tracin_identity_regression( ("mean", "mean", DataInfluenceConstructor(TracInCPFast)), ("sum", "sum", DataInfluenceConstructor(TracInCPFastRandProj)), ("mean", "mean", DataInfluenceConstructor(TracInCPFastRandProj)), + ("none", "none", DataInfluenceConstructor(NaiveInfluenceFunction)), ], name_func=build_test_name_func(), ) diff --git a/tests/influence/_core/test_tracin_self_influence.py b/tests/influence/_core/test_tracin_self_influence.py index 767aed6b02..2b4c962161 100644 --- a/tests/influence/_core/test_tracin_self_influence.py +++ b/tests/influence/_core/test_tracin_self_influence.py @@ -3,7 +3,8 @@ import torch import torch.nn as nn -from captum.influence._core.tracincp import TracInCP +from captum.influence._core.influence_function import NaiveInfluenceFunction +from captum.influence._core.tracincp import TracInCP, TracInCPBase from captum.influence._core.tracincp_fast_rand_proj import TracInCPFast from parameterized import parameterized from tests.helpers.basic import assertTensorAlmostEqual, BaseTest @@ -18,13 +19,18 @@ class TestTracInSelfInfluence(BaseTest): + param_list = [] + + # add the tests for `TracInCPBase` implementations and `InfluenceFunctionBase` + # implementations separately, because the latter does not support `DataParallel` + + # add tests for `TracInCPBase` implementations use_gpu_list = ( [False, "cuda", "cuda_data_parallel"] if torch.cuda.is_available() and torch.cuda.device_count() != 0 else [False] ) - param_list = [] for unpack_inputs in [True, False]: for use_gpu in use_gpu_list: for (reduction, constructor) in [ @@ -80,6 +86,40 @@ class TestTracInSelfInfluence(BaseTest): ): param_list.append((reduction, constructor, unpack_inputs, use_gpu)) + # add tests for `InfluenceFunctionBase` implementations + use_gpu_list = ( + [False, "cuda"] + if torch.cuda.is_available() and torch.cuda.device_count() != 0 + else [False] + ) + + for unpack_inputs in [True, False]: + for use_gpu in use_gpu_list: + for (reduction, constructor) in [ + ( + "none", + DataInfluenceConstructor( + NaiveInfluenceFunction, name="NaiveInfluenceFunction_all_layers" + ), + ), + ( + "none", + DataInfluenceConstructor( + NaiveInfluenceFunction, + name="NaiveInfluenceFunction_linear1", + layers=["module.linear1"] + if use_gpu == "cuda_data_parallel" + else ["linear1"], + ), + ), + ]: + if not ( + "sample_wise_grads_per_batch" in constructor.kwargs + and constructor.kwargs["sample_wise_grads_per_batch"] + and use_gpu + ): + param_list.append((reduction, constructor, unpack_inputs, use_gpu)) + @parameterized.expand( param_list, name_func=build_test_name_func(), @@ -117,9 +157,7 @@ def test_tracin_self_influence( k=None, ) # calculate self_tracin_scores - self_tracin_scores = tracin.self_influence( - outer_loop_by_checkpoints=False, - ) + self_tracin_scores = tracin.self_influence() # check that self_tracin scores equals the diagonal of influence scores assertTensorAlmostEqual( @@ -132,17 +170,22 @@ def test_tracin_self_influence( # check that setting `outer_loop_by_checkpoints=False` and # `outer_loop_by_checkpoints=True` gives the same self influence scores - self_tracin_scores_by_checkpoints = tracin.self_influence( - DataLoader(train_dataset, batch_size=batch_size), - outer_loop_by_checkpoints=True, - ) - assertTensorAlmostEqual( - self, - self_tracin_scores_by_checkpoints, - self_tracin_scores, - delta=0.01, - mode="max", - ) + # this test is only relevant for implementations of `TracInCPBase`, as + # implementations of `InfluenceFunctionBase` do not use checkpoints. + if isinstance(tracin, TracInCPBase): + self_tracin_scores_by_checkpoints = ( + tracin.self_influence( # type: ignore + DataLoader(train_dataset, batch_size=batch_size), + outer_loop_by_checkpoints=True, + ) + ) + assertTensorAlmostEqual( + self, + self_tracin_scores_by_checkpoints, + self_tracin_scores, + delta=0.01, + mode="max", + ) @parameterized.expand( [ diff --git a/tests/influence/_utils/common.py b/tests/influence/_utils/common.py index 17fe5b46cb..f647d892de 100644 --- a/tests/influence/_utils/common.py +++ b/tests/influence/_utils/common.py @@ -2,12 +2,15 @@ import os import unittest from functools import partial -from typing import Callable, Iterator, List, Optional, Tuple, Union +from inspect import isfunction +from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F +from captum._utils.common import _parse_version from captum.influence import DataInfluence +from captum.influence._core.influence_function import NaiveInfluenceFunction from captum.influence._core.tracincp_fast_rand_proj import ( TracInCPFast, TracInCPFastRandProj, @@ -180,10 +183,73 @@ def forward(self, *inputs): return torch.tanh(self.linear2(x)) +class Linear(nn.Module): + """ + a wrapper around `nn.Linear`, with purpose being to have an analogue to + `UnpackLinear`, with both's only parameter being 'linear'. "infinitesimal" + influence (i.e. that calculated by `InfluenceFunctionBase` implementations) for + this simple module can be analytically calculated, so its purpose is for testing + those implementations. + """ + + def __init__(self, in_features): + super().__init__() + self.linear = nn.Linear(in_features, 1, bias=False) + + def forward(self, input): + return self.linear(input) + + +class UnpackLinear(nn.Module): + """ + the analogue of `Linear` which unpacks inputs, serving the same purpose. + """ + + def __init__(self, in_features, num_inputs) -> None: + super().__init__() + self.linear = nn.Linear(in_features * num_inputs, 1, bias=False) + + def forward(self, *inputs): + return self.linear(torch.cat(inputs, dim=1)) + + def get_random_model_and_data( - tmpdir, unpack_inputs, return_test_data=True, use_gpu=False + tmpdir, + unpack_inputs, + return_test_data=True, + use_gpu=False, + return_hessian_data=False, + model_type="random", ): """ + returns a model, training data, and optionally data for computing the hessian + (needed for `InfluenceFunctionBase` implementations) as features / labels, and + optionally test data as features / labels. + + the data is always generated the same way. however depending on `model_type`, + a different model and checkpoints are returned. + - `model_type='random'`: the model is a 2-layer NN, and several checkpoints are + generated + - `model_type='trained_linear'`: the model is a linear model, and assumed to be + eventually trained to optimality. therefore, we find the optimal parameters, and + save a single checkpoint containing them. the training is done using the Hessian + data, because the purpose of training the model is so that the Hessian is positive + definite. since the Hessian is calculated using the Hessian data, it should be + used for training. since it is trained to optimality using the Hessian data, we can + guarantee that the Hessian is positive definite, so that different + implementations of `InfluenceFunctionBase` can be more easily compared. (if the + Hessian is not positive definite, we drop eigenvectors corresponding to negative + eigenvalues. since the eigenvectors dropped in `ArnoldiInfluence` differ from those + in `NaiveInfluenceFunction` due to the formers' use of Arnoldi iteration, we should + only use models / data whose Hessian is positive definite, so that no eigenvectors + are dropped). in short, this model / data are suitable for comparing different + `InfluenceFunctionBase` implementations. + - `model_type='trained_NN'`: the model is a 2-layer NN, and trained (not + necessarily) to optimality using the Hessian data. since it is trained, for same + reasons as for `model_type='trained_linear`, different implementations of + `InfluenceFunctionBase` can be more easily compared, due to lack of numerical + issues. + `use_gpu` can either be - `False`: returned model is on cpu - `'cuda'`: returned model is on gpu @@ -192,57 +258,54 @@ def get_random_model_and_data( is that sometimes we may want to test a model that is on cpu, but is *not* wrapped in `DataParallel`. """ - assert use_gpu in [False, "cuda", "cuda_data_parallel"] - - in_features, hidden_nodes, out_features = 5, 4, 3 + in_features, hidden_nodes = 5, 4 num_inputs = 2 - net = ( - BasicLinearNet(in_features, hidden_nodes, out_features) - if not unpack_inputs - else MultLinearNet(in_features, hidden_nodes, out_features, num_inputs) - ).double() - - num_checkpoints = 5 - - for i in range(num_checkpoints): - net.linear1.weight.data = torch.normal( - 3, 4, (hidden_nodes, in_features) - ).double() - net.linear2.weight.data = torch.normal( - 5, 6, (out_features, hidden_nodes) - ).double() - if unpack_inputs: - net.pre.weight.data = torch.normal( - 3, 4, (in_features, in_features * num_inputs) - ) - if hasattr(net, "pre"): - net.pre.weight.data = net.pre.weight.data.double() - checkpoint_name = "-".join(["checkpoint-reg", str(i + 1) + ".pt"]) - net_adjusted = ( - _wrap_model_in_dataparallel(net) - if use_gpu == "cuda_data_parallel" - else (net.to(device="cuda") if use_gpu == "cuda" else net) - ) - torch.save(net_adjusted.state_dict(), os.path.join(tmpdir, checkpoint_name)) + # generate data. regardless the model, the data is always generated the same way + # the only exception is if the `model_type` is 'trained_linear', i.e. a simple + # linear regression model. this is a simple model, and for simplicity, the + # number of `out_features` is 1 in this case. + if model_type == "trained_linear": + out_features = 1 + else: + out_features = 3 num_samples = 50 num_train = 32 + num_hessian = 22 # this needs to be high to prevent numerical issues all_labels = torch.normal(1, 2, (num_samples, out_features)).double() + all_labels = all_labels.cuda() if use_gpu else all_labels train_labels = all_labels[:num_train] test_labels = all_labels[num_train:] + hessian_labels = all_labels[:num_hessian] if unpack_inputs: all_samples = [ torch.normal(0, 1, (num_samples, in_features)).double() for _ in range(num_inputs) ] + all_samples = ( + _move_sample_to_cuda(all_samples) + if isinstance(all_samples, list) and use_gpu + else all_samples.cuda() + if use_gpu + else all_samples + ) train_samples = [ts[:num_train] for ts in all_samples] test_samples = [ts[num_train:] for ts in all_samples] + hessian_samples = [ts[:num_hessian] for ts in all_samples] else: all_samples = torch.normal(0, 1, (num_samples, in_features)).double() + all_samples = ( + _move_sample_to_cuda(all_samples) + if isinstance(all_samples, list) and use_gpu + else all_samples.cuda() + if use_gpu + else all_samples + ) train_samples = all_samples[:num_train] test_samples = all_samples[num_train:] + hessian_samples = all_samples[:num_hessian] dataset = ( ExplicitDataset(train_samples, train_labels, use_gpu) @@ -250,26 +313,129 @@ def get_random_model_and_data( else UnpackDataset(train_samples, train_labels, use_gpu) ) - if return_test_data: - return ( - _wrap_model_in_dataparallel(net) - if use_gpu == "cuda_data_parallel" - else (net.to(device="cuda") if use_gpu == "cuda" else net), - dataset, - _move_sample_to_cuda(test_samples) - if isinstance(test_samples, list) and use_gpu - else test_samples.cuda() - if use_gpu - else test_samples, - test_labels.cuda() if use_gpu else test_labels, + if model_type == "random": + net = ( + BasicLinearNet(in_features, hidden_nodes, out_features) + if not unpack_inputs + else MultLinearNet(in_features, hidden_nodes, out_features, num_inputs) + ).double() + + # generate checkpoints randomly + num_checkpoints = 5 + + for i in range(num_checkpoints): + net.linear1.weight.data = torch.normal( + 3, 4, (hidden_nodes, in_features) + ).double() + net.linear2.weight.data = torch.normal( + 5, 6, (out_features, hidden_nodes) + ).double() + if unpack_inputs: + net.pre.weight.data = torch.normal( + 3, 4, (in_features, in_features * num_inputs) + ).double() + checkpoint_name = "-".join(["checkpoint-reg", str(i + 1) + ".pt"]) + net_adjusted = ( + _wrap_model_in_dataparallel(net) + if use_gpu == "cuda_data_parallel" + else (net.to(device="cuda") if use_gpu == "cuda" else net) + ) + torch.save(net_adjusted.state_dict(), os.path.join(tmpdir, checkpoint_name)) + + elif model_type == "trained_linear": + net = ( + Linear(in_features) + if not unpack_inputs + else UnpackLinear(in_features, num_inputs) + ).double() + + # regardless of `unpack_inputs`, the model is a linear regression, so that + # we can get the optimal trained parameters via least squares + + # turn input into a single tensor for use by least squares + tensor_hessian_samples = ( + hessian_samples if not unpack_inputs else torch.cat(hessian_samples, dim=1) ) - else: - return ( + version = _parse_version(torch.__version__) + if version < (1, 9): + theta = torch.lstsq(tensor_hessian_samples, hessian_labels).solution[0:1] + else: + # run least squares to get optimal trained parameters + theta = torch.linalg.lstsq( + hessian_labels, + tensor_hessian_samples, + ).solution + # the first `n` rows of `theta` contains the least squares solution, where + # `n` is the number of features in `tensor_hessian_samples` + theta = theta[: tensor_hessian_samples.shape[1]] + + # save that trained parameter as a checkpoint + checkpoint_name = "checkpoint-final.pt" + net.linear.weight.data = theta.contiguous() + net_adjusted = ( _wrap_model_in_dataparallel(net) if use_gpu == "cuda_data_parallel" - else (net.to(device="cuda") if use_gpu == "cuda" else net), - dataset, + else (net.to(device="cuda") if use_gpu == "cuda" else net) ) + torch.save(net_adjusted.state_dict(), os.path.join(tmpdir, checkpoint_name)) + + training_data = ( + net_adjusted, + dataset, + ) + + hessian_data = ( + _move_sample_to_cuda(hessian_samples) + if isinstance(hessian_samples, list) and use_gpu + else hessian_samples.cuda() + if use_gpu + else hessian_samples, + hessian_labels.cuda() if use_gpu else hessian_labels, + ) + + test_data = ( + _move_sample_to_cuda(test_samples) + if isinstance(test_samples, list) and use_gpu + else test_samples.cuda() + if use_gpu + else test_samples, + test_labels.cuda() if use_gpu else test_labels, + ) + if return_test_data: + if not return_hessian_data: + return (*training_data, *test_data) + else: + return (*training_data, *hessian_data, *test_data) + else: + if not return_hessian_data: + return training_data + else: + return (*training_data, *hessian_data) + + +def generate_symmetric_matrix_given_eigenvalues(eigenvalues: List[float]): + """ + following https://github.com/google-research/jax-influence/blob/74bd321156b5445bb35b9594568e4eaaec1a76a3/jax_influence/test_utils.py#L123 # noqa: E501 + generate symmetric random matrix with specified eigenvalues + """ + # generate random matrix, then apply gram-schmidt to get random orthonormal basis + D = len(eigenvalues) + Q, _ = torch.linalg.qr(torch.randn((D, D))) + return torch.matmul(Q, torch.matmul(torch.diag(torch.tensor(eigenvalues)), Q.T)) + + +def generate_assymetric_matrix_given_eigenvalues(eigenvalues: List[float]): + """ + following https://github.com/google-research/jax-influence/blob/74bd321156b5445bb35b9594568e4eaaec1a76a3/jax_influence/test_utils.py#L105 # noqa: E501 + generate assymetric random matrix with specified eigenvalues + """ + # the matrix M, given eigenvectors Q and eigenvalues L, should satisfy MQ = QL + # or equivalently, Q'M' = LQ'. + D = len(eigenvalues) + Q_T = torch.randn((D, D)) + return torch.linalg.solve( + Q_T, torch.matmul(torch.diag(torch.tensor(eigenvalues)), Q_T) + ).T class DataInfluenceConstructor: @@ -295,11 +461,14 @@ def __init__( def __repr__(self) -> str: return self.name + def __name__(self) -> str: + return self.name + def __call__( self, net: Module, dataset: Union[Dataset, DataLoader], - tmpdir: Union[str, List[str], Iterator], + tmpdir: str, batch_size: Union[int, None], loss_fn: Optional[Union[Module, Callable]], **kwargs, @@ -324,6 +493,25 @@ def __call__( batch_size=batch_size, **constructor_kwargs, ) + elif self.data_influence_class in [ + NaiveInfluenceFunction, + ]: + # for these implementations, only a single checkpoint is needed, not + # a directory containing several checkpoints. therefore, given + # directory `tmpdir`, we do not pass it directly to the constructor, + # but instead find 1 checkpoint in it, and pass that to the + # constructor + checkpoint_name = sorted(os.listdir(tmpdir))[-1] + checkpoint = os.path.join(tmpdir, checkpoint_name) + + return self.data_influence_class( + net, + dataset, + checkpoint, + loss_fn=loss_fn, + batch_size=batch_size, + **constructor_kwargs, + ) else: return self.data_influence_class( net, @@ -371,6 +559,8 @@ def generate_test_name( if isinstance(arg, bool): if arg: param_strs.append(func_param_names[i]) + elif isfunction(arg): + param_strs.append(arg.__name__) else: args_str = str(arg) if args_str.isnumeric(): @@ -397,3 +587,10 @@ def _format_batch_into_tuple( return (*inputs, targets) else: return (inputs, targets) + + +USE_GPU_LIST = ( + [False, "cuda"] + if torch.cuda.is_available() and torch.cuda.device_count() != 0 + else [False] +)