Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,8 @@ class VllmRunner:
- `trust_remote_code`: Set to `True` instead of `False` for convenience.
- `seed`: Set to `0` instead of `None` for test reproducibility.
- `max_model_len`: Set to `1024` instead of `None` to reduce memory usage.
- `block_size`: Set to `16` instead of `None` to reduce memory usage.
- `block_size`: To reduce memory usage, set default to `64` if on XPU
devices, otherwise default to `16`.
- `enable_chunked_prefill`: Set to `False` instead of `None` for
test reproducibility.
- `enforce_eager`: Set to `False` to test CUDA graph.
Expand All @@ -777,7 +778,7 @@ def __init__(
dtype: str = "auto",
disable_log_stats: bool = True,
tensor_parallel_size: int = 1,
block_size: int = 16,
block_size: int = 16 if not torch.xpu.is_available() else 64,
enable_chunked_prefill: Optional[bool] = False,
swap_space: int = 4,
enforce_eager: Optional[bool] = False,
Expand Down
3 changes: 3 additions & 0 deletions vllm/distributed/device_communicators/xpu_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,6 @@ def gather(self,
else:
output_tensor = None
return output_tensor

def broadcast(self, input_: torch.Tensor, src: int = 0) -> None:
dist.broadcast(input_, src=src, group=self.device_group)
10 changes: 6 additions & 4 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,8 @@ def __init__(

if current_platform.is_cuda_alike():
self.device = torch.device(f"cuda:{local_rank}")
elif current_platform.is_xpu():
self.device = torch.device(f"xpu:{local_rank}")
Comment on lines +243 to +244
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider using torch.device directly instead of string formatting for device creation. This can improve code readability and reduce the risk of errors.

self.device = torch.device("xpu", local_rank)

elif current_platform.is_out_of_tree():
self.device = torch.device(
f"{current_platform.device_name}:{local_rank}")
Expand Down Expand Up @@ -1317,13 +1319,13 @@ def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],

def is_global_first_rank() -> bool:
"""
Check if the current process is the first rank globally across all
Check if the current process is the first rank globally across all
parallelism strategies (PP, TP, DP, EP, etc.).

Unlike group-specific checks like `get_tensor_model_parallel_rank() == 0`
or `get_pp_group().is_first_rank`, this function checks the global rank
across all parallelism dimensions.

Returns:
bool: True if this is the global first rank (rank 0), False otherwise.
Returns True if distributed is not initialized (single process).
Expand Down Expand Up @@ -1352,7 +1354,7 @@ def _node_count(pg: Union[ProcessGroup, StatelessProcessGroup]) -> int:

Args:
pg: The process group to analyze

Returns:
int: The total number of nodes
"""
Expand Down
10 changes: 5 additions & 5 deletions vllm/platforms/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:

# FIXME: Temporarily forcing eager mode
# remove after t.compile support stabilizes.

if (envs.VLLM_USE_V1 and vllm_config.model_config is not None
and not vllm_config.model_config.enforce_eager):
from vllm.config import CompilationLevel
Expand All @@ -111,9 +112,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
"mode.")
model_config.enforce_eager = True

if vllm_config.device_config is not None:
assert vllm_config.device_config.device_type == "xpu"

# check and update parallel config
parallel_config = vllm_config.parallel_config
parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker"
Expand All @@ -131,8 +129,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
logger.warning(
"Please use spawn as start method if you want to use mp.")
elif parallel_config.distributed_executor_backend != "ray" and \
parallel_config.distributed_executor_backend != "uni":
elif (parallel_config.distributed_executor_backend != "ray"
and parallel_config.distributed_executor_backend != "uni"
and parallel_config.distributed_executor_backend
!= "external_launcher"):
logger.warning(
"%s is not supported on XPU, fallback to ray distributed"
" executor backend.",
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/worker/xpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(
self.cascade_attn_enabled = False

def _init_device_properties(self) -> None:
pass
self.num_sms = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On top of this change, may I suggest an improvement to move these customization to gpu_model_runner.py as these are just backend specific dispatch logic which can be handled easier without class inheritance? See proposal here:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually these function was in gpu_model_runner.py originally and move to different device model_runner for cleaner readness. So I think we could follow this design.


def _sync_device(self) -> None:
torch.xpu.synchronize()