Skip to content

Commit 93bbef1

Browse files
99warriorsfacebook-github-bot
authored andcommitted
add compute_intermediate_quantities to TracInCP (#1068)
Summary: Pull Request resolved: #1068 This diff adds the `compute_intermediate_quantities` method to `TracInCP`, which returns influence embeddings such that the influence of one example on another is the dot-product of their respective influence embeddings. In the case of `TracInCP`, its influence embeddings are simply the parameter-gradients for an example, concatenated over different checkpoints. There is also an `aggregate` option that if True, returns not the influence embeddings of each example in the given dataset, but instead their *sum*. This is useful for the validation diff workflow (which is the next diff in the stack), where we want to calculate the influence of a given training example on an entire validation dataset. This can be accomplished by taking the dot-product of the training example's influence embedding with the *sum* of the influence embeddings over the validation dataset (i.e. with `aggregate=True`) For tests, the tests currently used for `TracInCPFastRandProj.compute_intermediate_quantities` (`test_tracin_intermediate_quantities.test_tracin_intermediate_quantities_api`, `test_tracin_intermediate_quantities.test_tracin_intermediate_quantities_consistent`) are applied to `TracInCP.compute_intermediate_quantities`. In addition, `test_tracin_intermediate_quantities.test_tracin_intermediate_quantities_aggregate` is added to test the `aggregate=True` option, checking that with `aggregate=True`, the returned influence embedding is indeed the sum of the influence embeddings for the given dataset. Differential Revision: https://internalfb.com/D40688327 fbshipit-source-id: 0f57a7a73d91a49f615a117a5ac3c2fdb2fb1bf8
1 parent d5c717c commit 93bbef1

File tree

3 files changed

+222
-0
lines changed

3 files changed

+222
-0
lines changed

captum/influence/_core/tracincp.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -773,6 +773,162 @@ def influence( # type: ignore[override]
773773
show_progress,
774774
)
775775

776+
def _sum_jacobians(
777+
self,
778+
inputs_dataset: DataLoader,
779+
loss_fn: Optional[Union[Module, Callable]] = None,
780+
reduction_type: Optional[str] = None,
781+
):
782+
"""
783+
sums the jacobians of all examples in `inputs_dataset`. result is of the
784+
same format as layer_jacobians, but the batch dimension has size 1
785+
"""
786+
inputs_dataset_iter = iter(inputs_dataset)
787+
788+
inputs_batch = next(inputs_dataset_iter)
789+
790+
def get_batch_contribution(inputs_batch):
791+
_input_jacobians = self._basic_computation_tracincp(
792+
inputs_batch[0:-1],
793+
inputs_batch[-1],
794+
loss_fn,
795+
reduction_type,
796+
)
797+
798+
return tuple(
799+
torch.sum(jacobian, dim=0).unsqueeze(0) for jacobian in _input_jacobians
800+
)
801+
802+
inputs_jacobians = get_batch_contribution(inputs_batch)
803+
804+
for inputs_batch in inputs_dataset_iter:
805+
inputs_batch_jacobians = get_batch_contribution(inputs_batch)
806+
inputs_jacobians = tuple(
807+
[
808+
inputs_jacobian + inputs_batch_jacobian
809+
for (inputs_jacobian, inputs_batch_jacobian) in zip(
810+
inputs_jacobians, inputs_batch_jacobians
811+
)
812+
]
813+
)
814+
815+
return inputs_jacobians
816+
817+
def _concat_jacobians(
818+
self,
819+
inputs_dataset: DataLoader,
820+
loss_fn: Optional[Union[Module, Callable]] = None,
821+
reduction_type: Optional[str] = None,
822+
):
823+
all_inputs_batch_jacobians = [
824+
self._basic_computation_tracincp(
825+
inputs_batch[0:-1],
826+
inputs_batch[-1],
827+
loss_fn,
828+
reduction_type,
829+
)
830+
for inputs_batch in inputs_dataset
831+
]
832+
833+
return tuple(
834+
torch.cat(all_inputs_batch_jacobian, dim=0)
835+
for all_inputs_batch_jacobian in zip(*all_inputs_batch_jacobians)
836+
)
837+
838+
@log_usage()
839+
def compute_intermediate_quantities(
840+
self,
841+
inputs_dataset: Union[Tuple[Any, ...], DataLoader],
842+
aggregate: bool = False,
843+
) -> Tensor:
844+
"""
845+
Computes "embedding" vectors for all examples in a single batch, or a
846+
`Dataloader` that yields batches. These embedding vectors are constructed so
847+
that the influence score of a training example on a test example is simply the
848+
dot-product of their corresponding vectors. Allowing a `DataLoader`
849+
yielding batches to be passed in (as opposed to a single batch) gives the
850+
potential to improve efficiency, because we load each checkpoint only once in
851+
this method call. Thus if a `DataLoader` yielding batches is passed in, this
852+
reduces the total number of times each checkpoint is loaded for a dataset,
853+
compared to if a single batch is passed in. The reason we do not just increase
854+
the batch size is that for large models, large batches do not fit in memory.
855+
856+
If `aggregate` is True, the *sum* of the vectors for all examples is returned,
857+
instead of the vectors for each example. This can be useful for computing the
858+
influence of a given training example on the total loss over a validation
859+
dataset, because due to properties of the dot-product, this influence is the
860+
dot-product of the training example's vector with the sum of the vectors in the
861+
validation dataset. Also, by doing the sum aggregation within this method as
862+
opposed to outside of it (by computing all vectors for the validation dataset,
863+
then taking the sum) allows memory usage to be reduced.
864+
865+
Args:
866+
inputs_dataset (Tuple, or DataLoader): Either a single tuple of any, or a
867+
`DataLoader`, where each batch yielded is a tuple of any. In
868+
either case, the tuple represents a single batch, where the last
869+
element is assumed to be the labels for the batch. That is,
870+
`model(*batch[0:-1])` produces the output for `model`, and
871+
and `batch[-1]` are the labels, if any. Here, `model` is model
872+
provided in initialization. This is the same assumption made for
873+
each batch yielded by training dataset `train_dataset`.
874+
aggregate (bool): Whether to return the sum of the vectors for all
875+
examples, as opposed to vectors for each example.
876+
877+
Returns:
878+
intermediate_quantities (Tensor): A tensor of dimension
879+
(N, D * C). Here, N is the total number of examples in
880+
`inputs_dataset` if `aggregate` is False, and 1, otherwise (so that
881+
a 2D tensor is always returned). C is the number of checkpoints
882+
passed as the `checkpoints` argument of `TracInCP.__init__`, and
883+
each row represents the vector for an example. Regarding D: Let I
884+
be the dimension of the output of the last fully-connected layer
885+
times the dimension of the input of the last fully-connected layer.
886+
If `self.projection_dim` is specified in initialization,
887+
D = min(I * C, `self.projection_dim` * C). Otherwise, D = I * C.
888+
In summary, if `self.projection_dim` is None, the dimension of each
889+
vector will be determined by the size of the input and output of
890+
the last fully-connected layer of `model`. Otherwise,
891+
`self.projection_dim` must be an int, and random projection will be
892+
performed to ensure that the vector is of dimension no more than
893+
`self.projection_dim` * C. `self.projection_dim` corresponds to
894+
the variable d in the top of page 15 of the TracIn paper:
895+
https://arxiv.org/pdf/2002.08484.pdf.
896+
"""
897+
# If `inputs_dataset` is not a `DataLoader`, turn it into one.
898+
inputs_dataset = _format_inputs_dataset(inputs_dataset)
899+
900+
def get_checkpoint_contribution(checkpoint):
901+
assert (
902+
checkpoint is not None
903+
), "None returned from `checkpoints`, cannot load."
904+
905+
learning_rate = self.checkpoints_load_func(self.model, checkpoint)
906+
# get jacobians as tuple of tensors
907+
if aggregate:
908+
inputs_jacobians = self._sum_jacobians(
909+
inputs_dataset, self.loss_fn, self.reduction_type
910+
)
911+
else:
912+
inputs_jacobians = self._concat_jacobians(
913+
inputs_dataset, self.loss_fn, self.reduction_type
914+
)
915+
# flatten into single tensor
916+
return learning_rate * torch.cat(
917+
[
918+
input_jacobian.flatten(start_dim=1)
919+
for input_jacobian in inputs_jacobians
920+
],
921+
dim=1,
922+
)
923+
924+
return torch.cat(
925+
[
926+
get_checkpoint_contribution(checkpoint)
927+
for checkpoint in self.checkpoints
928+
],
929+
dim=1,
930+
)
931+
776932
def _influence_batch_tracincp(
777933
self,
778934
inputs: Tuple[Any, ...],
@@ -1109,6 +1265,7 @@ def get_checkpoint_contribution(checkpoint):
11091265

11101266
return batches_self_tracin_scores
11111267

1268+
@log_usage()
11121269
def self_influence(
11131270
self,
11141271
inputs_dataset: Optional[Union[Tuple[Any, ...], DataLoader]] = None,

captum/influence/_core/tracincp_fast_rand_proj.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,7 @@ def get_checkpoint_contribution(checkpoint):
652652
checkpoints_progress.update()
653653
return batches_self_tracin_scores
654654

655+
@log_usage()
655656
def self_influence(
656657
self,
657658
inputs_dataset: Optional[Union[Tuple[Any, ...], DataLoader]] = None,
@@ -1181,6 +1182,7 @@ def _get_k_most_influential( # type: ignore[override]
11811182

11821183
return KMostInfluentialResults(indices, distances)
11831184

1185+
@log_usage()
11841186
def self_influence(
11851187
self,
11861188
inputs_dataset: Optional[Union[Tuple[Any, ...], DataLoader]] = None,
@@ -1628,6 +1630,7 @@ def _get_intermediate_quantities_tracincp_fast_rand_proj(
16281630
# each row in this result is the "embedding" vector for an example in `batch`
16291631
return torch.cat(checkpoint_contributions, dim=1) # type: ignore
16301632

1633+
@log_usage()
16311634
def compute_intermediate_quantities(
16321635
self,
16331636
inputs_dataset: Union[Tuple[Any, ...], DataLoader],

tests/influence/_core/test_tracin_intermediate_quantities.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55

66
import torch.nn as nn
7+
from captum.influence._core.tracincp import TracInCP
78
from captum.influence._core.tracincp_fast_rand_proj import (
89
TracInCPFast,
910
TracInCPFastRandProj,
@@ -19,12 +20,68 @@
1920

2021

2122
class TestTracInIntermediateQuantities(BaseTest):
23+
@parameterized.expand(
24+
[
25+
(reduction, constructor, unpack_inputs)
26+
for unpack_inputs in [True, False]
27+
for (reduction, constructor) in [
28+
("none", DataInfluenceConstructor(TracInCP)),
29+
]
30+
],
31+
name_func=build_test_name_func(),
32+
)
33+
def test_tracin_intermediate_quantities_aggregate(
34+
self, reduction: str, tracin_constructor: Callable, unpack_inputs: bool
35+
) -> None:
36+
"""
37+
tests that calling `compute_intermediate_quantities` with `aggregate=True`
38+
does give the same result as calling it with `aggregate=False`, and then
39+
summing
40+
"""
41+
with tempfile.TemporaryDirectory() as tmpdir:
42+
(net, train_dataset,) = get_random_model_and_data(
43+
tmpdir,
44+
unpack_inputs,
45+
return_test_data=False,
46+
)
47+
48+
# create a dataloader that yields batches from the dataset
49+
train_dataset = DataLoader(train_dataset, batch_size=5)
50+
51+
# create tracin instance
52+
criterion = nn.MSELoss(reduction=reduction)
53+
batch_size = 5
54+
55+
tracin = tracin_constructor(
56+
net,
57+
train_dataset,
58+
tmpdir,
59+
batch_size,
60+
criterion,
61+
)
62+
63+
intermediate_quantities = tracin.compute_intermediate_quantities(
64+
train_dataset, aggregate=False
65+
)
66+
aggregated_intermediate_quantities = tracin.compute_intermediate_quantities(
67+
train_dataset, aggregate=True
68+
)
69+
70+
assertTensorAlmostEqual(
71+
self,
72+
torch.sum(intermediate_quantities, dim=0, keepdim=True),
73+
aggregated_intermediate_quantities,
74+
delta=1e-4, # due to numerical issues, we can't set this to 0.0
75+
mode="max",
76+
)
77+
2278
@parameterized.expand(
2379
[
2480
(reduction, constructor, unpack_inputs)
2581
for unpack_inputs in [True, False]
2682
for (reduction, constructor) in [
2783
("sum", DataInfluenceConstructor(TracInCPFastRandProj)),
84+
("none", DataInfluenceConstructor(TracInCP)),
2885
]
2986
],
3087
name_func=build_test_name_func(),
@@ -103,6 +160,11 @@ def test_tracin_intermediate_quantities_api(
103160
DataInfluenceConstructor(TracInCPFast),
104161
DataInfluenceConstructor(TracInCPFastRandProj),
105162
),
163+
(
164+
"none",
165+
DataInfluenceConstructor(TracInCP),
166+
DataInfluenceConstructor(TracInCP),
167+
),
106168
]
107169
],
108170
name_func=build_test_name_func(),

0 commit comments

Comments
 (0)