From aa0969c2f3cf5cf7ce5690e7b77f95b1622988d6 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Mon, 2 Nov 2020 10:12:13 -0500 Subject: [PATCH] Fix bad import with PyTorch <= 1.4.1 --- src/transformers/trainer_pt_utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index f4cf668e427ed1..f19edba609ab7e 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -23,7 +23,7 @@ import numpy as np import torch -from torch.optim.lr_scheduler import SAVE_STATE_WARNING +from packaging import version from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import RandomSampler, Sampler @@ -34,6 +34,11 @@ if is_torch_tpu_available(): import torch_xla.core.xla_model as xm +if version.parse(torch.__version__) <= version.parse("1.4.1"): + SAVE_STATE_WARNING = "" +else: + from torch.optim.lr_scheduler import SAVE_STATE_WARNING + logger = logging.get_logger(__name__)