@@ -105,6 +105,7 @@ def __init__(
105105            if  torch .cuda .is_available ()
106106            else  torch .autograd .ProfilerState .CPU 
107107        )
108+         self .use_cuda  =  torch .cuda .is_available ()
108109
109110    def  log_trace_event (self , event ):
110111        self .record_trace_events (
@@ -224,7 +225,7 @@ def forward_pre_hook(self, module, inputs):
224225        elif  self .autograd_profiler_enabled :
225226            records  =  torch .autograd ._disable_profiler ()
226227            self .function_events  =  torch .autograd .profiler .EventList (
227-                 torch .autograd .profiler .parse_cpu_trace (records ), use_cuda = True 
228+                 torch .autograd .profiler .parse_cpu_trace (records ), use_cuda = self . use_cuda 
228229            )
229230            for  index , event  in  enumerate (self .function_events ):
230231                self .record_trace_events (
@@ -237,17 +238,16 @@ def forward_pre_hook(self, module, inputs):
237238                    tid = event .thread ,
238239                    step_num = self .step ,
239240                )
240-             for  k  in  event .kernels :
241-                 self .record_trace_events (
242-                     op_name = k .name ,
243-                     phase = "X" ,
244-                     # k.interval.start is in microseconds 
245-                     timestamp = (k .interval .start  +  self .start_profiler_time_us )
246-                     /  float (CONVERT_TO_MICROSECS ),
247-                     duration = k .interval .elapsed_us () /  float (CONVERT_TO_MICROSECS ),
248-                     tid = k .device ,
249-                     step_num = self .step ,
250-                 )
241+                 for  k  in  event .kernels :
242+                     self .record_trace_events (
243+                         op_name = k .name ,
244+                         phase = "X" ,
245+                         timestamp = (k .interval .start  +  self .start_profiler_time_us )
246+                         /  float (CONVERT_TO_MICROSECS ),
247+                         duration = k .interval .elapsed_us () /  float (CONVERT_TO_MICROSECS ),
248+                         tid = k .device ,
249+                         step_num = self .step ,
250+                     )
251251            self .autograd_profiler_enabled  =  False 
252252        # Write the gradients of the past step if the writer is still available. 
253253        if  self .writer  is  not   None :
@@ -419,7 +419,15 @@ def register_module(self, module):
419419            raise  ValueError (
420420                f"Module type { module .__class__ .__name__ }   must be type torch.nn.Module" 
421421            )
422- 
422+         # in case GPU is available but model has been loaded on CPU 
423+         for  parameter  in  module .parameters ():
424+             self .profiler  =  (
425+                 torch .autograd .ProfilerState .CUDA 
426+                 if  parameter .is_cuda 
427+                 else  torch .autograd .ProfilerState .CPU 
428+             )
429+             self .use_cuda  =  parameter .is_cuda 
430+             break 
423431        # Create an attribute and store the module name in the object 
424432        # So that it is available in the forward hook. 
425433
0 commit comments