Skip to content

Commit b4ac463

Browse files
committed
only A770 will fallback to fp16
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
1 parent 4e2774b commit b4ac463

File tree

1 file changed

+10
-13
lines changed

1 file changed

+10
-13
lines changed

vllm/platforms/xpu.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
9090
if model_config.dtype == torch.bfloat16:
9191
bf16_supported = cls.device_support_bf16()
9292
if not bf16_supported:
93-
logger.warning(
94-
"bfloat16 is only supported on Intel Data Center GPU, "
95-
"Intel Arc GPU is not supported yet. Your device is %s,"
96-
" which is not supported. will fallback to float16",
97-
cls.get_device_name())
9893
model_config.dtype = torch.float16
9994
if not model_config.enforce_eager:
10095
logger.warning(
@@ -162,24 +157,26 @@ def get_current_memory_usage(cls,
162157
@classmethod
163158
def device_support_bf16(cls) -> bool:
164159
device_name = cls.get_device_name().lower()
165-
if cls.is_client_gpu():
160+
if cls.is_client_gpu_a770():
161+
logger.warning("Intel Arc A770 have bfloat16 accuracy known issue,"
162+
" fallback to float16")
166163
return False
167-
elif cls.is_data_center_gpu():
168-
return True
169164
else:
170-
logger.warning("Unknown device name %s, always use float16",
171-
device_name)
172-
return False
165+
logger.info(
166+
"Device name %s supports bfloat16. Please file an issue "
167+
"if you encounter any accuracy problems with bfloat16.",
168+
device_name)
169+
return True
173170

174171
@classmethod
175172
def is_data_center_gpu(cls) -> bool:
176173
device_name = cls.get_device_name().lower()
177174
return device_name.count("data center gpu") > 0
178175

179176
@classmethod
180-
def is_client_gpu(cls) -> bool:
177+
def is_client_gpu_a770(cls) -> bool:
181178
device_name = cls.get_device_name().lower()
182-
return device_name.count("arc") > 0
179+
return device_name.count("a770") > 0
183180

184181
@classmethod
185182
def get_device_communicator_cls(cls) -> str:

0 commit comments

Comments
 (0)