diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index 394d29411d4da5..13745be6c1eb20 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -34,13 +34,18 @@ import torch import torch.distributed as dist from torch import nn -from torch.optim.lr_scheduler import LRScheduler from torch.utils.data import Dataset, IterableDataset, RandomSampler, Sampler from torch.utils.data.distributed import DistributedSampler from .integrations.deepspeed import is_deepspeed_zero3_enabled from .tokenization_utils_base import BatchEncoding -from .utils import is_sagemaker_mp_enabled, is_torch_xla_available, is_training_run_on_sagemaker, logging +from .utils import ( + is_sagemaker_mp_enabled, + is_torch_available, + is_torch_xla_available, + is_training_run_on_sagemaker, + logging, +) if is_training_run_on_sagemaker(): @@ -49,6 +54,15 @@ if is_torch_xla_available(): import torch_xla.core.xla_model as xm +if is_torch_available(): + from .pytorch_utils import is_torch_greater_or_equal_than_2_0 + + if is_torch_greater_or_equal_than_2_0: + from torch.optim.lr_scheduler import LRScheduler + else: + from torch.optim.lr_scheduler import _LRScheduler as LRScheduler + + # this is used to suppress an undesired warning emitted by pytorch versions 1.4.2-1.7.0 try: from torch.optim.lr_scheduler import SAVE_STATE_WARNING