|
1 | 1 | # Standard Library |
2 | 2 | import os |
3 | 3 | import time |
4 | | -from packaging import version |
5 | 4 |
|
6 | 5 | # Third Party |
7 | 6 | import torch |
8 | 7 | import torch.distributed as dist |
| 8 | +from packaging import version |
9 | 9 |
|
10 | 10 | # First Party |
11 | 11 | from smdebug.core.collection import DEFAULT_PYTORCH_COLLECTIONS, CollectionKeys |
@@ -258,7 +258,6 @@ def _collect_torch_profiling_data_if_profiler_enabled(self): |
258 | 258 | cpu_thread_start_time=event.cpu_interval.start + self.start_profiler_time_us, |
259 | 259 | ) |
260 | 260 |
|
261 | | - |
262 | 261 | # This hook is invoked by trainer prior to running the forward pass. |
263 | 262 | def forward_pre_hook(self, module, inputs): |
264 | 263 | # Disable pre-step 0 python profiling if profiling is enabled and if this is step 0. |
@@ -289,9 +288,13 @@ def forward_pre_hook(self, module, inputs): |
289 | 288 |
|
290 | 289 | if not self.autograd_profiler_enabled: |
291 | 290 | if version.parse(torch.__version__) <= version.parse("1.5.1"): |
292 | | - torch.autograd._enable_profiler(torch.autograd.ProfilerConfig(self.profiler, False)) |
293 | | - elif version.parse(torch.__version__) >= version.parse("1.6"): |
294 | | - torch.autograd._enable_profiler(torch.autograd.ProfilerConfig(self.profiler, False, False)) |
| 291 | + torch.autograd._enable_profiler( |
| 292 | + torch.autograd.ProfilerConfig(self.profiler, False) |
| 293 | + ) |
| 294 | + elif version.parse(torch.__version__) >= version.parse("1.6"): |
| 295 | + torch.autograd._enable_profiler( |
| 296 | + torch.autograd.ProfilerConfig(self.profiler, False, False) |
| 297 | + ) |
295 | 298 | self.start_profiler_time_us = time.time() * CONVERT_TO_MICROSECS |
296 | 299 | self.autograd_profiler_enabled = True |
297 | 300 |
|
@@ -461,10 +464,11 @@ def _closure_for_registering_backward_hook(self, module): |
461 | 464 | def count_parameters(self, model): |
462 | 465 | total_params = 0 |
463 | 466 | for name, parameter in model.named_parameters(): |
464 | | - if not parameter.requires_grad: continue |
| 467 | + if not parameter.requires_grad: |
| 468 | + continue |
465 | 469 | param = parameter.numel() |
466 | 470 | self.logger.info(f"name:{name} count_params:{param}") |
467 | | - total_params+=param |
| 471 | + total_params += param |
468 | 472 | self.logger.info(f"Total Trainable Params: {total_params}") |
469 | 473 | return total_params |
470 | 474 |
|
@@ -513,7 +517,6 @@ def register_module(self, module): |
513 | 517 | self.has_registered_module = True |
514 | 518 | self.count_parameters(module) |
515 | 519 |
|
516 | | - |
517 | 520 | def register_loss(self, loss_module): |
518 | 521 | """Register something like `criterion = nn.CrossEntropyLoss()`.""" |
519 | 522 | # Typechecking |
|
0 commit comments