Skip to content

Commit 954f31c

Browse files
cyyeverSunMarc
andauthored
Add XPU case to is_torch_bf16_gpu_available (#37132)
* Add xpu case to is_torch_bf16_gpu_available Signed-off-by: cyy <cyyever@outlook.com> * Refine error messages Signed-off-by: cyy <cyyever@outlook.com> --------- Signed-off-by: cyy <cyyever@outlook.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
1 parent 28eae8b commit 954f31c

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

src/transformers/training_args.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
is_sagemaker_mp_enabled,
4646
is_torch_available,
4747
is_torch_bf16_gpu_available,
48+
is_torch_cuda_available,
4849
is_torch_hpu_available,
4950
is_torch_mlu_available,
5051
is_torch_mps_available,
@@ -1683,11 +1684,12 @@ def __post_init__(self):
16831684
# cpu
16841685
raise ValueError("Your setup doesn't support bf16/(cpu, tpu, neuroncore). You need torch>=1.10")
16851686
elif not self.use_cpu:
1686-
if torch.cuda.is_available() and not is_torch_bf16_gpu_available():
1687+
if not is_torch_bf16_gpu_available():
1688+
error_message = "Your setup doesn't support bf16/gpu."
1689+
if is_torch_cuda_available():
1690+
error_message += " You need Ampere+ GPU with cuda>=11.0"
16871691
# gpu
1688-
raise ValueError(
1689-
"Your setup doesn't support bf16/gpu. You need torch>=1.10, using Ampere GPU with cuda>=11.0"
1690-
)
1692+
raise ValueError(error_message)
16911693

16921694
if self.fp16 and self.bf16:
16931695
raise ValueError("At most one of fp16 and bf16 can be True, but not both")

src/transformers/utils/import_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -512,13 +512,17 @@ def is_torch_mps_available(min_version: Optional[str] = None):
512512
return False
513513

514514

515-
def is_torch_bf16_gpu_available():
515+
def is_torch_bf16_gpu_available() -> bool:
516516
if not is_torch_available():
517517
return False
518518

519519
import torch
520520

521-
return torch.cuda.is_available() and torch.cuda.is_bf16_supported()
521+
if torch.cuda.is_available():
522+
return torch.cuda.is_bf16_supported()
523+
if torch.xpu.is_available():
524+
return torch.xpu.is_bf16_supported()
525+
return False
522526

523527

524528
def is_torch_bf16_cpu_available() -> bool:

0 commit comments

Comments
 (0)