@@ -46,12 +46,11 @@ class Hook(CallbackHook):
4646 """
4747
4848 class _TraceEventData :
49- def __init__ (self , phase , op_name , start_time , dur , pid , ** kwargs ):
49+ def __init__ (self , phase , op_name , start_time , dur , ** kwargs ):
5050 self .training_phase = phase
5151 self .end_time = start_time + dur
5252 self .start_time = start_time
5353 self .op_name = op_name
54- self .pid = pid
5554 self .kwargs = kwargs
5655
5756 def update_end_time (self , end_time = time .time ()):
@@ -281,6 +280,10 @@ def forward_pre_hook(self, module, inputs):
281280
282281 self ._increment_step ()
283282
283+ ## prepararing for step metrics
284+ # last operation can be forward( eval loop is running or multiple forward for example RNN can have multiple call to forward of module)
285+ # or last operation can be backward (train backward loop just finished and we are at forward again)
286+
284287 # we will log all outstanding forward and backward events
285288 self .log_outstanding_timeline_metrics ()
286289
@@ -292,6 +295,14 @@ def forward_pre_hook(self, module, inputs):
292295 pid = os .getpid (),
293296 step_num = str (self .mode_steps [self .mode ]),
294297 )
298+ self .parent_forward_event = self ._TraceEventData (
299+ phase = "Forward" ,
300+ op_name = module ._module_name ,
301+ start_time = time .time (),
302+ dur = 0 , # end time of parent_forward_event will be updated every time a forward event is called after this
303+ pid = os .getpid (),
304+ step_num = str (self .mode_steps [self .mode ]),
305+ )
295306
296307 # Disable python profiling if the python profiler is currently profiling.
297308 if python_profiler :
@@ -337,18 +348,6 @@ def forward_pre_hook(self, module, inputs):
337348 self .export_collections ()
338349 self .exported_collections = True
339350
340- ## prepararing for step metrics
341- # last operation can be forward( eval loop is running or multiple forward for example RNN can have multiple call to forward of module)
342- # or last operation can be backward (train backward loop just finished and we are at forward again)
343-
344- self .parent_forward_event = self ._TraceEventData (
345- phase = "Forward" ,
346- op_name = module ._module_name ,
347- start_time = time .time (),
348- dur = 0 , # end time of parent_forward_event will be updated every time a forward event is called after this
349- pid = os .getpid (),
350- step_num = str (self .mode_steps [self .mode ]),
351- )
352351 self .first_forward_submodule_name = None
353352
354353 def record_tensor_value (self , tensor_name : str , tensor_value : torch .Tensor ) -> None :
0 commit comments