11# Standard Library
22import os
33import time
4+ from packaging import version
45
56# Third Party
67import torch
@@ -257,13 +258,13 @@ def _collect_torch_profiling_data_if_profiler_enabled(self):
257258 cpu_thread_start_time = event .cpu_interval .start + self .start_profiler_time_us ,
258259 )
259260
261+
260262 # This hook is invoked by trainer prior to running the forward pass.
261263 def forward_pre_hook (self , module , inputs ):
262264 # Disable pre-step 0 python profiling if profiling is enabled and if this is step 0.
263265 # below we say step+1 because step is incremented further in this method
264266 if python_profiler :
265267 python_profiler .stop_profiling (step_phase = "startstep" + str (self .step + 1 ))
266-
267268 # Write the gradients of the past step if the writer is still available.
268269 if self .writer is not None :
269270 self ._close_writers ()
@@ -287,9 +288,13 @@ def forward_pre_hook(self, module, inputs):
287288 python_profiler .start_profiling (self .step , step_phase = "startstep" + str (self .step ))
288289
289290 if not self .autograd_profiler_enabled :
290- torch .autograd ._enable_profiler (torch .autograd .ProfilerConfig (self .profiler , False ))
291+ if version .parse (torch .__version__ ) <= version .parse ("1.5.1" ):
292+ torch .autograd ._enable_profiler (torch .autograd .ProfilerConfig (self .profiler , False ))
293+ elif version .parse (torch .__version__ ) >= version .parse ("1.6" ):
294+ torch .autograd ._enable_profiler (torch .autograd .ProfilerConfig (self .profiler , False , False ))
291295 self .start_profiler_time_us = time .time () * CONVERT_TO_MICROSECS
292296 self .autograd_profiler_enabled = True
297+
293298 if self ._get_collections_to_save_for_step ():
294299 self ._initialize_writers ()
295300 self ._log_params (module )
@@ -426,7 +431,7 @@ def bhook(self, module, grad_input, grad_output):
426431 if self .step_event :
427432 self .step_event .update_end_time (now )
428433
429- # if this is first forward we will use start time of parent as start time, and end time as now
434+ # if this is not first backward we will use start time of parent as start time, and end time as now
430435 if len (self .backward_modules_profile_stats ) > 0 :
431436 # this child start_time is approcximated as last child end time
432437 child_start_time = self .backward_modules_profile_stats [- 1 ].end_time
@@ -453,6 +458,16 @@ def bhook(self, module, grad_input, grad_output):
453458 def _closure_for_registering_backward_hook (self , module ):
454459 module .register_backward_hook (self .bhook )
455460
461+ def count_parameters (self , model ):
462+ total_params = 0
463+ for name , parameter in model .named_parameters ():
464+ if not parameter .requires_grad : continue
465+ param = parameter .numel ()
466+ self .logger .info (f"name:{ name } count_params:{ param } " )
467+ total_params += param
468+ self .logger .info (f"Total Trainable Params: { total_params } " )
469+ return total_params
470+
456471 def register_module (self , module ):
457472 """
458473 This function registers the forward hook. If user wants to register the hook
@@ -496,6 +511,8 @@ def register_module(self, module):
496511 module .apply (self ._closure_for_registering_backward_hook )
497512
498513 self .has_registered_module = True
514+ self .count_parameters (module )
515+
499516
500517 def register_loss (self , loss_module ):
501518 """Register something like `criterion = nn.CrossEntropyLoss()`."""
0 commit comments