Skip to content

Commit 4e5e50c

Browse files
authored
Optim-wip: Improve ModuleOutputsHook, testing coverage (#834)
* Improve ModuleOutputsHook, testing coverage, & fix bug * Added the `_remove_all_forward_hooks` function for easy cleanup and removal of hooks without requiring their handles. * Changed `ModuleOutputHook`'s forward hook function name from `forward_hook` to `module_outputs_forward_hook` to allow for easy removal of only hooks using that hook function. * `ModuleOutputHook`'s initialization function now runs the `_remove_all_forward_hooks` function on targets, and only removes the hooks created by `ModuleOutputHook` to avoid breaking PyTorch. * Added the `_count_forward_hooks` function for easy testing of hook creation & removal functionality. * Added tests for verifying that the 'ghost hook' bug has been fixed, and that the new function is working correctly. * Added tests for `ModuleOutputsHook`. Previously we had no tests for this module. * Make hook fix optional * Remove hacky hook fix * Lint: Fix import order
1 parent 5e18711 commit 4e5e50c

File tree

2 files changed

+207
-7
lines changed

2 files changed

+207
-7
lines changed

captum/optim/_core/output_hook.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@ def is_ready(self) -> bool:
3232

3333
def _forward_hook(self) -> Callable:
3434
"""
35-
Return the forward_hook function.
35+
Return the module_outputs_forward_hook forward hook function.
3636
3737
Returns:
38-
forward_hook (Callable): The forward_hook function.
38+
forward_hook (Callable): The module_outputs_forward_hook function.
3939
"""
4040

41-
def forward_hook(
41+
def module_outputs_forward_hook(
4242
module: nn.Module, input: Tuple[torch.Tensor], output: torch.Tensor
4343
) -> None:
4444
assert module in self.outputs.keys()
@@ -56,7 +56,7 @@ def forward_hook(
5656
"that you are passing model layers in your losses."
5757
)
5858

59-
return forward_hook
59+
return module_outputs_forward_hook
6060

6161
def consume_outputs(self) -> ModuleOutputMapping:
6262
"""
Lines changed: 203 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,219 @@
11
#!/usr/bin/env python3
2-
from typing import cast
2+
from collections import OrderedDict
3+
from typing import List, Optional, cast
34

45
import captum.optim._core.output_hook as output_hook
56
import torch
67
from captum.optim.models import googlenet
7-
from tests.helpers.basic import BaseTest
8+
from tests.helpers.basic import BaseTest, assertTensorAlmostEqual
9+
10+
11+
def _count_forward_hooks(
12+
module: torch.nn.Module, hook_fn_name: Optional[str] = None
13+
) -> int:
14+
"""
15+
Count the number of active forward hooks on the specified module instance.
16+
17+
Args:
18+
19+
module (nn.Module): The model module instance to count the number of
20+
forward hooks on.
21+
name (str, optional): Optionally only count specific forward hooks based on
22+
their function's __name__ attribute.
23+
Default: None
24+
25+
Returns:
26+
num_hooks (int): The number of active hooks in the specified module.
27+
"""
28+
29+
num_hooks: List[int] = [0]
30+
31+
def _count_hooks(m: torch.nn.Module, name: Optional[str] = None) -> None:
32+
if hasattr(m, "_forward_hooks"):
33+
if m._forward_hooks != OrderedDict():
34+
dict_items = list(m._forward_hooks.items())
35+
for i, fn in dict_items:
36+
if hook_fn_name is None or fn.__name__ == name:
37+
num_hooks[0] += 1
38+
39+
def _count_child_hooks(
40+
target_module: torch.nn.Module,
41+
hook_name: Optional[str] = None,
42+
) -> None:
43+
44+
for name, child in target_module._modules.items():
45+
if child is not None:
46+
_count_hooks(child, hook_name)
47+
_count_child_hooks(child, hook_name)
48+
49+
_count_child_hooks(module, hook_fn_name)
50+
_count_hooks(module, hook_fn_name)
51+
return num_hooks[0]
52+
53+
54+
class TestModuleOutputsHook(BaseTest):
55+
def test_init_single_target(self) -> None:
56+
model = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity())
57+
target_modules = [model[0]]
58+
59+
hook_module = output_hook.ModuleOutputsHook(target_modules)
60+
self.assertEqual(len(hook_module.hooks), len(target_modules))
61+
62+
n_hooks = _count_forward_hooks(model, "module_outputs_forward_hook")
63+
self.assertEqual(n_hooks, len(target_modules))
64+
65+
outputs = dict.fromkeys(target_modules, None)
66+
self.assertEqual(outputs, hook_module.outputs)
67+
self.assertEqual(list(hook_module.targets), target_modules)
68+
self.assertFalse(hook_module.is_ready)
69+
70+
def test_init_multiple_targets(self) -> None:
71+
model = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity())
72+
target_modules = [model[0], model[1]]
73+
74+
hook_module = output_hook.ModuleOutputsHook(target_modules)
75+
self.assertEqual(len(hook_module.hooks), len(target_modules))
76+
77+
n_hooks = _count_forward_hooks(model, "module_outputs_forward_hook")
78+
self.assertEqual(n_hooks, len(target_modules))
79+
80+
outputs = dict.fromkeys(target_modules, None)
81+
self.assertEqual(outputs, hook_module.outputs)
82+
self.assertEqual(list(hook_module.targets), target_modules)
83+
self.assertFalse(hook_module.is_ready)
84+
85+
def test_init_multiple_targets_remove_hooks(self) -> None:
86+
model = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity())
87+
target_modules = [model[0], model[1]]
88+
89+
hook_module = output_hook.ModuleOutputsHook(target_modules)
90+
91+
n_hooks = _count_forward_hooks(model, "module_outputs_forward_hook")
92+
self.assertEqual(n_hooks, len(target_modules))
93+
94+
hook_module.remove_hooks()
95+
96+
n_hooks = _count_forward_hooks(model, "module_outputs_forward_hook")
97+
self.assertEqual(n_hooks, 0)
98+
99+
def test_reset_outputs_multiple_targets(self) -> None:
100+
model = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity())
101+
target_modules = [model[0], model[1]]
102+
test_input = torch.randn(1, 3, 4, 4)
103+
104+
hook_module = output_hook.ModuleOutputsHook(target_modules)
105+
self.assertFalse(hook_module.is_ready)
106+
107+
_ = model(test_input)
108+
109+
self.assertTrue(hook_module.is_ready)
110+
111+
outputs_dict = hook_module.outputs
112+
i = 0
113+
for target, activations in outputs_dict.items():
114+
self.assertEqual(target, target_modules[i])
115+
assertTensorAlmostEqual(self, activations, test_input)
116+
i += 1
117+
118+
hook_module._reset_outputs()
119+
120+
self.assertFalse(hook_module.is_ready)
121+
122+
expected_outputs = dict.fromkeys(target_modules, None)
123+
self.assertEqual(hook_module.outputs, expected_outputs)
124+
125+
def test_consume_outputs_multiple_targets(self) -> None:
126+
model = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity())
127+
target_modules = [model[0], model[1]]
128+
test_input = torch.randn(1, 3, 4, 4)
129+
130+
hook_module = output_hook.ModuleOutputsHook(target_modules)
131+
self.assertFalse(hook_module.is_ready)
132+
133+
_ = model(test_input)
134+
135+
self.assertTrue(hook_module.is_ready)
136+
137+
test_outputs_dict = hook_module.outputs
138+
self.assertIsInstance(test_outputs_dict, dict)
139+
self.assertEqual(len(test_outputs_dict), len(target_modules))
140+
141+
i = 0
142+
for target, activations in test_outputs_dict.items():
143+
self.assertEqual(target, target_modules[i])
144+
assertTensorAlmostEqual(self, activations, test_input)
145+
i += 1
146+
147+
test_output = hook_module.consume_outputs()
148+
149+
self.assertFalse(hook_module.is_ready)
150+
151+
i = 0
152+
for target, activations in test_output.items():
153+
self.assertEqual(target, target_modules[i])
154+
assertTensorAlmostEqual(self, activations, test_input)
155+
i += 1
156+
157+
expected_outputs = dict.fromkeys(target_modules, None)
158+
self.assertEqual(hook_module.outputs, expected_outputs)
159+
160+
def test_consume_outputs_warning(self) -> None:
161+
model = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity())
162+
target_modules = [model[0], model[1]]
163+
test_input = torch.randn(1, 3, 4, 4)
164+
165+
hook_module = output_hook.ModuleOutputsHook(target_modules)
166+
self.assertFalse(hook_module.is_ready)
167+
168+
_ = model(test_input)
169+
170+
self.assertTrue(hook_module.is_ready)
171+
172+
hook_module._reset_outputs()
173+
174+
self.assertFalse(hook_module.is_ready)
175+
176+
with self.assertWarns(Warning):
177+
_ = hook_module.consume_outputs()
8178

9179

10180
class TestActivationFetcher(BaseTest):
11-
def test_activation_fetcher(self) -> None:
181+
def test_activation_fetcher_simple_model(self) -> None:
182+
model = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity())
183+
184+
catch_activ = output_hook.ActivationFetcher(model, targets=[model[0]])
185+
test_input = torch.randn(1, 3, 224, 224)
186+
activ_out = catch_activ(test_input)
187+
188+
self.assertIsInstance(activ_out, dict)
189+
self.assertEqual(len(activ_out), 1)
190+
activ = activ_out[model[0]]
191+
assertTensorAlmostEqual(self, activ, test_input)
192+
193+
def test_activation_fetcher_single_target(self) -> None:
12194
model = googlenet(pretrained=True)
13195

14196
catch_activ = output_hook.ActivationFetcher(model, targets=[model.mixed4d])
15197
activ_out = catch_activ(torch.zeros(1, 3, 224, 224))
16198

17199
self.assertIsInstance(activ_out, dict)
200+
self.assertEqual(len(activ_out), 1)
18201
m4d_activ = activ_out[model.mixed4d]
19202
self.assertEqual(list(cast(torch.Tensor, m4d_activ).shape), [1, 528, 14, 14])
203+
204+
def test_activation_fetcher_multiple_targets(self) -> None:
205+
model = googlenet(pretrained=True)
206+
207+
catch_activ = output_hook.ActivationFetcher(
208+
model, targets=[model.mixed4d, model.mixed5b]
209+
)
210+
activ_out = catch_activ(torch.zeros(1, 3, 224, 224))
211+
212+
self.assertIsInstance(activ_out, dict)
213+
self.assertEqual(len(activ_out), 2)
214+
215+
m4d_activ = activ_out[model.mixed4d]
216+
self.assertEqual(list(cast(torch.Tensor, m4d_activ).shape), [1, 528, 14, 14])
217+
218+
m5b_activ = activ_out[model.mixed5b]
219+
self.assertEqual(list(cast(torch.Tensor, m5b_activ).shape), [1, 1024, 7, 7])

0 commit comments

Comments
 (0)