Skip to content

Commit 43c7d23

Browse files
99warriorsfacebook-github-bot
authored andcommitted
add compute_intermediate_quantities to TracInCP (pytorch#1068)
Summary: Pull Request resolved: pytorch#1068 This diff adds the `compute_intermediate_quantities` method to `TracInCP`, which returns influence embeddings such that the influence of one example on another is the dot-product of their respective influence embeddings. In the case of `TracInCP`, its influence embeddings are simply the parameter-gradients for an example, concatenated over different checkpoints. There is also an `aggregate` option that if True, returns not the influence embeddings of each example in the given dataset, but instead their *sum*. This is useful for the validation diff workflow (which is the next diff in the stack), where we want to calculate the influence of a given training example on an entire validation dataset. This can be accomplished by taking the dot-product of the training example's influence embedding with the *sum* of the influence embeddings over the validation dataset (i.e. with `aggregate=True`) For tests, the tests currently used for `TracInCPFastRandProj.compute_intermediate_quantities` (`test_tracin_intermediate_quantities.test_tracin_intermediate_quantities_api`, `test_tracin_intermediate_quantities.test_tracin_intermediate_quantities_consistent`) are applied to `TracInCP.compute_intermediate_quantities`. In addition, `test_tracin_intermediate_quantities.test_tracin_intermediate_quantities_aggregate` is added to test the `aggregate=True` option, checking that with `aggregate=True`, the returned influence embedding is indeed the sum of the influence embeddings for the given dataset. Differential Revision: https://internalfb.com/D40688327 fbshipit-source-id: d7a328a6227e5fae9e95b06188835524c2a5c86b
1 parent 03db921 commit 43c7d23

File tree

3 files changed

+222
-0
lines changed

3 files changed

+222
-0
lines changed

captum/influence/_core/tracincp.py

+157
Original file line numberDiff line numberDiff line change
@@ -777,6 +777,162 @@ def influence( # type: ignore[override]
777777
show_progress,
778778
)
779779

780+
def _sum_jacobians(
781+
self,
782+
inputs_dataset: DataLoader,
783+
loss_fn: Optional[Union[Module, Callable]] = None,
784+
reduction_type: Optional[str] = None,
785+
):
786+
"""
787+
sums the jacobians of all examples in `inputs_dataset`. result is of the
788+
same format as layer_jacobians, but the batch dimension has size 1
789+
"""
790+
inputs_dataset_iter = iter(inputs_dataset)
791+
792+
inputs_batch = next(inputs_dataset_iter)
793+
794+
def get_batch_contribution(inputs_batch):
795+
_input_jacobians = self._basic_computation_tracincp(
796+
inputs_batch[0:-1],
797+
inputs_batch[-1],
798+
loss_fn,
799+
reduction_type,
800+
)
801+
802+
return tuple(
803+
torch.sum(jacobian, dim=0).unsqueeze(0) for jacobian in _input_jacobians
804+
)
805+
806+
inputs_jacobians = get_batch_contribution(inputs_batch)
807+
808+
for inputs_batch in inputs_dataset_iter:
809+
inputs_batch_jacobians = get_batch_contribution(inputs_batch)
810+
inputs_jacobians = tuple(
811+
[
812+
inputs_jacobian + inputs_batch_jacobian
813+
for (inputs_jacobian, inputs_batch_jacobian) in zip(
814+
inputs_jacobians, inputs_batch_jacobians
815+
)
816+
]
817+
)
818+
819+
return inputs_jacobians
820+
821+
def _concat_jacobians(
822+
self,
823+
inputs_dataset: DataLoader,
824+
loss_fn: Optional[Union[Module, Callable]] = None,
825+
reduction_type: Optional[str] = None,
826+
):
827+
all_inputs_batch_jacobians = [
828+
self._basic_computation_tracincp(
829+
inputs_batch[0:-1],
830+
inputs_batch[-1],
831+
loss_fn,
832+
reduction_type,
833+
)
834+
for inputs_batch in inputs_dataset
835+
]
836+
837+
return tuple(
838+
torch.cat(all_inputs_batch_jacobian, dim=0)
839+
for all_inputs_batch_jacobian in zip(*all_inputs_batch_jacobians)
840+
)
841+
842+
@log_usage()
843+
def compute_intermediate_quantities(
844+
self,
845+
inputs_dataset: Union[Tuple[Any, ...], DataLoader],
846+
aggregate: bool = False,
847+
) -> Tensor:
848+
"""
849+
Computes "embedding" vectors for all examples in a single batch, or a
850+
`Dataloader` that yields batches. These embedding vectors are constructed so
851+
that the influence score of a training example on a test example is simply the
852+
dot-product of their corresponding vectors. Allowing a `DataLoader`
853+
yielding batches to be passed in (as opposed to a single batch) gives the
854+
potential to improve efficiency, because we load each checkpoint only once in
855+
this method call. Thus if a `DataLoader` yielding batches is passed in, this
856+
reduces the total number of times each checkpoint is loaded for a dataset,
857+
compared to if a single batch is passed in. The reason we do not just increase
858+
the batch size is that for large models, large batches do not fit in memory.
859+
860+
If `aggregate` is True, the *sum* of the vectors for all examples is returned,
861+
instead of the vectors for each example. This can be useful for computing the
862+
influence of a given training example on the total loss over a validation
863+
dataset, because due to properties of the dot-product, this influence is the
864+
dot-product of the training example's vector with the sum of the vectors in the
865+
validation dataset. Also, by doing the sum aggregation within this method as
866+
opposed to outside of it (by computing all vectors for the validation dataset,
867+
then taking the sum) allows memory usage to be reduced.
868+
869+
Args:
870+
inputs_dataset (Tuple, or DataLoader): Either a single tuple of any, or a
871+
`DataLoader`, where each batch yielded is a tuple of any. In
872+
either case, the tuple represents a single batch, where the last
873+
element is assumed to be the labels for the batch. That is,
874+
`model(*batch[0:-1])` produces the output for `model`, and
875+
and `batch[-1]` are the labels, if any. Here, `model` is model
876+
provided in initialization. This is the same assumption made for
877+
each batch yielded by training dataset `train_dataset`.
878+
aggregate (bool): Whether to return the sum of the vectors for all
879+
examples, as opposed to vectors for each example.
880+
881+
Returns:
882+
intermediate_quantities (Tensor): A tensor of dimension
883+
(N, D * C). Here, N is the total number of examples in
884+
`inputs_dataset` if `aggregate` is False, and 1, otherwise (so that
885+
a 2D tensor is always returned). C is the number of checkpoints
886+
passed as the `checkpoints` argument of `TracInCP.__init__`, and
887+
each row represents the vector for an example. Regarding D: Let I
888+
be the dimension of the output of the last fully-connected layer
889+
times the dimension of the input of the last fully-connected layer.
890+
If `self.projection_dim` is specified in initialization,
891+
D = min(I * C, `self.projection_dim` * C). Otherwise, D = I * C.
892+
In summary, if `self.projection_dim` is None, the dimension of each
893+
vector will be determined by the size of the input and output of
894+
the last fully-connected layer of `model`. Otherwise,
895+
`self.projection_dim` must be an int, and random projection will be
896+
performed to ensure that the vector is of dimension no more than
897+
`self.projection_dim` * C. `self.projection_dim` corresponds to
898+
the variable d in the top of page 15 of the TracIn paper:
899+
https://arxiv.org/pdf/2002.08484.pdf.
900+
"""
901+
# If `inputs_dataset` is not a `DataLoader`, turn it into one.
902+
inputs_dataset = _format_inputs_dataset(inputs_dataset)
903+
904+
def get_checkpoint_contribution(checkpoint):
905+
assert (
906+
checkpoint is not None
907+
), "None returned from `checkpoints`, cannot load."
908+
909+
learning_rate = self.checkpoints_load_func(self.model, checkpoint)
910+
# get jacobians as tuple of tensors
911+
if aggregate:
912+
inputs_jacobians = self._sum_jacobians(
913+
inputs_dataset, self.loss_fn, self.reduction_type
914+
)
915+
else:
916+
inputs_jacobians = self._concat_jacobians(
917+
inputs_dataset, self.loss_fn, self.reduction_type
918+
)
919+
# flatten into single tensor
920+
return learning_rate * torch.cat(
921+
[
922+
input_jacobian.flatten(start_dim=1)
923+
for input_jacobian in inputs_jacobians
924+
],
925+
dim=1,
926+
)
927+
928+
return torch.cat(
929+
[
930+
get_checkpoint_contribution(checkpoint)
931+
for checkpoint in self.checkpoints
932+
],
933+
dim=1,
934+
)
935+
780936
def _influence_batch_tracincp(
781937
self,
782938
inputs: Tuple[Any, ...],
@@ -1113,6 +1269,7 @@ def get_checkpoint_contribution(checkpoint):
11131269

11141270
return batches_self_tracin_scores
11151271

1272+
@log_usage()
11161273
def self_influence(
11171274
self,
11181275
inputs_dataset: Optional[Union[Tuple[Any, ...], DataLoader]] = None,

captum/influence/_core/tracincp_fast_rand_proj.py

+3
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,7 @@ def get_checkpoint_contribution(checkpoint):
652652
checkpoints_progress.update()
653653
return batches_self_tracin_scores
654654

655+
@log_usage()
655656
def self_influence(
656657
self,
657658
inputs_dataset: Optional[Union[Tuple[Any, ...], DataLoader]] = None,
@@ -1142,6 +1143,7 @@ def _get_k_most_influential( # type: ignore[override]
11421143

11431144
return KMostInfluentialResults(indices, distances)
11441145

1146+
@log_usage()
11451147
def self_influence(
11461148
self,
11471149
inputs_dataset: Optional[Union[Tuple[Any, ...], DataLoader]] = None,
@@ -1589,6 +1591,7 @@ def _get_intermediate_quantities_tracincp_fast_rand_proj(
15891591
# each row in this result is the "embedding" vector for an example in `batch`
15901592
return torch.cat(checkpoint_contributions, dim=1) # type: ignore
15911593

1594+
@log_usage()
15921595
def compute_intermediate_quantities(
15931596
self,
15941597
inputs_dataset: Union[Tuple[Any, ...], DataLoader],

tests/influence/_core/test_tracin_intermediate_quantities.py

+62
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55

66
import torch.nn as nn
7+
from captum.influence._core.tracincp import TracInCP
78
from captum.influence._core.tracincp_fast_rand_proj import (
89
TracInCPFast,
910
TracInCPFastRandProj,
@@ -19,12 +20,68 @@
1920

2021

2122
class TestTracInIntermediateQuantities(BaseTest):
23+
@parameterized.expand(
24+
[
25+
(reduction, constructor, unpack_inputs)
26+
for unpack_inputs in [True, False]
27+
for (reduction, constructor) in [
28+
("none", DataInfluenceConstructor(TracInCP)),
29+
]
30+
],
31+
name_func=build_test_name_func(),
32+
)
33+
def test_tracin_intermediate_quantities_aggregate(
34+
self, reduction: str, tracin_constructor: Callable, unpack_inputs: bool
35+
) -> None:
36+
"""
37+
tests that calling `compute_intermediate_quantities` with `aggregate=True`
38+
does give the same result as calling it with `aggregate=False`, and then
39+
summing
40+
"""
41+
with tempfile.TemporaryDirectory() as tmpdir:
42+
(net, train_dataset,) = get_random_model_and_data(
43+
tmpdir,
44+
unpack_inputs,
45+
return_test_data=False,
46+
)
47+
48+
# create a dataloader that yields batches from the dataset
49+
train_dataset = DataLoader(train_dataset, batch_size=5)
50+
51+
# create tracin instance
52+
criterion = nn.MSELoss(reduction=reduction)
53+
batch_size = 5
54+
55+
tracin = tracin_constructor(
56+
net,
57+
train_dataset,
58+
tmpdir,
59+
batch_size,
60+
criterion,
61+
)
62+
63+
intermediate_quantities = tracin.compute_intermediate_quantities(
64+
train_dataset, aggregate=False
65+
)
66+
aggregated_intermediate_quantities = tracin.compute_intermediate_quantities(
67+
train_dataset, aggregate=True
68+
)
69+
70+
assertTensorAlmostEqual(
71+
self,
72+
torch.sum(intermediate_quantities, dim=0, keepdim=True),
73+
aggregated_intermediate_quantities,
74+
delta=1e-4, # due to numerical issues, we can't set this to 0.0
75+
mode="max",
76+
)
77+
2278
@parameterized.expand(
2379
[
2480
(reduction, constructor, unpack_inputs)
2581
for unpack_inputs in [True, False]
2682
for (reduction, constructor) in [
2783
("sum", DataInfluenceConstructor(TracInCPFastRandProj)),
84+
("none", DataInfluenceConstructor(TracInCP)),
2885
]
2986
],
3087
name_func=build_test_name_func(),
@@ -103,6 +160,11 @@ def test_tracin_intermediate_quantities_api(
103160
DataInfluenceConstructor(TracInCPFast),
104161
DataInfluenceConstructor(TracInCPFastRandProj),
105162
),
163+
(
164+
"none",
165+
DataInfluenceConstructor(TracInCP),
166+
DataInfluenceConstructor(TracInCP),
167+
),
106168
]
107169
],
108170
name_func=build_test_name_func(),

0 commit comments

Comments
 (0)