@@ -90,11 +90,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
9090 if model_config .dtype == torch .bfloat16 :
9191 bf16_supported = cls .device_support_bf16 ()
9292 if not bf16_supported :
93- logger .warning (
94- "bfloat16 is only supported on Intel Data Center GPU, "
95- "Intel Arc GPU is not supported yet. Your device is %s,"
96- " which is not supported. will fallback to float16" ,
97- cls .get_device_name ())
9893 model_config .dtype = torch .float16
9994 if not model_config .enforce_eager :
10095 logger .warning (
@@ -162,24 +157,26 @@ def get_current_memory_usage(cls,
162157 @classmethod
163158 def device_support_bf16 (cls ) -> bool :
164159 device_name = cls .get_device_name ().lower ()
165- if cls .is_client_gpu ():
160+ if cls .is_client_gpu_a770 ():
161+ logger .warning ("Intel Arc A770 have bfloat16 accuracy known issue,"
162+ " fallback to float16" )
166163 return False
167- elif cls .is_data_center_gpu ():
168- return True
169164 else :
170- logger .warning ("Unknown device name %s, always use float16" ,
171- device_name )
172- return False
165+ logger .info (
166+ "Device name %s supports bfloat16. Please file an issue "
167+ "if you encounter any accuracy problems with bfloat16." ,
168+ device_name )
169+ return True
173170
174171 @classmethod
175172 def is_data_center_gpu (cls ) -> bool :
176173 device_name = cls .get_device_name ().lower ()
177174 return device_name .count ("data center gpu" ) > 0
178175
179176 @classmethod
180- def is_client_gpu (cls ) -> bool :
177+ def is_client_gpu_a770 (cls ) -> bool :
181178 device_name = cls .get_device_name ().lower ()
182- return device_name .count ("arc " ) > 0
179+ return device_name .count ("a770 " ) > 0
183180
184181 @classmethod
185182 def get_device_communicator_cls (cls ) -> str :
0 commit comments