|
1 | 1 | #!/usr/bin/env python3 |
2 | | -from typing import cast |
| 2 | +from collections import OrderedDict |
| 3 | +from typing import List, Optional, cast |
3 | 4 |
|
4 | 5 | import captum.optim._core.output_hook as output_hook |
5 | 6 | import torch |
6 | 7 | from captum.optim.models import googlenet |
7 | | -from tests.helpers.basic import BaseTest |
| 8 | +from tests.helpers.basic import BaseTest, assertTensorAlmostEqual |
| 9 | + |
| 10 | + |
| 11 | +def _count_forward_hooks( |
| 12 | + module: torch.nn.Module, hook_fn_name: Optional[str] = None |
| 13 | +) -> int: |
| 14 | + """ |
| 15 | + Count the number of active forward hooks on the specified module instance. |
| 16 | +
|
| 17 | + Args: |
| 18 | +
|
| 19 | + module (nn.Module): The model module instance to count the number of |
| 20 | + forward hooks on. |
| 21 | + name (str, optional): Optionally only count specific forward hooks based on |
| 22 | + their function's __name__ attribute. |
| 23 | + Default: None |
| 24 | +
|
| 25 | + Returns: |
| 26 | + num_hooks (int): The number of active hooks in the specified module. |
| 27 | + """ |
| 28 | + |
| 29 | + num_hooks: List[int] = [0] |
| 30 | + |
| 31 | + def _count_hooks(m: torch.nn.Module, name: Optional[str] = None) -> None: |
| 32 | + if hasattr(m, "_forward_hooks"): |
| 33 | + if m._forward_hooks != OrderedDict(): |
| 34 | + dict_items = list(m._forward_hooks.items()) |
| 35 | + for i, fn in dict_items: |
| 36 | + if hook_fn_name is None or fn.__name__ == name: |
| 37 | + num_hooks[0] += 1 |
| 38 | + |
| 39 | + def _count_child_hooks( |
| 40 | + target_module: torch.nn.Module, |
| 41 | + hook_name: Optional[str] = None, |
| 42 | + ) -> None: |
| 43 | + |
| 44 | + for name, child in target_module._modules.items(): |
| 45 | + if child is not None: |
| 46 | + _count_hooks(child, hook_name) |
| 47 | + _count_child_hooks(child, hook_name) |
| 48 | + |
| 49 | + _count_child_hooks(module, hook_fn_name) |
| 50 | + _count_hooks(module, hook_fn_name) |
| 51 | + return num_hooks[0] |
| 52 | + |
| 53 | + |
| 54 | +class TestModuleOutputsHook(BaseTest): |
| 55 | + def test_init_single_target(self) -> None: |
| 56 | + model = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity()) |
| 57 | + target_modules = [model[0]] |
| 58 | + |
| 59 | + hook_module = output_hook.ModuleOutputsHook(target_modules) |
| 60 | + self.assertEqual(len(hook_module.hooks), len(target_modules)) |
| 61 | + |
| 62 | + n_hooks = _count_forward_hooks(model, "module_outputs_forward_hook") |
| 63 | + self.assertEqual(n_hooks, len(target_modules)) |
| 64 | + |
| 65 | + outputs = dict.fromkeys(target_modules, None) |
| 66 | + self.assertEqual(outputs, hook_module.outputs) |
| 67 | + self.assertEqual(list(hook_module.targets), target_modules) |
| 68 | + self.assertFalse(hook_module.is_ready) |
| 69 | + |
| 70 | + def test_init_multiple_targets(self) -> None: |
| 71 | + model = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity()) |
| 72 | + target_modules = [model[0], model[1]] |
| 73 | + |
| 74 | + hook_module = output_hook.ModuleOutputsHook(target_modules) |
| 75 | + self.assertEqual(len(hook_module.hooks), len(target_modules)) |
| 76 | + |
| 77 | + n_hooks = _count_forward_hooks(model, "module_outputs_forward_hook") |
| 78 | + self.assertEqual(n_hooks, len(target_modules)) |
| 79 | + |
| 80 | + outputs = dict.fromkeys(target_modules, None) |
| 81 | + self.assertEqual(outputs, hook_module.outputs) |
| 82 | + self.assertEqual(list(hook_module.targets), target_modules) |
| 83 | + self.assertFalse(hook_module.is_ready) |
| 84 | + |
| 85 | + def test_init_multiple_targets_remove_hooks(self) -> None: |
| 86 | + model = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity()) |
| 87 | + target_modules = [model[0], model[1]] |
| 88 | + |
| 89 | + hook_module = output_hook.ModuleOutputsHook(target_modules) |
| 90 | + |
| 91 | + n_hooks = _count_forward_hooks(model, "module_outputs_forward_hook") |
| 92 | + self.assertEqual(n_hooks, len(target_modules)) |
| 93 | + |
| 94 | + hook_module.remove_hooks() |
| 95 | + |
| 96 | + n_hooks = _count_forward_hooks(model, "module_outputs_forward_hook") |
| 97 | + self.assertEqual(n_hooks, 0) |
| 98 | + |
| 99 | + def test_reset_outputs_multiple_targets(self) -> None: |
| 100 | + model = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity()) |
| 101 | + target_modules = [model[0], model[1]] |
| 102 | + test_input = torch.randn(1, 3, 4, 4) |
| 103 | + |
| 104 | + hook_module = output_hook.ModuleOutputsHook(target_modules) |
| 105 | + self.assertFalse(hook_module.is_ready) |
| 106 | + |
| 107 | + _ = model(test_input) |
| 108 | + |
| 109 | + self.assertTrue(hook_module.is_ready) |
| 110 | + |
| 111 | + outputs_dict = hook_module.outputs |
| 112 | + i = 0 |
| 113 | + for target, activations in outputs_dict.items(): |
| 114 | + self.assertEqual(target, target_modules[i]) |
| 115 | + assertTensorAlmostEqual(self, activations, test_input) |
| 116 | + i += 1 |
| 117 | + |
| 118 | + hook_module._reset_outputs() |
| 119 | + |
| 120 | + self.assertFalse(hook_module.is_ready) |
| 121 | + |
| 122 | + expected_outputs = dict.fromkeys(target_modules, None) |
| 123 | + self.assertEqual(hook_module.outputs, expected_outputs) |
| 124 | + |
| 125 | + def test_consume_outputs_multiple_targets(self) -> None: |
| 126 | + model = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity()) |
| 127 | + target_modules = [model[0], model[1]] |
| 128 | + test_input = torch.randn(1, 3, 4, 4) |
| 129 | + |
| 130 | + hook_module = output_hook.ModuleOutputsHook(target_modules) |
| 131 | + self.assertFalse(hook_module.is_ready) |
| 132 | + |
| 133 | + _ = model(test_input) |
| 134 | + |
| 135 | + self.assertTrue(hook_module.is_ready) |
| 136 | + |
| 137 | + test_outputs_dict = hook_module.outputs |
| 138 | + self.assertIsInstance(test_outputs_dict, dict) |
| 139 | + self.assertEqual(len(test_outputs_dict), len(target_modules)) |
| 140 | + |
| 141 | + i = 0 |
| 142 | + for target, activations in test_outputs_dict.items(): |
| 143 | + self.assertEqual(target, target_modules[i]) |
| 144 | + assertTensorAlmostEqual(self, activations, test_input) |
| 145 | + i += 1 |
| 146 | + |
| 147 | + test_output = hook_module.consume_outputs() |
| 148 | + |
| 149 | + self.assertFalse(hook_module.is_ready) |
| 150 | + |
| 151 | + i = 0 |
| 152 | + for target, activations in test_output.items(): |
| 153 | + self.assertEqual(target, target_modules[i]) |
| 154 | + assertTensorAlmostEqual(self, activations, test_input) |
| 155 | + i += 1 |
| 156 | + |
| 157 | + expected_outputs = dict.fromkeys(target_modules, None) |
| 158 | + self.assertEqual(hook_module.outputs, expected_outputs) |
| 159 | + |
| 160 | + def test_consume_outputs_warning(self) -> None: |
| 161 | + model = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity()) |
| 162 | + target_modules = [model[0], model[1]] |
| 163 | + test_input = torch.randn(1, 3, 4, 4) |
| 164 | + |
| 165 | + hook_module = output_hook.ModuleOutputsHook(target_modules) |
| 166 | + self.assertFalse(hook_module.is_ready) |
| 167 | + |
| 168 | + _ = model(test_input) |
| 169 | + |
| 170 | + self.assertTrue(hook_module.is_ready) |
| 171 | + |
| 172 | + hook_module._reset_outputs() |
| 173 | + |
| 174 | + self.assertFalse(hook_module.is_ready) |
| 175 | + |
| 176 | + with self.assertWarns(Warning): |
| 177 | + _ = hook_module.consume_outputs() |
8 | 178 |
|
9 | 179 |
|
10 | 180 | class TestActivationFetcher(BaseTest): |
11 | | - def test_activation_fetcher(self) -> None: |
| 181 | + def test_activation_fetcher_simple_model(self) -> None: |
| 182 | + model = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity()) |
| 183 | + |
| 184 | + catch_activ = output_hook.ActivationFetcher(model, targets=[model[0]]) |
| 185 | + test_input = torch.randn(1, 3, 224, 224) |
| 186 | + activ_out = catch_activ(test_input) |
| 187 | + |
| 188 | + self.assertIsInstance(activ_out, dict) |
| 189 | + self.assertEqual(len(activ_out), 1) |
| 190 | + activ = activ_out[model[0]] |
| 191 | + assertTensorAlmostEqual(self, activ, test_input) |
| 192 | + |
| 193 | + def test_activation_fetcher_single_target(self) -> None: |
12 | 194 | model = googlenet(pretrained=True) |
13 | 195 |
|
14 | 196 | catch_activ = output_hook.ActivationFetcher(model, targets=[model.mixed4d]) |
15 | 197 | activ_out = catch_activ(torch.zeros(1, 3, 224, 224)) |
16 | 198 |
|
17 | 199 | self.assertIsInstance(activ_out, dict) |
| 200 | + self.assertEqual(len(activ_out), 1) |
18 | 201 | m4d_activ = activ_out[model.mixed4d] |
19 | 202 | self.assertEqual(list(cast(torch.Tensor, m4d_activ).shape), [1, 528, 14, 14]) |
| 203 | + |
| 204 | + def test_activation_fetcher_multiple_targets(self) -> None: |
| 205 | + model = googlenet(pretrained=True) |
| 206 | + |
| 207 | + catch_activ = output_hook.ActivationFetcher( |
| 208 | + model, targets=[model.mixed4d, model.mixed5b] |
| 209 | + ) |
| 210 | + activ_out = catch_activ(torch.zeros(1, 3, 224, 224)) |
| 211 | + |
| 212 | + self.assertIsInstance(activ_out, dict) |
| 213 | + self.assertEqual(len(activ_out), 2) |
| 214 | + |
| 215 | + m4d_activ = activ_out[model.mixed4d] |
| 216 | + self.assertEqual(list(cast(torch.Tensor, m4d_activ).shape), [1, 528, 14, 14]) |
| 217 | + |
| 218 | + m5b_activ = activ_out[model.mixed5b] |
| 219 | + self.assertEqual(list(cast(torch.Tensor, m5b_activ).shape), [1, 1024, 7, 7]) |
0 commit comments