Skip to content

Commit a31f98b

Browse files
authored
Removing the functionality to attach the backward hook to the module (#125)
* Removing the functionality to attach the backward hook to the module * Updated the number of traceevents as the backward hook is no longer registered.
1 parent a10db38 commit a31f98b

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

smdebug/pytorch/hook.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,14 @@ def register_module(self, module):
581581

582582
# Capture the gradient for each parameter in the net
583583
self._backward_apply(module)
584-
module.apply(self._closure_for_registering_backward_hook)
584+
585+
# TODO: Registering the backward hook causes issues in certain cases. There is a ‘Warning’ (
586+
# https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=register_backward_hook#torch.nn.Module.register_backward_hook) for using this hook in certain cases.
587+
# The ‘__call_impl” in PyTorch Module class makes some assumptions about ‘results’ returned from the forward pass of the module. It can not operate correctly if ‘forward’ pass returns anything other than dictionary of torch.Tensors. Some of the torchvision.transform classes returned ‘PIL’ image object and backward hook used to crash.
588+
# In some cases, we have seen the the training hangs. Hence currently the following functionality is
589+
# commented. We can revisit it after understanding the PyTorch's implementation of backward hook.
590+
591+
# module.apply(self._closure_for_registering_backward_hook)
585592

586593
self.has_registered_module = True
587594
self.count_parameters(module)

tests/profiler/pytorch/test_pytorch_profiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,4 @@ def test_pytorch_profiler(pytorch_profiler_config_parser, out_dir):
8181
lt.refresh_event_file_list()
8282
events = lt.get_events(0, time.time() * 1000000)
8383
print(f"Number of events {len(events)}")
84-
assert len(events) == 496
84+
assert len(events) == 386

0 commit comments

Comments
 (0)