File tree Expand file tree Collapse file tree 2 files changed +12
-6
lines changed Expand file tree Collapse file tree 2 files changed +12
-6
lines changed Original file line number Diff line number Diff line change 4545 is_sagemaker_mp_enabled ,
4646 is_torch_available ,
4747 is_torch_bf16_gpu_available ,
48+ is_torch_cuda_available ,
4849 is_torch_hpu_available ,
4950 is_torch_mlu_available ,
5051 is_torch_mps_available ,
@@ -1683,11 +1684,12 @@ def __post_init__(self):
16831684 # cpu
16841685 raise ValueError ("Your setup doesn't support bf16/(cpu, tpu, neuroncore). You need torch>=1.10" )
16851686 elif not self .use_cpu :
1686- if torch .cuda .is_available () and not is_torch_bf16_gpu_available ():
1687+ if not is_torch_bf16_gpu_available ():
1688+ error_message = "Your setup doesn't support bf16/gpu."
1689+ if is_torch_cuda_available ():
1690+ error_message += " You need Ampere+ GPU with cuda>=11.0"
16871691 # gpu
1688- raise ValueError (
1689- "Your setup doesn't support bf16/gpu. You need torch>=1.10, using Ampere GPU with cuda>=11.0"
1690- )
1692+ raise ValueError (error_message )
16911693
16921694 if self .fp16 and self .bf16 :
16931695 raise ValueError ("At most one of fp16 and bf16 can be True, but not both" )
Original file line number Diff line number Diff line change @@ -512,13 +512,17 @@ def is_torch_mps_available(min_version: Optional[str] = None):
512512 return False
513513
514514
515- def is_torch_bf16_gpu_available ():
515+ def is_torch_bf16_gpu_available () -> bool :
516516 if not is_torch_available ():
517517 return False
518518
519519 import torch
520520
521- return torch .cuda .is_available () and torch .cuda .is_bf16_supported ()
521+ if torch .cuda .is_available ():
522+ return torch .cuda .is_bf16_supported ()
523+ if torch .xpu .is_available ():
524+ return torch .xpu .is_bf16_supported ()
525+ return False
522526
523527
524528def is_torch_bf16_cpu_available () -> bool :
You can’t perform that action at this time.
0 commit comments