Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions captum/_utils/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,18 +849,21 @@ def _compute_jacobian_wrt_params_with_sample_wise_trick(
if labels is not None and loss_fn is not None:
loss = loss_fn(out, labels)
# TODO: allow loss_fn to be Callable
if isinstance(loss_fn, Module) and hasattr(loss_fn, "reduction"):
if (isinstance(loss_fn, Module) or callable(loss_fn)) and hasattr(
loss_fn, "reduction"
):
reduction = loss_fn.reduction # type: ignore
msg0 = (
"Please ensure that loss_fn.reduction is set to `sum` or `mean`"
)

assert loss_fn.reduction != "none", msg0
assert reduction != "none", msg0
msg1 = (
f"loss_fn.reduction ({loss_fn.reduction}) does not match"
f"loss_fn.reduction ({reduction}) does not match"
f"reduction type ({reduction_type}). Please ensure they are"
" matching."
)
assert loss_fn.reduction == reduction_type, msg1
assert reduction == reduction_type, msg1
msg2 = (
"Please ensure custom loss function is applying either a "
"sum or mean reduction."
Expand Down
275 changes: 228 additions & 47 deletions captum/influence/_core/tracincp.py

Large diffs are not rendered by default.

113 changes: 83 additions & 30 deletions captum/influence/_core/tracincp_fast_rand_proj.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
TracInCPBase,
)
from captum.influence._utils.common import (
_check_loss_fn,
_DatasetFromList,
_format_inputs_dataset,
_get_k_most_influential_helper,
Expand Down Expand Up @@ -88,6 +89,7 @@ def __init__(
checkpoints_load_func: Callable = _load_flexible_state_dict,
loss_fn: Optional[Union[Module, Callable]] = None,
batch_size: Union[int, None] = 1,
test_loss_fn: Optional[Union[Module, Callable]] = None,
vectorize: bool = False,
) -> None:
r"""
Expand Down Expand Up @@ -153,6 +155,20 @@ def __init__(
`train_dataset` is a Dataset. If `train_dataset`
is a DataLoader, then `batch_size` is ignored as an argument.
Default: 1
test_loss_fn (Callable, optional): In some cases, one may want to use a
separate loss functions for training examples, i.e. those in
`train_dataset`, and for test examples, i.e. those
represented by the `inputs` and `targets` arguments to the
`influence` method. For example, if one wants to calculate the
influence score of a training example on a test example's
prediction for a fixed class, `test_loss_fn` could map from the
logits for all classes to the logits for a fixed class.
`test_loss_fn` needs satisfy the same constraints as `loss_fn`.
Thus, the same checks that we apply to `loss_fn` are also applied
to `test_loss_fn`, if the latter is provided. If not provided, the
loss function for test examples is assumed to be the same as the
loss function for training examples, i.e. `loss_fn`.
Default: None
vectorize (bool, optional): Flag to use experimental vectorize functionality
for `torch.autograd.functional.jacobian`.
Default: False
Expand All @@ -165,6 +181,7 @@ def __init__(
checkpoints_load_func,
loss_fn,
batch_size,
test_loss_fn,
)

self.vectorize = vectorize
Expand All @@ -179,29 +196,14 @@ def __init__(

assert loss_fn is not None, "loss function must not be none"

# If we are able to access the reduction used by `loss_fn`, we check whether
# the reduction is either 'sum' or 'mean', as required
if isinstance(loss_fn, Module) and hasattr(
loss_fn, "reduction"
): # TODO: allow loss_fn to be Callable
assert loss_fn.reduction in [
"sum",
"mean",
], 'reduction for `loss_fn` must be "sum" or "mean"'
self.reduction_type = str(loss_fn.reduction)
else:
# if we are unable to access the reduction used by `loss_fn`, we warn
# the user about the assumptions we are making regarding the reduction
# used by `loss_fn`
warnings.warn(
'Since `loss_fn` has no "reduction" attribute, the implementation '
'assumes that `loss_fn` is a "reduction" loss function that '
"reduces the per-example losses by taking their *sum*. If "
"`loss_fn` instead reduces the per-example losses by taking their "
'mean, please set the reduction attribute of `loss_fn` to "mean", '
'i.e. `loss_fn.reduction = "mean"`.'
)
self.reduction_type = "sum"
# check `loss_fn`
self.reduction_type = _check_loss_fn(self, loss_fn, "loss_fn")
# check `test_loss_fn` if it was provided
self.test_reduction_type = (
self.reduction_type
if test_loss_fn is None
else _check_loss_fn(self, test_loss_fn, "test_loss_fn")
)

@log_usage()
def influence( # type: ignore[override]
Expand Down Expand Up @@ -340,10 +342,16 @@ def get_checkpoint_contribution(checkpoint):
self,
inputs,
targets,
self.test_loss_fn,
self.test_reduction_type,
)

src_jacobian, src_layer_input = _basic_computation_tracincp_fast(
self, batch[0:-1], batch[-1]
self,
batch[0:-1],
batch[-1],
self.loss_fn,
self.reduction_type,
)
return (
_tensor_batch_dot(
Expand Down Expand Up @@ -603,7 +611,11 @@ def get_checkpoint_contribution(checkpoint):
for batch in _inputs_dataset:

batch_jacobian, batch_layer_input = _basic_computation_tracincp_fast(
self, batch[0:-1], batch[-1]
self,
batch[0:-1],
batch[-1],
self.loss_fn,
self.reduction_type,
)

checkpoint_contribution.append(
Expand Down Expand Up @@ -640,6 +652,7 @@ def get_checkpoint_contribution(checkpoint):
checkpoints_progress.update()
return batches_self_tracin_scores

@log_usage()
def self_influence(
self,
inputs_dataset: Optional[Union[Tuple[Any, ...], DataLoader]] = None,
Expand Down Expand Up @@ -722,11 +735,18 @@ def _basic_computation_tracincp_fast(
influence_instance: TracInCPFast,
inputs: Tuple[Any, ...],
targets: Tensor,
loss_fn: Optional[Union[Module, Callable]] = None,
reduction_type: Optional[str] = None,
):
"""
For instances of TracInCPFast and children classes, computation of influence scores
or self influence scores repeatedly calls this function for different checkpoints
and batches.
and batches. These computations involve a loss function. If `test` is True, the
loss function is `self.loss_fn`. If `test` is False, the loss function is
`self.test_loss_fn`. These two attributes were set in initialization, with
`self.loss_fn` equal to the `loss_fn` initialization argument, and
`self.test_loss_fn` equal to the `test_loss_fn` initialization argument if it was
provided, and `loss_fn` otherwise.

Args:

Expand All @@ -742,6 +762,11 @@ def _basic_computation_tracincp_fast(
that `model(*inputs)` produces the predictions for the batch.
targets (Tensor): If computing influence scores on a loss function,
these are the labels corresponding to the batch `inputs`.
loss_fn (Callable, optional): The loss function to use when computing the
jacobian.
reduction_type (str, optional): The reduction type of `loss_fn`. This argument
is only used if `sample_wise_grads_per_batch` was true in
initialization of `influence_instance`.

Returns:
(input_jacobians, layer_inputs) (tuple): `input_jacobians` is a 2D tensor,
Expand Down Expand Up @@ -773,17 +798,17 @@ def _capture_inputs(layer, input, output) -> None:
)
out = influence_instance.model(*inputs)

assert influence_instance.loss_fn is not None, "loss function is required"
assert influence_instance.reduction_type in [
assert loss_fn is not None, "loss function is required"
assert reduction_type in [
"sum",
"mean",
], 'reduction_type must be either "mean" or "sum"'
input_jacobians = _jacobian_loss_wrt_inputs(
influence_instance.loss_fn,
loss_fn,
out,
targets,
influence_instance.vectorize,
influence_instance.reduction_type,
reduction_type,
)
handle.remove()

Expand Down Expand Up @@ -863,6 +888,7 @@ def __init__(
checkpoints_load_func: Callable = _load_flexible_state_dict,
loss_fn: Optional[Union[Module, Callable]] = None,
batch_size: Union[int, None] = 1,
test_loss_fn: Optional[Union[Module, Callable]] = None,
vectorize: bool = False,
nearest_neighbors: Optional[NearestNeighbors] = None,
projection_dim: int = None,
Expand Down Expand Up @@ -927,6 +953,19 @@ def __init__(
`train_dataset` is a Dataset. If `train_dataset`
is a DataLoader, then `batch_size` is ignored as an argument.
Default: 1
test_loss_fn (Callable, optional): In some cases, one may want to use a
separate loss functions for training examples, i.e. those in
`train_dataset`, and for test examples, i.e. those
represented by the `inputs` and `targets` arguments to the
`influence` method. For example, if one wants to calculate the
influence score of a training example on a test example's
prediction for a fixed class, `test_loss_fn` could map from the
logits for all classes to the logits for a fixed class.
`test_loss_fn` needs satisfy the same constraints as `loss_fn`.
Thus, the same checks that we apply to `loss_fn` are also applied
to `test_loss_fn`, if the latter is provided. If not provided, the
loss function for test examples is assumed to be the same as the
loss function for training examples, i.e. `loss_fn`.
vectorize (bool): Flag to use experimental vectorize functionality
for `torch.autograd.functional.jacobian`.
Default: False
Expand Down Expand Up @@ -970,6 +1009,7 @@ def __init__(
checkpoints_load_func,
loss_fn,
batch_size,
test_loss_fn,
vectorize,
)

Expand Down Expand Up @@ -1038,6 +1078,7 @@ def _influence( # type: ignore[override]
_DatasetFromList([inputs_batch]), shuffle=False, batch_size=None
),
self.projection_quantities,
test=True,
)

src_projections = self.src_intermediate_quantities
Expand Down Expand Up @@ -1088,6 +1129,7 @@ def _get_k_most_influential( # type: ignore[override]
_DatasetFromList([inputs_batch]), shuffle=False, batch_size=None
),
self.projection_quantities,
test=True,
)
multiplier = 1 if proponents else -1

Expand All @@ -1101,6 +1143,7 @@ def _get_k_most_influential( # type: ignore[override]

return KMostInfluentialResults(indices, distances)

@log_usage()
def self_influence(
self,
inputs_dataset: Optional[Union[Tuple[Any, ...], DataLoader]] = None,
Expand Down Expand Up @@ -1326,6 +1369,8 @@ def _set_projections_tracincp_fast_rand_proj(
self,
batch[0:-1],
batch[-1],
self.loss_fn,
self.reduction_type,
)

jacobian_dim = batch_jacobians.shape[
Expand Down Expand Up @@ -1398,6 +1443,7 @@ def _get_intermediate_quantities_tracincp_fast_rand_proj(
self,
inputs_dataset: Union[Tuple[Any, ...], DataLoader],
projection_quantities: Optional[Tuple[torch.Tensor, torch.Tensor]],
test: bool = False,
) -> torch.Tensor:
r"""
This method computes vectors that can be used to compute influence. (see
Expand All @@ -1422,6 +1468,10 @@ def _get_intermediate_quantities_tracincp_fast_rand_proj(
projection_quantities (tuple or None): Is either the two tensors defining
the randomized projections to apply, or None, which means no
projection is to be applied.
test (bool): If True, the intermediate quantities are computed using
`self.test_loss_fn`. Otherwise, they are computed using
`self.loss_fn`.
Default: False

Returns:
intermediate_quantities (Tensor): A tensor of dimension
Expand Down Expand Up @@ -1490,6 +1540,8 @@ def _get_intermediate_quantities_tracincp_fast_rand_proj(
self,
batch[0:-1],
batch[-1],
self.test_loss_fn,
self.test_reduction_type,
)

# if doing projection, project those two quantities
Expand Down Expand Up @@ -1539,6 +1591,7 @@ def _get_intermediate_quantities_tracincp_fast_rand_proj(
# each row in this result is the "embedding" vector for an example in `batch`
return torch.cat(checkpoint_contributions, dim=1) # type: ignore

@log_usage()
def compute_intermediate_quantities(
self,
inputs_dataset: Union[Tuple[Any, ...], DataLoader],
Expand Down
Loading