diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index 1f90ffe0d2a3..bd06f0ef3dc9 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -960,6 +960,18 @@ def state_dict(self) -> Dict[str, Any]: def load_state_dict(self, state_dict: Dict[str, Any]) -> None: return + def _check_time_remaining(self, trainer: "pl.Trainer") -> None: + super()._check_time_remaining(trainer) + if trainer.should_stop: + checkpoint_callback: Optional[NeMoModelCheckpoint] = trainer.checkpoint_callback + if checkpoint_callback: + monitor_candidates = checkpoint_callback._monitor_candidates(trainer) + checkpoint_callback._save_last_checkpoint(trainer, monitor_candidates) + # Throw this exception to signal to Lightning to terminate gracefully. + from pytorch_lightning.utilities.exceptions import _TunerExitException + + raise _TunerExitException() + def configure_no_restart_validation_training_loop(trainer: pytorch_lightning.Trainer) -> None: if type(trainer.fit_loop.epoch_loop) != _TrainingEpochLoop: