|
16 | 16 | ) |
17 | 17 |
|
18 | 18 | from captum.optim._core.loss import default_loss_summarize |
19 | | -from captum.optim._core.output_hook import ModuleOutputsHook |
| 19 | +from captum.optim._core.output_hook import ModuleOutputsHook, _remove_all_forward_hooks |
20 | 20 | from captum.optim._param.image.images import InputParameterization, NaturalImage |
21 | 21 | from captum.optim._param.image.transforms import RandomScale, RandomSpatialJitter |
22 | 22 | from captum.optim._utils.typing import ( |
@@ -196,6 +196,29 @@ def continue_while( |
196 | 196 | return continue_while |
197 | 197 |
|
198 | 198 |
|
| 199 | +def cleanup_module_hooks(modules: Union[nn.Module, List[nn.Module]) -> None: |
| 200 | + """ |
| 201 | + Remove any InputOptimization hooks from the specified modules. This may be useful |
| 202 | + in the event that something goes wrong in between creating the InputOptimization |
| 203 | + instance and running the optimization function, or if InputOptimization fails |
| 204 | + without properly removing it's hooks. |
| 205 | +
|
| 206 | + Warning: This function will remove all the hooks placed by InputOptimization |
| 207 | + instances on the target modules, and thus can interfere with using multiple |
| 208 | + InputOptimization instances. |
| 209 | +
|
| 210 | + Args: |
| 211 | +
|
| 212 | + modules (nn.Module or list of nn.Module): Any module instances that contain |
| 213 | + hooks created by InputOptimization, for which the removal of the hooks is |
| 214 | + required. |
| 215 | + """ |
| 216 | + if not hasattr(modules, "__iter__"): |
| 217 | + modules = [modules] |
| 218 | + # Captum ModuleOutputsHook uses "module_outputs_forward_hook" hook functions |
| 219 | + [_remove_all_forward_hooks(module, "module_outputs_forward_hook") for module in modules] |
| 220 | + |
| 221 | + |
199 | 222 | __all__ = [ |
200 | 223 | "InputOptimization", |
201 | 224 | "n_steps", |
|
0 commit comments