Skip to content

Commit 03db921

Browse files
99warriorsfacebook-github-bot
authored andcommitted
add test loss (pytorch#1073)
Summary: Pull Request resolved: pytorch#1073 - For all `TracInCPBase` implementations, this adds an additional `test_loss_fn` initialization argument, which is the loss function to apply to test examples when computing the influence of a training example on a test example. With this change,the influence score is a sum over terms for each checkpoint, where each term is the gradient of `loss_fn` for a given training example, multiplied with the gradient of `test_loss_fn` for a given test example. Before, `test_loss_fn` was assumed to be the same as `loss_fn`. - checks regarding the reduction type of both `loss_fn` and `test_loss_fn` are now handled by helper functions `_check_tracincp_loss_fn` and `_check_tracincp_fast_loss_fn`. - documentation is updated. one detail: for `TracInCP`, we assume that `sample_wise_grads_per_batch` is applied to both `loss_fn` and `test_loss_fn` (if provided), and this is mentioned in the documentation. - `test_tracin_regression.test_tracin_regression` is slightly modified - `DataInfluenceConstructor` now can explicitly pass in the same loss function for both `loss_fn` and `test_loss_fn` (done when `duplicate_loss_fn=True`). Doing so would have the same effect as not passing in `test_loss_fn`, so the original tests are also applied to the case when `duplicate_loss_fn=True`, as the expected behavior should be the same as before. - a new test, `test_tracin_regression.test_tracin_constant_test_loss_fn` is added. For all implementations of `TracInCPBase`, it checks that if `test_loss_fn` is a constant loss function, the influence scores are all 0's. This should be the case, because if `test_loss_fn` is constant, its gradients would all be 0's, so that training examples have 0 influence on test examples. Differential Revision: https://internalfb.com/D41202866 fbshipit-source-id: 3e085383ef6a735695a190d4289dd4c702702f06
1 parent c076410 commit 03db921

File tree

6 files changed

+363
-87
lines changed

6 files changed

+363
-87
lines changed

captum/_utils/gradient.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -849,18 +849,21 @@ def _compute_jacobian_wrt_params_with_sample_wise_trick(
849849
if labels is not None and loss_fn is not None:
850850
loss = loss_fn(out, labels)
851851
# TODO: allow loss_fn to be Callable
852-
if isinstance(loss_fn, Module) and hasattr(loss_fn, "reduction"):
852+
if (isinstance(loss_fn, Module) or callable(loss_fn)) and hasattr(
853+
loss_fn, "reduction"
854+
):
855+
reduction = loss_fn.reduction # type: ignore
853856
msg0 = (
854857
"Please ensure that loss_fn.reduction is set to `sum` or `mean`"
855858
)
856859

857-
assert loss_fn.reduction != "none", msg0
860+
assert reduction != "none", msg0
858861
msg1 = (
859-
f"loss_fn.reduction ({loss_fn.reduction}) does not match"
862+
f"loss_fn.reduction ({reduction}) does not match"
860863
f"reduction type ({reduction_type}). Please ensure they are"
861864
" matching."
862865
)
863-
assert loss_fn.reduction == reduction_type, msg1
866+
assert reduction == reduction_type, msg1
864867
msg2 = (
865868
"Please ensure custom loss function is applying either a "
866869
"sum or mean reduction."

captum/influence/_core/tracincp.py

+71-47
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from captum._utils.progress import NullProgress, progress
2727
from captum.influence._core.influence import DataInfluence
2828
from captum.influence._utils.common import (
29+
_check_loss_fn,
2930
_format_inputs_dataset,
3031
_get_k_most_influential_helper,
3132
_gradient_dot_product,
@@ -102,6 +103,7 @@ def __init__(
102103
checkpoints_load_func: Callable = _load_flexible_state_dict,
103104
loss_fn: Optional[Union[Module, Callable]] = None,
104105
batch_size: Union[int, None] = 1,
106+
test_loss_fn: Optional[Union[Module, Callable]] = None,
105107
) -> None:
106108
r"""
107109
Args:
@@ -152,6 +154,19 @@ def __init__(
152154
`train_dataset` is a Dataset. If `train_dataset`
153155
is a DataLoader, then `batch_size` is ignored as an argument.
154156
Default: 1
157+
test_loss_fn (Callable, optional): In some cases, one may want to use a
158+
separate loss functions for training examples, i.e. those in
159+
`train_dataset`, and for test examples, i.e. those
160+
represented by the `inputs` and `targets` arguments to the
161+
`influence` method. For example, if one wants to calculate the
162+
influence score of a training example on a test example's
163+
prediction for a fixed class, `test_loss_fn` could map from the
164+
logits for all classes to the logits for a fixed class.
165+
`test_loss_fn` needs to satisfy the same constraints as `loss_fn`.
166+
If not provided, the loss function for test examples is assumed to
167+
be the same as the loss function for training examples, i.e.
168+
`loss_fn`.
169+
Default: None
155170
"""
156171

157172
self.model = model
@@ -167,6 +182,8 @@ def __init__(
167182

168183
self.checkpoints_load_func = checkpoints_load_func
169184
self.loss_fn = loss_fn
185+
# If test_loss_fn not provided, it's assumed to be same as loss_fn
186+
self.test_loss_fn = loss_fn if test_loss_fn is None else test_loss_fn
170187
self.batch_size = batch_size
171188

172189
if not isinstance(train_dataset, DataLoader):
@@ -489,6 +506,7 @@ def __init__(
489506
layers: Optional[List[str]] = None,
490507
loss_fn: Optional[Union[Module, Callable]] = None,
491508
batch_size: Union[int, None] = 1,
509+
test_loss_fn: Optional[Union[Module, Callable]] = None,
492510
sample_wise_grads_per_batch: bool = False,
493511
) -> None:
494512
r"""
@@ -561,6 +579,24 @@ def __init__(
561579
`train_dataset` is a Dataset. If `train_dataset`
562580
is a DataLoader, then `batch_size` is ignored as an argument.
563581
Default: 1
582+
test_loss_fn (Callable, optional): In some cases, one may want to use a
583+
separate loss functions for training examples, i.e. those in
584+
`train_dataset`, and for test examples, i.e. those
585+
represented by the `inputs` and `targets` arguments to the
586+
`influence` method. For example, if one wants to calculate the
587+
influence score of a training example on a test example's
588+
prediction for a fixed class, `test_loss_fn` could map from the
589+
logits for all classes to the logits for a fixed class.
590+
`test_loss_fn` needs satisfy the same constraints as `loss_fn`.
591+
Thus, the same checks that we apply to `loss_fn` are also applied
592+
to `test_loss_fn`, if the latter is provided. Note that the
593+
constraints on both `loss_fn` and `test_loss_fn` both depend on
594+
`sample_wise_grads_per_batch`. This means `loss_fn` and
595+
`test_loss_fn` must either both be "per-example" loss functions,
596+
or both be "reduction" loss functions. If not provided, the loss
597+
function for test examples is assumed to be the same as the loss
598+
function for training examples, i.e. `loss_fn`.
599+
Default: None
564600
sample_wise_grads_per_batch (bool, optional): PyTorch's native gradient
565601
computations w.r.t. model parameters aggregates the results for a
566602
batch and does not allow to access sample-wise gradients w.r.t.
@@ -590,51 +626,23 @@ def __init__(
590626
checkpoints_load_func,
591627
loss_fn,
592628
batch_size,
629+
test_loss_fn,
593630
)
594631

595632
self.sample_wise_grads_per_batch = sample_wise_grads_per_batch
596633

597-
# If we are able to access the reduction used by `loss_fn`, we check whether
598-
# the reduction is compatible with `sample_wise_grads_per_batch`
599-
if isinstance(loss_fn, Module) and hasattr(
600-
loss_fn, "reduction"
601-
): # TODO: allow loss_fn to be Callable
602-
if self.sample_wise_grads_per_batch:
603-
assert loss_fn.reduction in ["sum", "mean"], (
604-
'reduction for `loss_fn` must be "sum" or "mean" when '
605-
"`sample_wise_grads_per_batch` is True"
606-
)
607-
self.reduction_type = str(loss_fn.reduction)
608-
else:
609-
assert loss_fn.reduction == "none", (
610-
'reduction for `loss_fn` must be "none" when '
611-
"`sample_wise_grads_per_batch` is False"
612-
)
613-
else:
614-
# if we are unable to access the reduction used by `loss_fn`, we warn
615-
# the user about the assumptions we are making regarding the reduction
616-
# used by `loss_fn`
617-
if self.sample_wise_grads_per_batch:
618-
warnings.warn(
619-
'Since `loss_fn` has no "reduction" attribute, and '
620-
"`sample_wise_grads_per_batch` is True, the implementation assumes "
621-
'that `loss_fn` is a "reduction" loss function that reduces the '
622-
"per-example losses by taking their *sum*. If `loss_fn` "
623-
"instead reduces the per-example losses by taking their mean, "
624-
'please set the reduction attribute of `loss_fn` to "mean", i.e. '
625-
'`loss_fn.reduction = "mean"`. Note that if '
626-
"`sample_wise_grads_per_batch` is True, the implementation "
627-
"assumes the reduction is either a sum or mean reduction."
628-
)
629-
self.reduction_type = "sum"
630-
else:
631-
warnings.warn(
632-
'Since `loss_fn` has no "reduction" attribute, and '
633-
"`sample_wise_grads_per_batch` is False, the implementation "
634-
'assumes that `loss_fn` is a "per-example" loss function (see '
635-
"documentation for `loss_fn` for details). Please ensure that "
636-
"this is the case."
637-
)
634+
# check `loss_fn`
635+
self.reduction_type = _check_loss_fn(
636+
self, loss_fn, "loss_fn", sample_wise_grads_per_batch
637+
)
638+
# check `test_loss_fn` if it was provided
639+
self.test_reduction_type = (
640+
self.reduction_type
641+
if test_loss_fn is None
642+
else _check_loss_fn(
643+
self, test_loss_fn, "test_loss_fn", sample_wise_grads_per_batch
644+
)
645+
)
638646

639647
r"""
640648
TODO: Either restore model state after done (would have to place functionality
@@ -790,11 +798,15 @@ def get_checkpoint_contribution(checkpoint):
790798
input_jacobians = self._basic_computation_tracincp(
791799
inputs,
792800
targets,
801+
self.test_loss_fn,
802+
self.test_reduction_type,
793803
)
794804
return (
795805
_gradient_dot_product(
796806
input_jacobians,
797-
self._basic_computation_tracincp(batch[0:-1], batch[-1]),
807+
self._basic_computation_tracincp(
808+
batch[0:-1], batch[-1], self.loss_fn, self.reduction_type
809+
),
798810
)
799811
* learning_rate
800812
)
@@ -1042,7 +1054,10 @@ def get_checkpoint_contribution(checkpoint):
10421054
for batch in _inputs_dataset:
10431055

10441056
layer_jacobians = self._basic_computation_tracincp(
1045-
batch[0:-1], batch[-1]
1057+
batch[0:-1],
1058+
batch[-1],
1059+
self.loss_fn,
1060+
self.reduction_type,
10461061
)
10471062

10481063
# Note that all variables in this function are for an entire batch.
@@ -1179,11 +1194,14 @@ def _basic_computation_tracincp(
11791194
self,
11801195
inputs: Tuple[Any, ...],
11811196
targets: Optional[Tensor] = None,
1197+
loss_fn: Optional[Union[Module, Callable]] = None,
1198+
reduction_type: Optional[str] = None,
11821199
) -> Tuple[Tensor, ...]:
11831200
"""
11841201
For instances of TracInCP, computation of influence scores or self influence
11851202
scores repeatedly calls this function for different checkpoints
1186-
and batches.
1203+
and batches. In particular, this function computes the jacobian of a loss
1204+
function w.r.t. parameters in the `layers` initialization argument.
11871205
11881206
Args:
11891207
@@ -1193,20 +1211,26 @@ def _basic_computation_tracincp(
11931211
that `model(*inputs)` produces the predictions for the batch.
11941212
targets (tensor or None): If computing influence scores on a loss function,
11951213
these are the labels corresponding to the batch `inputs`.
1214+
Default: none
1215+
loss_fn (Callable, optional): The loss function to use when computing the
1216+
jacobian.
1217+
reduction_type (str, optional): The reduction type of `loss_fn`. This
1218+
argument is only used if `sample_wise_grads_per_batch` was true in
1219+
initialization.
11961220
"""
11971221
if self.sample_wise_grads_per_batch:
11981222
return _compute_jacobian_wrt_params_with_sample_wise_trick(
11991223
self.model,
12001224
inputs,
12011225
targets,
1202-
self.loss_fn,
1203-
self.reduction_type,
1226+
loss_fn,
1227+
reduction_type,
12041228
self.layer_modules,
12051229
)
12061230
return _compute_jacobian_wrt_params(
12071231
self.model,
12081232
inputs,
12091233
targets,
1210-
self.loss_fn,
1234+
loss_fn,
12111235
self.layer_modules,
12121236
)

0 commit comments

Comments
 (0)