Skip to content

add NaiveInfluenceFunction (#1186) #1214

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions captum/influence/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3

from captum.influence._core.influence import DataInfluence # noqa
from captum.influence._core.influence_function import NaiveInfluenceFunction # noqa
from captum.influence._core.similarity_influence import SimilarityInfluence # noqa
from captum.influence._core.tracincp import TracInCP, TracInCPBase # noqa
from captum.influence._core.tracincp_fast_rand_proj import (
Expand All @@ -15,4 +16,5 @@
"TracInCP",
"TracInCPFast",
"TracInCPFastRandProj",
"NaiveInfluenceFunction",
]
14 changes: 13 additions & 1 deletion captum/influence/_core/influence.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3

from abc import ABC, abstractmethod
from typing import Any
from typing import Any, Type

from torch.nn import Module
from torch.utils.data import Dataset
Expand Down Expand Up @@ -42,3 +42,15 @@ def influence(self, inputs: Any = None, **kwargs: Any) -> Any:
though this may change in the future.
"""
pass

@classmethod
def get_name(cls: Type["DataInfluence"]) -> str:
r"""
Create readable class name. Due to the nature of the names of `TracInCPBase`
subclasses, simply returns the class name. For example, for a class called
TracInCP, we return the string TracInCP.

Returns:
name (str): a readable class name
"""
return cls.__name__
1,314 changes: 1,314 additions & 0 deletions captum/influence/_core/influence_function.py

Large diffs are not rendered by default.

103 changes: 9 additions & 94 deletions captum/influence/_core/tracincp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,24 @@
import warnings
from abc import abstractmethod
from os.path import join
from typing import (
Any,
Callable,
Iterator,
List,
NamedTuple,
Optional,
Tuple,
Type,
Union,
)
from typing import Any, Callable, Iterator, List, Optional, Tuple, Type, Union

import torch
from captum._utils.av import AV
from captum._utils.common import _get_module_from_name, _parse_version
from captum._utils.gradient import (
_compute_jacobian_wrt_params,
_compute_jacobian_wrt_params_with_sample_wise_trick,
)
from captum._utils.common import _parse_version
from captum._utils.progress import NullProgress, progress
from captum.influence._core.influence import DataInfluence
from captum.influence._utils.common import (
_check_loss_fn,
_compute_jacobian_sample_wise_grads_per_batch,
_format_inputs_dataset,
_get_k_most_influential_helper,
_gradient_dot_product,
_influence_route_to_helpers,
_load_flexible_state_dict,
_self_influence_by_batches_helper,
_set_active_parameters,
KMostInfluentialResults,
)
from captum.log import log_usage
from torch import Tensor
Expand Down Expand Up @@ -69,24 +59,6 @@
"""


class KMostInfluentialResults(NamedTuple):
"""
This namedtuple stores the results of using the `influence` method. This method
is implemented by all subclasses of `TracInCPBase` to calculate
proponents / opponents. The `indices` field stores the indices of the
proponents / opponents for each example in the test dataset. For example, if
finding opponents, `indices[i][j]` stores the index in the training data of the
example with the `j`-th highest influence score on the `i`-th example in the test
dataset. Similarly, the `influence_scores` field stores the actual influence
scores, so that `influence_scores[i][j]` is the influence score of example
`indices[i][j]` in the training data on example `i` of the test dataset.
Please see `TracInCPBase.influence` for more details.
"""

indices: Tensor
influence_scores: Tensor


class TracInCPBase(DataInfluence):
"""
To implement the `influence` method, classes inheriting from `TracInCPBase` will
Expand Down Expand Up @@ -448,34 +420,6 @@ def get_name(cls: Type["TracInCPBase"]) -> str:
return cls.__name__


def _influence_route_to_helpers(
influence_instance: TracInCPBase,
inputs: Union[Tuple[Any, ...], DataLoader],
k: Optional[int] = None,
proponents: bool = True,
**kwargs,
) -> Union[Tensor, KMostInfluentialResults]:
"""
This is a helper function called by `TracInCP.influence` and
`TracInCPFast.influence`. Those methods share a common logic in that they assume
an instance of their respective classes implement 2 private methods
(``_influence`, `_get_k_most_influential`), and the logic of
which private method to call is common, as described in the documentation of the
`influence` method. The arguments and return values of this function are the exact
same as the `influence` method. Note that `influence_instance` refers to the
instance for which the `influence` method was called.
"""
if k is None:
return influence_instance._influence(inputs, **kwargs)
else:
return influence_instance._get_k_most_influential(
inputs,
k,
proponents,
**kwargs,
)


class TracInCP(TracInCPBase):
def __init__(
self,
Expand Down Expand Up @@ -630,23 +574,7 @@ def __init__(
"""
self.layer_modules = None
if layers is not None:
assert isinstance(layers, List), "`layers` should be a list!"
assert len(layers) > 0, "`layers` cannot be empty!"
assert isinstance(
layers[0], str
), "`layers` should contain str layer names."
self.layer_modules = [
_get_module_from_name(self.model, layer) for layer in layers
]
for layer, layer_module in zip(layers, self.layer_modules):
for name, param in layer_module.named_parameters():
if not param.requires_grad:
warnings.warn(
"Setting required grads for layer: {}, name: {}".format(
".".join(layer), name
)
)
param.requires_grad = True
self.layer_modules = _set_active_parameters(model, layers)

@log_usage()
def influence( # type: ignore[override]
Expand Down Expand Up @@ -1463,19 +1391,6 @@ def _basic_computation_tracincp(
argument is only used if `sample_wise_grads_per_batch` was true in
initialization.
"""
if self.sample_wise_grads_per_batch:
return _compute_jacobian_wrt_params_with_sample_wise_trick(
self.model,
inputs,
targets,
loss_fn,
reduction_type,
self.layer_modules,
)
return _compute_jacobian_wrt_params(
self.model,
inputs,
targets,
loss_fn,
self.layer_modules,
return _compute_jacobian_sample_wise_grads_per_batch(
self, inputs, targets, loss_fn, reduction_type
)
Loading