diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index cb3d4a5bfe5b7b..2ca37180a70a4d 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -34,7 +34,8 @@ if is_torch_tpu_available(): import torch_xla.core.xla_model as xm -if version.parse(torch.__version__) <= version.parse("1.4.1"): +# this is used to supress an undesired warning emitted by pytorch versions 1.4.2-1.7.0 +if version.parse(torch.__version__) <= version.parse("1.4.1") or version.parse(torch.__version__) > version.parse("1.7.0"): SAVE_STATE_WARNING = "" else: from torch.optim.lr_scheduler import SAVE_STATE_WARNING