Skip to content

Commit

Permalink
Only import torch.distributed if it is available (#35133)
Browse files Browse the repository at this point in the history
  • Loading branch information
GaetanLepage authored and ArthurZucker committed Dec 10, 2024
1 parent 4995230 commit bf5d7c3
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/transformers/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@
is_torch_greater_or_equal_than_1_13 = parsed_torch_version_base >= version.parse("1.13")
is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12")

# Cache this result has it's a C FFI call which can be pretty time-consuming
_torch_distributed_available = torch.distributed.is_available()

if is_torch_greater_or_equal("2.5"):
if is_torch_greater_or_equal("2.5") and _torch_distributed_available:
from torch.distributed.tensor import Replicate
from torch.distributed.tensor.parallel import (
ColwiseParallel,
Expand Down

0 comments on commit bf5d7c3

Please sign in to comment.