88from captum .optim ._utils .typing import ModuleOutputMapping , TupleOfTensorsOrTensorType
99
1010
11- class ModuleReuseException (Exception ):
12- pass
13-
14-
1511class ModuleOutputsHook :
1612 def __init__ (self , target_modules : Iterable [nn .Module ]) -> None :
13+ """
14+ Args:
15+
16+ target_modules (Iterable of nn.Module): A list of nn.Module targets.
17+ """
1718 self .outputs : ModuleOutputMapping = dict .fromkeys (target_modules , None )
1819 self .hooks = [
1920 module .register_forward_hook (self ._forward_hook ())
2021 for module in target_modules
2122 ]
2223
2324 def _reset_outputs (self ) -> None :
25+ """
26+ Delete captured activations.
27+ """
2428 self .outputs = dict .fromkeys (self .outputs .keys (), None )
2529
2630 @property
2731 def is_ready (self ) -> bool :
2832 return all (value is not None for value in self .outputs .values ())
2933
3034 def _forward_hook (self ) -> Callable :
35+ """
36+ Return the forward_hook function.
37+
38+ Returns:
39+ forward_hook (Callable): The forward_hook function.
40+ """
41+
3142 def forward_hook (
3243 module : nn .Module , input : Tuple [torch .Tensor ], output : torch .Tensor
3344 ) -> None :
@@ -49,6 +60,12 @@ def forward_hook(
4960 return forward_hook
5061
5162 def consume_outputs (self ) -> ModuleOutputMapping :
63+ """
64+ Collect target activations and return them.
65+
66+ Returns:
67+ outputs (ModuleOutputMapping): The captured outputs.
68+ """
5269 if not self .is_ready :
5370 warn (
5471 "Consume captured outputs, but not all requested target outputs "
@@ -63,11 +80,16 @@ def targets(self) -> Iterable[nn.Module]:
6380 return self .outputs .keys ()
6481
6582 def remove_hooks (self ) -> None :
83+ """
84+ Remove hooks.
85+ """
6686 for hook in self .hooks :
6787 hook .remove ()
6888
6989 def __del__ (self ) -> None :
70- # print(f"DEL HOOKS!: {list(self.outputs.keys())}")
90+ """
91+ Ensure that using 'del' properly deletes hooks.
92+ """
7193 self .remove_hooks ()
7294
7395
@@ -77,16 +99,34 @@ class ActivationFetcher:
7799 """
78100
79101 def __init__ (self , model : nn .Module , targets : Iterable [nn .Module ]) -> None :
102+ """
103+ Args:
104+
105+ model (nn.Module): The reference to PyTorch model instance.
106+ targets (nn.Module or list of nn.Module): The target layers to
107+ collect activations from.
108+ """
80109 super (ActivationFetcher , self ).__init__ ()
81110 self .model = model
82111 self .layers = ModuleOutputsHook (targets )
83112
84113 def __call__ (self , input_t : TupleOfTensorsOrTensorType ) -> ModuleOutputMapping :
114+ """
115+ Args:
116+
117+ input_t (tensor or tuple of tensors, optional): The input to use
118+ with the specified model.
119+
120+ Returns:
121+ activations_dict: An dict containing the collected activations. The keys
122+ for the returned dictionary are the target layers.
123+ """
124+
85125 try :
86126 with warnings .catch_warnings ():
87127 warnings .simplefilter ("ignore" )
88128 self .model (input_t )
89- activations = self .layers .consume_outputs ()
129+ activations_dict = self .layers .consume_outputs ()
90130 finally :
91131 self .layers .remove_hooks ()
92- return activations
132+ return activations_dict
0 commit comments