diff --git a/src/super_gradients/training/sg_trainer/sg_trainer.py b/src/super_gradients/training/sg_trainer/sg_trainer.py index 8add251283..65395ba094 100755 --- a/src/super_gradients/training/sg_trainer/sg_trainer.py +++ b/src/super_gradients/training/sg_trainer/sg_trainer.py @@ -85,6 +85,7 @@ read_ckpt_state_dict, load_checkpoint_to_model, load_pretrained_weights, + get_scheduler_state, ) from super_gradients.training.datasets.datasets_utils import DatasetStatisticsTensorboardLogger from super_gradients.training.utils.callbacks import ( @@ -666,7 +667,7 @@ def _save_checkpoint( state["processing_params"] = processing_params if self._torch_lr_scheduler is not None: - state["torch_scheduler_state_dict"] = self._torch_lr_scheduler.state_dict() + state["torch_scheduler_state_dict"] = get_scheduler_state(self._torch_lr_scheduler) # SAVES CURRENT MODEL AS ckpt_latest self.sg_logger.add_checkpoint(tag="ckpt_latest.pth", state_dict=state, global_step=epoch) diff --git a/src/super_gradients/training/utils/checkpoint_utils.py b/src/super_gradients/training/utils/checkpoint_utils.py index 46de06fb58..c524462c9e 100644 --- a/src/super_gradients/training/utils/checkpoint_utils.py +++ b/src/super_gradients/training/utils/checkpoint_utils.py @@ -1,7 +1,7 @@ import collections import os import tempfile -from typing import Union, Mapping +from typing import Union, Mapping, Dict import pkg_resources import torch @@ -1629,3 +1629,19 @@ def _maybe_load_preprocessing_params(model: Union[nn.Module, HasPredict], checkp "predict make sure to call set_dataset_processing_params." ) return False + + +def get_scheduler_state(scheduler) -> Dict[str, Tensor]: + """ + Wrapper for getting a torch lr scheduler state dict, resolving some issues with CyclicLR + (see https://github.com/pytorch/pytorch/pull/91400) + :param scheduler: torch.optim.lr_scheduler._LRScheduler, the scheduler + :return: the scheduler's state_dict + """ + from super_gradients.training.utils import torch_version_is_greater_or_equal + from torch.optim.lr_scheduler import CyclicLR + + state = scheduler.state_dict() + if isinstance(scheduler, CyclicLR) and not torch_version_is_greater_or_equal(2, 0): + del state["_scale_fn_ref"] + return state