Skip to content

Commit

Permalink
Updated SparseML callback for latest PyTorch Lightning (#822)
Browse files Browse the repository at this point in the history
* Updated callback
* Fix for `max_steps` (defaults to `-1`),
* Fix for deprecation of `num_gpus`, `num_processes` and `tpu_cores`.
* Avoid breaking changes for `max_steps=None`
* Added comments for clarity and future maintenance
  • Loading branch information
clementpoiret authored May 4, 2022
1 parent 9711eef commit 1e962eb
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions pl_bolts/callbacks/sparseml.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,22 @@ def _num_training_steps_per_epoch(self, trainer: Trainer) -> int:
else:
dataset_size = len(trainer.datamodule.train_dataloader())

num_devices = max(1, trainer.num_gpus, trainer.num_processes)
if trainer.tpu_cores:
num_devices = max(num_devices, trainer.tpu_cores)
if hasattr(trainer, 'num_devices'):
# New behavior in Lightning
num_devices = max(1, trainer.num_devices)
else:
# Old behavior deprecated in v1.6
num_devices = max(1, trainer.num_gpus, trainer.num_processes)
if trainer.tpu_cores:
num_devices = max(num_devices, trainer.tpu_cores)

effective_batch_size = trainer.accumulate_grad_batches * num_devices
max_estimated_steps = dataset_size // effective_batch_size

if trainer.max_steps and trainer.max_steps < max_estimated_steps:
return trainer.max_steps
# To avoid breaking changes, max_steps is set to -1 if it is not defined
max_steps = -1 if not trainer.max_steps else trainer.max_steps
if max_steps != -1 and max_steps < max_estimated_steps:
return max_steps
return max_estimated_steps

@staticmethod
Expand Down

0 comments on commit 1e962eb

Please sign in to comment.