Skip to content

Commit 5b588db

Browse files
authored
Merge pull request #13 from awslabs/pre_commit_fix
running pre-commit
2 parents 4f650b7 + 0905968 commit 5b588db

File tree

4 files changed

+14
-10
lines changed

4 files changed

+14
-10
lines changed

smdebug/profiler/python_profiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def _dump_stats(self, stats_file_path):
157157
get_logger("smdebug-profiler").info(f"Dumping cProfile stats to {stats_file_path}.")
158158
pstats.Stats(self._profiler).dump_stats(stats_file_path)
159159

160+
160161
class PyinstrumentPythonProfiler(PythonProfiler):
161162
"""Higher level class to oversee profiling specific to Pyinstrument, a third party Python profiler.
162163
"""

smdebug/pytorch/hook.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# Standard Library
22
import os
33
import time
4-
from packaging import version
54

65
# Third Party
76
import torch
87
import torch.distributed as dist
8+
from packaging import version
99

1010
# First Party
1111
from smdebug.core.collection import DEFAULT_PYTORCH_COLLECTIONS, CollectionKeys
@@ -258,7 +258,6 @@ def _collect_torch_profiling_data_if_profiler_enabled(self):
258258
cpu_thread_start_time=event.cpu_interval.start + self.start_profiler_time_us,
259259
)
260260

261-
262261
# This hook is invoked by trainer prior to running the forward pass.
263262
def forward_pre_hook(self, module, inputs):
264263
# Disable pre-step 0 python profiling if profiling is enabled and if this is step 0.
@@ -289,9 +288,13 @@ def forward_pre_hook(self, module, inputs):
289288

290289
if not self.autograd_profiler_enabled:
291290
if version.parse(torch.__version__) <= version.parse("1.5.1"):
292-
torch.autograd._enable_profiler(torch.autograd.ProfilerConfig(self.profiler, False))
293-
elif version.parse(torch.__version__) >= version.parse("1.6"):
294-
torch.autograd._enable_profiler(torch.autograd.ProfilerConfig(self.profiler, False, False))
291+
torch.autograd._enable_profiler(
292+
torch.autograd.ProfilerConfig(self.profiler, False)
293+
)
294+
elif version.parse(torch.__version__) >= version.parse("1.6"):
295+
torch.autograd._enable_profiler(
296+
torch.autograd.ProfilerConfig(self.profiler, False, False)
297+
)
295298
self.start_profiler_time_us = time.time() * CONVERT_TO_MICROSECS
296299
self.autograd_profiler_enabled = True
297300

@@ -461,10 +464,11 @@ def _closure_for_registering_backward_hook(self, module):
461464
def count_parameters(self, model):
462465
total_params = 0
463466
for name, parameter in model.named_parameters():
464-
if not parameter.requires_grad: continue
467+
if not parameter.requires_grad:
468+
continue
465469
param = parameter.numel()
466470
self.logger.info(f"name:{name} count_params:{param}")
467-
total_params+=param
471+
total_params += param
468472
self.logger.info(f"Total Trainable Params: {total_params}")
469473
return total_params
470474

@@ -513,7 +517,6 @@ def register_module(self, module):
513517
self.has_registered_module = True
514518
self.count_parameters(module)
515519

516-
517520
def register_loss(self, loss_module):
518521
"""Register something like `criterion = nn.CrossEntropyLoss()`."""
519522
# Typechecking

tests/profiler/pytorch/test_pytorch_profiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,4 @@ def test_pytorch_profiler(pytorch_profiler_config_parser, out_dir):
8181
lt.refresh_event_file_list()
8282
events = lt.get_events(0, time.time() * 1000000)
8383
print(f"Number of events {len(events)}")
84-
assert len(events) == 379
84+
assert len(events) == 447

tests/profiler/pytorch/test_pytorch_profiler_rnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,5 +72,5 @@ def test_pytorch_profiler_rnn(pytorch_profiler_config_parser, out_dir):
7272
lt.refresh_event_file_list()
7373
events = lt.get_events(0, time.time() * 1000000)
7474
print(f"Number of events {len(events)}")
75-
assert len(events) == 59
75+
assert len(events) == 61
7676
shutil.rmtree(out_dir, ignore_errors=True)

0 commit comments

Comments
 (0)