Skip to content

Commit

Permalink
Safe import of LRScheduler (#29919)
Browse files Browse the repository at this point in the history
* Safe import of LRScheduler

* Update src/transformers/trainer_pt_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/trainer_pt_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Fix up

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
  • Loading branch information
2 people authored and Ita Zaporozhets committed May 14, 2024
1 parent bfb3a20 commit 5b6d5e7
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,18 @@
import torch
import torch.distributed as dist
from torch import nn
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import Dataset, IterableDataset, RandomSampler, Sampler
from torch.utils.data.distributed import DistributedSampler

from .integrations.deepspeed import is_deepspeed_zero3_enabled
from .tokenization_utils_base import BatchEncoding
from .utils import is_sagemaker_mp_enabled, is_torch_xla_available, is_training_run_on_sagemaker, logging
from .utils import (
is_sagemaker_mp_enabled,
is_torch_available,
is_torch_xla_available,
is_training_run_on_sagemaker,
logging,
)


if is_training_run_on_sagemaker():
Expand All @@ -49,6 +54,15 @@
if is_torch_xla_available():
import torch_xla.core.xla_model as xm

if is_torch_available():
from .pytorch_utils import is_torch_greater_or_equal_than_2_0

if is_torch_greater_or_equal_than_2_0:
from torch.optim.lr_scheduler import LRScheduler
else:
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler


# this is used to suppress an undesired warning emitted by pytorch versions 1.4.2-1.7.0
try:
from torch.optim.lr_scheduler import SAVE_STATE_WARNING
Expand Down

0 comments on commit 5b6d5e7

Please sign in to comment.