26
26
from captum ._utils .progress import NullProgress , progress
27
27
from captum .influence ._core .influence import DataInfluence
28
28
from captum .influence ._utils .common import (
29
+ _check_loss_fn ,
29
30
_format_inputs_dataset ,
30
31
_get_k_most_influential_helper ,
31
32
_gradient_dot_product ,
@@ -102,6 +103,7 @@ def __init__(
102
103
checkpoints_load_func : Callable = _load_flexible_state_dict ,
103
104
loss_fn : Optional [Union [Module , Callable ]] = None ,
104
105
batch_size : Union [int , None ] = 1 ,
106
+ test_loss_fn : Optional [Union [Module , Callable ]] = None ,
105
107
) -> None :
106
108
r"""
107
109
Args:
@@ -152,6 +154,19 @@ def __init__(
152
154
`train_dataset` is a Dataset. If `train_dataset`
153
155
is a DataLoader, then `batch_size` is ignored as an argument.
154
156
Default: 1
157
+ test_loss_fn (Callable, optional): In some cases, one may want to use a
158
+ separate loss functions for training examples, i.e. those in
159
+ `train_dataset`, and for test examples, i.e. those
160
+ represented by the `inputs` and `targets` arguments to the
161
+ `influence` method. For example, if one wants to calculate the
162
+ influence score of a training example on a test example's
163
+ prediction for a fixed class, `test_loss_fn` could map from the
164
+ logits for all classes to the logits for a fixed class.
165
+ `test_loss_fn` needs to satisfy the same constraints as `loss_fn`.
166
+ If not provided, the loss function for test examples is assumed to
167
+ be the same as the loss function for training examples, i.e.
168
+ `loss_fn`.
169
+ Default: None
155
170
"""
156
171
157
172
self .model = model
@@ -167,6 +182,8 @@ def __init__(
167
182
168
183
self .checkpoints_load_func = checkpoints_load_func
169
184
self .loss_fn = loss_fn
185
+ # If test_loss_fn not provided, it's assumed to be same as loss_fn
186
+ self .test_loss_fn = loss_fn if test_loss_fn is None else test_loss_fn
170
187
self .batch_size = batch_size
171
188
172
189
if not isinstance (train_dataset , DataLoader ):
@@ -489,6 +506,7 @@ def __init__(
489
506
layers : Optional [List [str ]] = None ,
490
507
loss_fn : Optional [Union [Module , Callable ]] = None ,
491
508
batch_size : Union [int , None ] = 1 ,
509
+ test_loss_fn : Optional [Union [Module , Callable ]] = None ,
492
510
sample_wise_grads_per_batch : bool = False ,
493
511
) -> None :
494
512
r"""
@@ -561,6 +579,24 @@ def __init__(
561
579
`train_dataset` is a Dataset. If `train_dataset`
562
580
is a DataLoader, then `batch_size` is ignored as an argument.
563
581
Default: 1
582
+ test_loss_fn (Callable, optional): In some cases, one may want to use a
583
+ separate loss functions for training examples, i.e. those in
584
+ `train_dataset`, and for test examples, i.e. those
585
+ represented by the `inputs` and `targets` arguments to the
586
+ `influence` method. For example, if one wants to calculate the
587
+ influence score of a training example on a test example's
588
+ prediction for a fixed class, `test_loss_fn` could map from the
589
+ logits for all classes to the logits for a fixed class.
590
+ `test_loss_fn` needs satisfy the same constraints as `loss_fn`.
591
+ Thus, the same checks that we apply to `loss_fn` are also applied
592
+ to `test_loss_fn`, if the latter is provided. Note that the
593
+ constraints on both `loss_fn` and `test_loss_fn` both depend on
594
+ `sample_wise_grads_per_batch`. This means `loss_fn` and
595
+ `test_loss_fn` must either both be "per-example" loss functions,
596
+ or both be "reduction" loss functions. If not provided, the loss
597
+ function for test examples is assumed to be the same as the loss
598
+ function for training examples, i.e. `loss_fn`.
599
+ Default: None
564
600
sample_wise_grads_per_batch (bool, optional): PyTorch's native gradient
565
601
computations w.r.t. model parameters aggregates the results for a
566
602
batch and does not allow to access sample-wise gradients w.r.t.
@@ -590,51 +626,23 @@ def __init__(
590
626
checkpoints_load_func ,
591
627
loss_fn ,
592
628
batch_size ,
629
+ test_loss_fn ,
593
630
)
594
631
595
632
self .sample_wise_grads_per_batch = sample_wise_grads_per_batch
596
633
597
- # If we are able to access the reduction used by `loss_fn`, we check whether
598
- # the reduction is compatible with `sample_wise_grads_per_batch`
599
- if isinstance (loss_fn , Module ) and hasattr (
600
- loss_fn , "reduction"
601
- ): # TODO: allow loss_fn to be Callable
602
- if self .sample_wise_grads_per_batch :
603
- assert loss_fn .reduction in ["sum" , "mean" ], (
604
- 'reduction for `loss_fn` must be "sum" or "mean" when '
605
- "`sample_wise_grads_per_batch` is True"
606
- )
607
- self .reduction_type = str (loss_fn .reduction )
608
- else :
609
- assert loss_fn .reduction == "none" , (
610
- 'reduction for `loss_fn` must be "none" when '
611
- "`sample_wise_grads_per_batch` is False"
612
- )
613
- else :
614
- # if we are unable to access the reduction used by `loss_fn`, we warn
615
- # the user about the assumptions we are making regarding the reduction
616
- # used by `loss_fn`
617
- if self .sample_wise_grads_per_batch :
618
- warnings .warn (
619
- 'Since `loss_fn` has no "reduction" attribute, and '
620
- "`sample_wise_grads_per_batch` is True, the implementation assumes "
621
- 'that `loss_fn` is a "reduction" loss function that reduces the '
622
- "per-example losses by taking their *sum*. If `loss_fn` "
623
- "instead reduces the per-example losses by taking their mean, "
624
- 'please set the reduction attribute of `loss_fn` to "mean", i.e. '
625
- '`loss_fn.reduction = "mean"`. Note that if '
626
- "`sample_wise_grads_per_batch` is True, the implementation "
627
- "assumes the reduction is either a sum or mean reduction."
628
- )
629
- self .reduction_type = "sum"
630
- else :
631
- warnings .warn (
632
- 'Since `loss_fn` has no "reduction" attribute, and '
633
- "`sample_wise_grads_per_batch` is False, the implementation "
634
- 'assumes that `loss_fn` is a "per-example" loss function (see '
635
- "documentation for `loss_fn` for details). Please ensure that "
636
- "this is the case."
637
- )
634
+ # check `loss_fn`
635
+ self .reduction_type = _check_loss_fn (
636
+ self , loss_fn , "loss_fn" , sample_wise_grads_per_batch
637
+ )
638
+ # check `test_loss_fn` if it was provided
639
+ self .test_reduction_type = (
640
+ self .reduction_type
641
+ if test_loss_fn is None
642
+ else _check_loss_fn (
643
+ self , test_loss_fn , "test_loss_fn" , sample_wise_grads_per_batch
644
+ )
645
+ )
638
646
639
647
r"""
640
648
TODO: Either restore model state after done (would have to place functionality
@@ -790,11 +798,15 @@ def get_checkpoint_contribution(checkpoint):
790
798
input_jacobians = self ._basic_computation_tracincp (
791
799
inputs ,
792
800
targets ,
801
+ self .test_loss_fn ,
802
+ self .test_reduction_type ,
793
803
)
794
804
return (
795
805
_gradient_dot_product (
796
806
input_jacobians ,
797
- self ._basic_computation_tracincp (batch [0 :- 1 ], batch [- 1 ]),
807
+ self ._basic_computation_tracincp (
808
+ batch [0 :- 1 ], batch [- 1 ], self .loss_fn , self .reduction_type
809
+ ),
798
810
)
799
811
* learning_rate
800
812
)
@@ -1042,7 +1054,10 @@ def get_checkpoint_contribution(checkpoint):
1042
1054
for batch in _inputs_dataset :
1043
1055
1044
1056
layer_jacobians = self ._basic_computation_tracincp (
1045
- batch [0 :- 1 ], batch [- 1 ]
1057
+ batch [0 :- 1 ],
1058
+ batch [- 1 ],
1059
+ self .loss_fn ,
1060
+ self .reduction_type ,
1046
1061
)
1047
1062
1048
1063
# Note that all variables in this function are for an entire batch.
@@ -1179,11 +1194,14 @@ def _basic_computation_tracincp(
1179
1194
self ,
1180
1195
inputs : Tuple [Any , ...],
1181
1196
targets : Optional [Tensor ] = None ,
1197
+ loss_fn : Optional [Union [Module , Callable ]] = None ,
1198
+ reduction_type : Optional [str ] = None ,
1182
1199
) -> Tuple [Tensor , ...]:
1183
1200
"""
1184
1201
For instances of TracInCP, computation of influence scores or self influence
1185
1202
scores repeatedly calls this function for different checkpoints
1186
- and batches.
1203
+ and batches. In particular, this function computes the jacobian of a loss
1204
+ function w.r.t. parameters in the `layers` initialization argument.
1187
1205
1188
1206
Args:
1189
1207
@@ -1193,20 +1211,26 @@ def _basic_computation_tracincp(
1193
1211
that `model(*inputs)` produces the predictions for the batch.
1194
1212
targets (tensor or None): If computing influence scores on a loss function,
1195
1213
these are the labels corresponding to the batch `inputs`.
1214
+ Default: none
1215
+ loss_fn (Callable, optional): The loss function to use when computing the
1216
+ jacobian.
1217
+ reduction_type (str, optional): The reduction type of `loss_fn`. This
1218
+ argument is only used if `sample_wise_grads_per_batch` was true in
1219
+ initialization.
1196
1220
"""
1197
1221
if self .sample_wise_grads_per_batch :
1198
1222
return _compute_jacobian_wrt_params_with_sample_wise_trick (
1199
1223
self .model ,
1200
1224
inputs ,
1201
1225
targets ,
1202
- self . loss_fn ,
1203
- self . reduction_type ,
1226
+ loss_fn ,
1227
+ reduction_type ,
1204
1228
self .layer_modules ,
1205
1229
)
1206
1230
return _compute_jacobian_wrt_params (
1207
1231
self .model ,
1208
1232
inputs ,
1209
1233
targets ,
1210
- self . loss_fn ,
1234
+ loss_fn ,
1211
1235
self .layer_modules ,
1212
1236
)
0 commit comments