11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
4+ import os
45from typing import TYPE_CHECKING , Optional
56
67import torch
78
9+ import vllm .envs as envs
810from vllm .logger import init_logger
911from vllm .utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
1012
1113from .interface import DeviceCapability , Platform , PlatformEnum , _Backend
1214
1315if TYPE_CHECKING :
14- from vllm .config import VllmConfig
16+ from vllm .config import ModelConfig , VllmConfig
1517else :
18+ ModelConfig = None
1619 VllmConfig = None
1720
1821logger = init_logger (__name__ )
@@ -35,8 +38,13 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
3538 use_mla : bool ) -> str :
3639 if selected_backend != _Backend .IPEX :
3740 logger .info ("Cannot use %s backend on XPU." , selected_backend )
38- logger .info ("Using IPEX attention backend." )
39- return "vllm.attention.backends.ipex_attn.IpexAttnBackend"
41+ use_v1 = envs .VLLM_USE_V1
42+ if use_v1 :
43+ logger .info ("Using Flash Attention backend on V1 engine." )
44+ return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
45+ else :
46+ logger .info ("Using IPEX attention backend." )
47+ return "vllm.attention.backends.ipex_attn.IpexAttnBackend"
4048
4149 @classmethod
4250 def get_device_capability (
@@ -67,25 +75,27 @@ def inference_mode(cls):
6775 @classmethod
6876 def check_and_update_config (cls , vllm_config : VllmConfig ) -> None :
6977 cache_config = vllm_config .cache_config
78+ # in V1(or with ipex chunked prefill) block_size is 64
7079 if cache_config and cache_config .block_size is None :
71- cache_config .block_size = 16
72-
73- # check and update model config
74- model_config = vllm_config .model_config
75- if model_config .dtype == torch .bfloat16 :
76- bf16_supported = cls .device_support_bf16 ()
77- if not bf16_supported :
80+ if envs .VLLM_USE_V1 :
81+ cache_config .block_size = 64
82+ else :
83+ cache_config .block_size = 16
84+
85+ # Instances created using VllmConfig() typically have model_config as
86+ # None by default. The modification involves adding a check to prevent
87+ # potential null exceptions check and update model config.
88+ if vllm_config .model_config is not None :
89+ model_config = vllm_config .model_config
90+ if model_config .dtype == torch .bfloat16 :
91+ bf16_supported = cls .device_support_bf16 ()
92+ if not bf16_supported :
93+ model_config .dtype = torch .float16
94+ if not model_config .enforce_eager :
7895 logger .warning (
79- "bfloat16 is only supported on Intel Data Center GPU, "
80- "Intel Arc GPU is not supported yet. Your device is %s,"
81- " which is not supported. will fallback to float16" ,
82- cls .get_device_name ())
83- model_config .dtype = torch .float16
84- if not model_config .enforce_eager :
85- logger .warning (
86- "CUDA graph is not supported on XPU, fallback to the eager "
87- "mode." )
88- model_config .enforce_eager = True
96+ "CUDA graph is not supported on XPU, fallback to the eager "
97+ "mode." )
98+ model_config .enforce_eager = True
8999
90100 if vllm_config .speculative_config is not None :
91101 raise NotImplementedError (
@@ -96,21 +106,27 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
96106
97107 # check and update parallel config
98108 parallel_config = vllm_config .parallel_config
99- if parallel_config .worker_cls == "auto" :
109+ if envs .VLLM_USE_V1 :
110+ parallel_config .worker_cls = \
111+ "vllm.v1.worker.xpu_worker.XPUWorker"
112+ else :
100113 parallel_config .worker_cls = "vllm.worker.xpu_worker.XPUWorker"
101114
102115 if parallel_config .distributed_executor_backend is None :
103- parallel_config .distributed_executor_backend = "ray"
116+ if parallel_config .world_size > 1 :
117+ parallel_config .distributed_executor_backend = "ray"
118+ else :
119+ parallel_config .distributed_executor_backend = "uni"
104120 elif parallel_config .distributed_executor_backend == "mp" :
105121 # FIXME(kunshang):
106122 # spawn needs calling `if __name__ == '__main__':``
107123 # fork is not supported for xpu start new process.
108- logger . error (
109- "Both start methods ( spawn and fork) have issue "
110- "on XPU if you use mp backend, setting it to ray instead." )
111- parallel_config . distributed_executor_backend = "ray"
112-
113- elif parallel_config .distributed_executor_backend != "ray " :
124+ if envs . VLLM_WORKER_MULTIPROC_METHOD != "spawn" :
125+ os . environ [ "VLLM_WORKER_MULTIPROC_METHOD" ] = " spawn"
126+ logger . warning (
127+ "Please use spawn as start method if you want to use mp." )
128+ elif parallel_config . distributed_executor_backend != "ray" and \
129+ parallel_config .distributed_executor_backend != "uni " :
114130 logger .warning (
115131 "%s is not supported on XPU, fallback to ray distributed"
116132 " executor backend." ,
@@ -142,15 +158,35 @@ def get_current_memory_usage(cls,
142158 @classmethod
143159 def device_support_bf16 (cls ) -> bool :
144160 device_name = cls .get_device_name ().lower ()
145- if device_name .count ("arc" ) > 0 :
161+ if cls .is_client_gpu_a770 ():
162+ logger .warning ("Intel Arc A770 have bfloat16 accuracy known issue,"
163+ " fallback to float16" )
146164 return False
147- elif device_name .count ("data center gpu" ) > 0 :
148- return True
149165 else :
150- logger .warning ("Unknown device name %s, always use float16" ,
151- device_name )
152- return False
166+ logger .info (
167+ "Device name %s supports bfloat16. Please file an issue "
168+ "if you encounter any accuracy problems with bfloat16." ,
169+ device_name )
170+ return True
171+
172+ @classmethod
173+ def is_data_center_gpu (cls ) -> bool :
174+ device_name = cls .get_device_name ().lower ()
175+ return device_name .count ("data center gpu" ) > 0
176+
177+ @classmethod
178+ def is_client_gpu_a770 (cls ) -> bool :
179+ device_name = cls .get_device_name ().lower ()
180+ return device_name .count ("a770" ) > 0
153181
154182 @classmethod
155183 def get_device_communicator_cls (cls ) -> str :
156184 return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa
185+
186+ @classmethod
187+ def supports_v1 (cls , model_config : ModelConfig ) -> bool :
188+ return True
189+
190+ @classmethod
191+ def device_count (cls ) -> int :
192+ return torch .xpu .device_count ()
0 commit comments