Skip to content

Commit 39f9919

Browse files
Reorder events in pytorch hook (#60)
1 parent ea4ac88 commit 39f9919

File tree

2 files changed

+13
-15
lines changed

2 files changed

+13
-15
lines changed

smdebug/pytorch/hook.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,11 @@ class Hook(CallbackHook):
4646
"""
4747

4848
class _TraceEventData:
49-
def __init__(self, phase, op_name, start_time, dur, pid, **kwargs):
49+
def __init__(self, phase, op_name, start_time, dur, **kwargs):
5050
self.training_phase = phase
5151
self.end_time = start_time + dur
5252
self.start_time = start_time
5353
self.op_name = op_name
54-
self.pid = pid
5554
self.kwargs = kwargs
5655

5756
def update_end_time(self, end_time=time.time()):
@@ -281,6 +280,10 @@ def forward_pre_hook(self, module, inputs):
281280

282281
self._increment_step()
283282

283+
## prepararing for step metrics
284+
# last operation can be forward( eval loop is running or multiple forward for example RNN can have multiple call to forward of module)
285+
# or last operation can be backward (train backward loop just finished and we are at forward again)
286+
284287
# we will log all outstanding forward and backward events
285288
self.log_outstanding_timeline_metrics()
286289

@@ -292,6 +295,14 @@ def forward_pre_hook(self, module, inputs):
292295
pid=os.getpid(),
293296
step_num=str(self.mode_steps[self.mode]),
294297
)
298+
self.parent_forward_event = self._TraceEventData(
299+
phase="Forward",
300+
op_name=module._module_name,
301+
start_time=time.time(),
302+
dur=0, # end time of parent_forward_event will be updated every time a forward event is called after this
303+
pid=os.getpid(),
304+
step_num=str(self.mode_steps[self.mode]),
305+
)
295306

296307
# Disable python profiling if the python profiler is currently profiling.
297308
if python_profiler:
@@ -337,18 +348,6 @@ def forward_pre_hook(self, module, inputs):
337348
self.export_collections()
338349
self.exported_collections = True
339350

340-
## prepararing for step metrics
341-
# last operation can be forward( eval loop is running or multiple forward for example RNN can have multiple call to forward of module)
342-
# or last operation can be backward (train backward loop just finished and we are at forward again)
343-
344-
self.parent_forward_event = self._TraceEventData(
345-
phase="Forward",
346-
op_name=module._module_name,
347-
start_time=time.time(),
348-
dur=0, # end time of parent_forward_event will be updated every time a forward event is called after this
349-
pid=os.getpid(),
350-
step_num=str(self.mode_steps[self.mode]),
351-
)
352351
self.first_forward_submodule_name = None
353352

354353
def record_tensor_value(self, tensor_name: str, tensor_value: torch.Tensor) -> None:

tests/sagemaker/pytorch_profiler_tests_config.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@
5555
ProfilerEnabled: "True"
5656
MetricsConfig: "{\"PythonProfilingConfig\": {\"NumSteps\": \"3\", \"ProfilerName\": \"pyinstrument\"}}"
5757
RotateMaxFileSizeInBytes: "12485760.34"
58-
RotateFileCloseIntervalInSeconds: "100.45"
5958
FileOpenFailThreshold: "35"
6059
- expected_values_in_test_artifacts:
6160
python_trace_file_count: 7

0 commit comments

Comments
 (0)