Skip to content

Commit d934493

Browse files
authored
Bug fixes for autograd profiler in Pytorch hook. (#50)
* fixed pytorch hook * fixed merge conflict * fixed bug in hook
1 parent 726ef04 commit d934493

File tree

1 file changed

+21
-13
lines changed

1 file changed

+21
-13
lines changed

smdebug/pytorch/hook.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def __init__(
105105
if torch.cuda.is_available()
106106
else torch.autograd.ProfilerState.CPU
107107
)
108+
self.use_cuda = torch.cuda.is_available()
108109

109110
def log_trace_event(self, event):
110111
self.record_trace_events(
@@ -224,7 +225,7 @@ def forward_pre_hook(self, module, inputs):
224225
elif self.autograd_profiler_enabled:
225226
records = torch.autograd._disable_profiler()
226227
self.function_events = torch.autograd.profiler.EventList(
227-
torch.autograd.profiler.parse_cpu_trace(records), use_cuda=True
228+
torch.autograd.profiler.parse_cpu_trace(records), use_cuda=self.use_cuda
228229
)
229230
for index, event in enumerate(self.function_events):
230231
self.record_trace_events(
@@ -237,17 +238,16 @@ def forward_pre_hook(self, module, inputs):
237238
tid=event.thread,
238239
step_num=self.step,
239240
)
240-
for k in event.kernels:
241-
self.record_trace_events(
242-
op_name=k.name,
243-
phase="X",
244-
# k.interval.start is in microseconds
245-
timestamp=(k.interval.start + self.start_profiler_time_us)
246-
/ float(CONVERT_TO_MICROSECS),
247-
duration=k.interval.elapsed_us() / float(CONVERT_TO_MICROSECS),
248-
tid=k.device,
249-
step_num=self.step,
250-
)
241+
for k in event.kernels:
242+
self.record_trace_events(
243+
op_name=k.name,
244+
phase="X",
245+
timestamp=(k.interval.start + self.start_profiler_time_us)
246+
/ float(CONVERT_TO_MICROSECS),
247+
duration=k.interval.elapsed_us() / float(CONVERT_TO_MICROSECS),
248+
tid=k.device,
249+
step_num=self.step,
250+
)
251251
self.autograd_profiler_enabled = False
252252
# Write the gradients of the past step if the writer is still available.
253253
if self.writer is not None:
@@ -419,7 +419,15 @@ def register_module(self, module):
419419
raise ValueError(
420420
f"Module type {module.__class__.__name__} must be type torch.nn.Module"
421421
)
422-
422+
# in case GPU is available but model has been loaded on CPU
423+
for parameter in module.parameters():
424+
self.profiler = (
425+
torch.autograd.ProfilerState.CUDA
426+
if parameter.is_cuda
427+
else torch.autograd.ProfilerState.CPU
428+
)
429+
self.use_cuda = parameter.is_cuda
430+
break
423431
# Create an attribute and store the module name in the object
424432
# So that it is available in the forward hook.
425433

0 commit comments

Comments
 (0)