Skip to content

Commit 1f6c2d5

Browse files
authored
Improve ModuleOutputsHook, testing coverage, & fix bug
* Added the `_remove_all_forward_hooks` function for easy cleanup and removal of hooks without requiring their handles. * Changed `ModuleOutputHook`'s forward hook function name from `forward_hook` to `module_outputs_forward_hook` to allow for easy removal of only hooks using that hook function. * `ModuleOutputHook`'s initialization function now runs the `_remove_all_forward_hooks` function on targets, and only removes the hooks created by `ModuleOutputHook` to avoid breaking PyTorch. * Added the `_count_forward_hooks` function for easy testing of hook creation & removal functionality. * Added tests for verifying that the 'ghost hook' bug has been fixed, and that the new function is working correctly. * Added tests for `ModuleOutputsHook`. Previously we had no tests for this module.
1 parent 6e7f0bd commit 1f6c2d5

File tree

2 files changed

+397
-8
lines changed

2 files changed

+397
-8
lines changed

captum/optim/_core/output_hook.py

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import warnings
2-
from typing import Callable, Iterable, Tuple
2+
from collections import OrderedDict
3+
from typing import Callable, Dict, Iterable, Optional, Tuple
34
from warnings import warn
45

56
import torch
@@ -15,6 +16,9 @@ def __init__(self, target_modules: Iterable[nn.Module]) -> None:
1516
1617
target_modules (Iterable of nn.Module): A list of nn.Module targets.
1718
"""
19+
for module in target_modules:
20+
# Clean up any old hooks that weren't properly deleted
21+
_remove_all_forward_hooks(module, "module_outputs_forward_hook")
1822
self.outputs: ModuleOutputMapping = dict.fromkeys(target_modules, None)
1923
self.hooks = [
2024
module.register_forward_hook(self._forward_hook())
@@ -33,13 +37,13 @@ def is_ready(self) -> bool:
3337

3438
def _forward_hook(self) -> Callable:
3539
"""
36-
Return the forward_hook function.
40+
Return the module_outputs_forward_hook forward hook function.
3741
3842
Returns:
39-
forward_hook (Callable): The forward_hook function.
43+
forward_hook (Callable): The module_outputs_forward_hook function.
4044
"""
4145

42-
def forward_hook(
46+
def module_outputs_forward_hook(
4347
module: nn.Module, input: Tuple[torch.Tensor], output: torch.Tensor
4448
) -> None:
4549
assert module in self.outputs.keys()
@@ -57,7 +61,7 @@ def forward_hook(
5761
"that you are passing model layers in your losses."
5862
)
5963

60-
return forward_hook
64+
return module_outputs_forward_hook
6165

6266
def consume_outputs(self) -> ModuleOutputMapping:
6367
"""
@@ -130,3 +134,53 @@ def __call__(self, input_t: TupleOfTensorsOrTensorType) -> ModuleOutputMapping:
130134
finally:
131135
self.layers.remove_hooks()
132136
return activations_dict
137+
138+
139+
def _remove_all_forward_hooks(
140+
module: torch.nn.Module, hook_fn_name: Optional[str] = None
141+
) -> None:
142+
"""
143+
This function removes all forward hooks in the specified module, without requiring
144+
any hook handles. This lets us clean up & remove any hooks that weren't property
145+
deleted.
146+
147+
Warning: Various PyTorch modules and systems make use of hooks, and thus extreme
148+
caution should be exercised when removing all hooks. Users are recommended to give
149+
their hook function a unique name that can be used to safely identify and remove
150+
the target forward hooks.
151+
152+
Args:
153+
154+
module (nn.Module): The module instance to remove forward hooks from.
155+
hook_fn_name (str, optional): Optionally only remove specific forward hooks
156+
based on their function's __name__ attribute.
157+
Default: None
158+
"""
159+
160+
if hook_fn_name is None:
161+
warn("Removing all active hooks will break some PyTorch modules & systems.")
162+
163+
def _remove_hooks(m: torch.nn.Module, name: Optional[str] = None) -> None:
164+
if hasattr(module, "_forward_hooks"):
165+
if m._forward_hooks != OrderedDict():
166+
if name is not None:
167+
dict_items = list(m._forward_hooks.items())
168+
m._forward_hooks = OrderedDict(
169+
[(i, fn) for i, fn in dict_items if fn.__name__ != name]
170+
)
171+
else:
172+
m._forward_hooks: Dict[int, Callable] = OrderedDict()
173+
174+
def _remove_child_hooks(
175+
target_module: torch.nn.Module, hook_name: Optional[str] = None
176+
) -> None:
177+
for name, child in target_module._modules.items():
178+
if child is not None:
179+
_remove_hooks(child, hook_name)
180+
_remove_child_hooks(child, hook_name)
181+
182+
# Remove hooks from target submodules
183+
_remove_child_hooks(module, hook_fn_name)
184+
185+
# Remove hooks from the target module
186+
_remove_hooks(module, hook_fn_name)

0 commit comments

Comments
 (0)