Skip to content

Commit 8cd6b0c

Browse files
committed
Some cleanup, adding total time in cprofile
1 parent 1e5bd48 commit 8cd6b0c

File tree

2 files changed

+29
-4
lines changed

2 files changed

+29
-4
lines changed

smdebug/profiler/python_profiler.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,11 +124,19 @@ def _reset_profiler(self):
124124
"""Reset profiler and corresponding attributes to defaults
125125
"""
126126
super()._reset_profiler()
127-
self._profiler = cProfileProfiler()
127+
self._profiler = cProfileProfiler(self._total_time)
128128

129129
def _name(self):
130130
return "cProfile"
131131

132+
def _total_time(self):
133+
times = os.times()
134+
return times.elapsed
135+
136+
def _off_cpu_time(self):
137+
times = os.times()
138+
return times.elapsed - (times.system + times.user)
139+
132140
def _stats_filename(self):
133141
# this is default value
134142
return "python_stats"

smdebug/pytorch/hook.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Standard Library
22
import os
33
import time
4+
from packaging import version
45

56
# Third Party
67
import torch
@@ -257,13 +258,13 @@ def _collect_torch_profiling_data_if_profiler_enabled(self):
257258
cpu_thread_start_time=event.cpu_interval.start + self.start_profiler_time_us,
258259
)
259260

261+
260262
# This hook is invoked by trainer prior to running the forward pass.
261263
def forward_pre_hook(self, module, inputs):
262264
# Disable pre-step 0 python profiling if profiling is enabled and if this is step 0.
263265
# below we say step+1 because step is incremented further in this method
264266
if python_profiler:
265267
python_profiler.stop_profiling(step_phase="startstep" + str(self.step + 1))
266-
267268
# Write the gradients of the past step if the writer is still available.
268269
if self.writer is not None:
269270
self._close_writers()
@@ -287,9 +288,13 @@ def forward_pre_hook(self, module, inputs):
287288
python_profiler.start_profiling(self.step, step_phase="startstep" + str(self.step))
288289

289290
if not self.autograd_profiler_enabled:
290-
torch.autograd._enable_profiler(torch.autograd.ProfilerConfig(self.profiler, False))
291+
if version.parse(torch.__version__) <= version.parse("1.5.1"):
292+
torch.autograd._enable_profiler(torch.autograd.ProfilerConfig(self.profiler, False))
293+
elif version.parse(torch.__version__) >= version.parse("1.6"):
294+
torch.autograd._enable_profiler(torch.autograd.ProfilerConfig(self.profiler, False, False))
291295
self.start_profiler_time_us = time.time() * CONVERT_TO_MICROSECS
292296
self.autograd_profiler_enabled = True
297+
293298
if self._get_collections_to_save_for_step():
294299
self._initialize_writers()
295300
self._log_params(module)
@@ -426,7 +431,7 @@ def bhook(self, module, grad_input, grad_output):
426431
if self.step_event:
427432
self.step_event.update_end_time(now)
428433

429-
# if this is first forward we will use start time of parent as start time, and end time as now
434+
# if this is not first backward we will use start time of parent as start time, and end time as now
430435
if len(self.backward_modules_profile_stats) > 0:
431436
# this child start_time is approcximated as last child end time
432437
child_start_time = self.backward_modules_profile_stats[-1].end_time
@@ -453,6 +458,16 @@ def bhook(self, module, grad_input, grad_output):
453458
def _closure_for_registering_backward_hook(self, module):
454459
module.register_backward_hook(self.bhook)
455460

461+
def count_parameters(self, model):
462+
total_params = 0
463+
for name, parameter in model.named_parameters():
464+
if not parameter.requires_grad: continue
465+
param = parameter.numel()
466+
self.logger.info(f"name:{name} count_params:{param}")
467+
total_params+=param
468+
self.logger.info(f"Total Trainable Params: {total_params}")
469+
return total_params
470+
456471
def register_module(self, module):
457472
"""
458473
This function registers the forward hook. If user wants to register the hook
@@ -496,6 +511,8 @@ def register_module(self, module):
496511
module.apply(self._closure_for_registering_backward_hook)
497512

498513
self.has_registered_module = True
514+
self.count_parameters(module)
515+
499516

500517
def register_loss(self, loss_module):
501518
"""Register something like `criterion = nn.CrossEntropyLoss()`."""

0 commit comments

Comments
 (0)