diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index f4cf668e427ed1..f19edba609ab7e 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -23,7 +23,7 @@ import numpy as np import torch -from torch.optim.lr_scheduler import SAVE_STATE_WARNING +from packaging import version from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import RandomSampler, Sampler @@ -34,6 +34,11 @@ if is_torch_tpu_available(): import torch_xla.core.xla_model as xm +if version.parse(torch.__version__) <= version.parse("1.4.1"): + SAVE_STATE_WARNING = "" +else: + from torch.optim.lr_scheduler import SAVE_STATE_WARNING + logger = logging.get_logger(__name__)