From aa9e94d150e8a524f3ec0af049d044ef2954978d Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Tue, 8 Jul 2025 01:31:05 -0700 Subject: [PATCH 1/5] enhance xpu test support Co-authored-by chaojun-zhang Co-authored-by: zufangzhu Co-authored-by: zhenwei-intel Signed-off-by: Ma, Liangliang --- tests/conftest.py | 5 +++-- tests/utils.py | 7 ++++++- .../device_communicators/xpu_communicator.py | 5 +++++ vllm/distributed/parallel_state.py | 10 ++++++---- vllm/platforms/xpu.py | 14 ++++++++++---- vllm/v1/worker/xpu_model_runner.py | 2 +- 6 files changed, 31 insertions(+), 12 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index b294b50a5cdd..d976dee1b464 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -759,7 +759,8 @@ class VllmRunner: - `trust_remote_code`: Set to `True` instead of `False` for convenience. - `seed`: Set to `0` instead of `None` for test reproducibility. - `max_model_len`: Set to `1024` instead of `None` to reduce memory usage. - - `block_size`: Set to `16` instead of `None` to reduce memory usage. + - `block_size`: To reduce memory usage, set default to `64` if on XPU devices, + otherwise default to `16`. - `enable_chunked_prefill`: Set to `False` instead of `None` for test reproducibility. - `enforce_eager`: Set to `False` to test CUDA graph. @@ -777,7 +778,7 @@ def __init__( dtype: str = "auto", disable_log_stats: bool = True, tensor_parallel_size: int = 1, - block_size: int = 16, + block_size: int = 16 if not hasattr(torch, 'xpu') else 64, enable_chunked_prefill: Optional[bool] = False, swap_space: int = 4, enforce_eager: Optional[bool] = False, diff --git a/tests/utils.py b/tests/utils.py index a37872830dad..82455b840f29 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -730,6 +730,11 @@ def fork_new_process_for_each_test( @functools.wraps(f) def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: + # To use XPU with multiprocessing, must use the 'spawn' start method via 'VLLM_WORKER_MULTIPROC_METHOD=spawn' + if current_platform.is_xpu(): + f(*args, **kwargs) + return + # Make the process the leader of its own process group # to avoid sending SIGTERM to the parent process os.setpgrp() @@ -817,7 +822,7 @@ def create_new_process_for_each_test( """Creates a decorator that runs each test function in a new process. Args: - method: The process creation method. Can be either "spawn" or "fork". + method: The process creation method. Can be either "spawn" or "fork". If not specified, it defaults to "spawn" on ROCm platforms and "fork" otherwise. diff --git a/vllm/distributed/device_communicators/xpu_communicator.py b/vllm/distributed/device_communicators/xpu_communicator.py index 216ff85c8bb7..058f3ce3d4e1 100644 --- a/vllm/distributed/device_communicators/xpu_communicator.py +++ b/vllm/distributed/device_communicators/xpu_communicator.py @@ -53,3 +53,8 @@ def gather(self, else: output_tensor = None return output_tensor + + def broadcast(self, + input_: torch.Tensor, + src: int = 0) -> None: + dist.broadcast(input_, src=src, group=self.device_group) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index c53601a22f21..495a758e6069 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -240,6 +240,8 @@ def __init__( if current_platform.is_cuda_alike(): self.device = torch.device(f"cuda:{local_rank}") + elif current_platform.is_xpu(): + self.device = torch.device(f"xpu:{local_rank}") elif current_platform.is_out_of_tree(): self.device = torch.device( f"{current_platform.device_name}:{local_rank}") @@ -1317,13 +1319,13 @@ def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup], def is_global_first_rank() -> bool: """ - Check if the current process is the first rank globally across all + Check if the current process is the first rank globally across all parallelism strategies (PP, TP, DP, EP, etc.). - + Unlike group-specific checks like `get_tensor_model_parallel_rank() == 0` or `get_pp_group().is_first_rank`, this function checks the global rank across all parallelism dimensions. - + Returns: bool: True if this is the global first rank (rank 0), False otherwise. Returns True if distributed is not initialized (single process). @@ -1352,7 +1354,7 @@ def _node_count(pg: Union[ProcessGroup, StatelessProcessGroup]) -> int: Args: pg: The process group to analyze - + Returns: int: The total number of nodes """ diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index e2871c106492..6e8813a30608 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -78,6 +78,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if cache_config and cache_config.block_size is None: cache_config.block_size = 64 + # FIXME: Temporarily forcing eager mode + # remove after t.compile support stabilizes. + if envs.VLLM_USE_V1 and vllm_config.model_config is not None and \ + not vllm_config.model_config.enforce_eager: + from vllm.config import CompilationLevel + vllm_config.compilation_config.level = \ + CompilationLevel.NO_COMPILATION + # Instances created using VllmConfig() typically have model_config as # None by default. The modification involves adding a check to prevent # potential null exceptions check and update model config. @@ -93,9 +101,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "mode.") model_config.enforce_eager = True - if vllm_config.device_config is not None: - assert vllm_config.device_config.device_type == "xpu" - # check and update parallel config parallel_config = vllm_config.parallel_config parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker" @@ -114,7 +119,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: logger.warning( "Please use spawn as start method if you want to use mp.") elif parallel_config.distributed_executor_backend != "ray" and \ - parallel_config.distributed_executor_backend != "uni": + parallel_config.distributed_executor_backend != "uni" and \ + parallel_config.distributed_executor_backend != "external_launcher": logger.warning( "%s is not supported on XPU, fallback to ray distributed" " executor backend.", diff --git a/vllm/v1/worker/xpu_model_runner.py b/vllm/v1/worker/xpu_model_runner.py index 4cedc913c2ab..59f8d0fcf5bd 100644 --- a/vllm/v1/worker/xpu_model_runner.py +++ b/vllm/v1/worker/xpu_model_runner.py @@ -27,7 +27,7 @@ def __init__( self.cascade_attn_enabled = False def _init_device_properties(self) -> None: - pass + self.num_sms = None def _sync_device(self) -> None: torch.xpu.synchronize() From cb0b3dcdd6310d3514c664f104b76a9d7f2e7401 Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Wed, 9 Jul 2025 00:18:58 -0700 Subject: [PATCH 2/5] remove fork change Signed-off-by: Ma, Liangliang --- tests/utils.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index 82455b840f29..1653338ffdd2 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -730,11 +730,6 @@ def fork_new_process_for_each_test( @functools.wraps(f) def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: - # To use XPU with multiprocessing, must use the 'spawn' start method via 'VLLM_WORKER_MULTIPROC_METHOD=spawn' - if current_platform.is_xpu(): - f(*args, **kwargs) - return - # Make the process the leader of its own process group # to avoid sending SIGTERM to the parent process os.setpgrp() From 6dba0f34b9e6413fc7a9c3d6144761dba3dcb978 Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Wed, 9 Jul 2025 00:29:47 -0700 Subject: [PATCH 3/5] use torch.xpu.is_available instead of hasattr Signed-off-by: Ma, Liangliang --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index d976dee1b464..54cb10899cc3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -778,7 +778,7 @@ def __init__( dtype: str = "auto", disable_log_stats: bool = True, tensor_parallel_size: int = 1, - block_size: int = 16 if not hasattr(torch, 'xpu') else 64, + block_size: int = 16 if not torch.xpu.is_available() else 64, enable_chunked_prefill: Optional[bool] = False, swap_space: int = 4, enforce_eager: Optional[bool] = False, From 200dba945f31a86d1ceb30c52da013f15d28a3d4 Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Wed, 9 Jul 2025 01:09:17 -0700 Subject: [PATCH 4/5] fix format Signed-off-by: Ma, Liangliang --- vllm/distributed/device_communicators/xpu_communicator.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/distributed/device_communicators/xpu_communicator.py b/vllm/distributed/device_communicators/xpu_communicator.py index 058f3ce3d4e1..dee5ed7a2883 100644 --- a/vllm/distributed/device_communicators/xpu_communicator.py +++ b/vllm/distributed/device_communicators/xpu_communicator.py @@ -54,7 +54,5 @@ def gather(self, output_tensor = None return output_tensor - def broadcast(self, - input_: torch.Tensor, - src: int = 0) -> None: + def broadcast(self, input_: torch.Tensor, src: int = 0) -> None: dist.broadcast(input_, src=src, group=self.device_group) From a84337ee2353d39bc1d11d5efa986da522e58c79 Mon Sep 17 00:00:00 2001 From: "Ma, Liangliang" Date: Wed, 9 Jul 2025 04:07:57 -0700 Subject: [PATCH 5/5] fix format that exceed max line length Signed-off-by: Ma, Liangliang --- tests/conftest.py | 4 ++-- vllm/platforms/xpu.py | 7 ++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 54cb10899cc3..c5d7156905bc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -759,8 +759,8 @@ class VllmRunner: - `trust_remote_code`: Set to `True` instead of `False` for convenience. - `seed`: Set to `0` instead of `None` for test reproducibility. - `max_model_len`: Set to `1024` instead of `None` to reduce memory usage. - - `block_size`: To reduce memory usage, set default to `64` if on XPU devices, - otherwise default to `16`. + - `block_size`: To reduce memory usage, set default to `64` if on XPU + devices, otherwise default to `16`. - `enable_chunked_prefill`: Set to `False` instead of `None` for test reproducibility. - `enforce_eager`: Set to `False` to test CUDA graph. diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index c89838ff250c..3196f3059e19 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -129,9 +129,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" logger.warning( "Please use spawn as start method if you want to use mp.") - elif parallel_config.distributed_executor_backend != "ray" and \ - parallel_config.distributed_executor_backend != "uni" and \ - parallel_config.distributed_executor_backend != "external_launcher": + elif (parallel_config.distributed_executor_backend != "ray" + and parallel_config.distributed_executor_backend != "uni" + and parallel_config.distributed_executor_backend + != "external_launcher"): logger.warning( "%s is not supported on XPU, fallback to ray distributed" " executor backend.",