Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/getting_started/installation/gpu/xpu.inc.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,9 @@ python -m vllm.entrypoints.openai.api_server \
By default, a ray instance will be launched automatically if no existing one is detected in the system, with `num-gpus` equals to `parallel_config.world_size`. We recommend properly starting a ray cluster before execution, referring to the <gh-file:examples/online_serving/run_cluster.sh> helper script.

# --8<-- [end:supported-features]
# --8<-- [start:distributed-backend]

XPU platform uses **torch-ccl** for torch<2.8 and **xccl** for torch>=2.8 as distributed backend, since torch 2.8 supports **xccl** as built-in backend for XPU.

# --8<-- [end:distributed-backend]
# --8<-- [end:extra-information]
13 changes: 11 additions & 2 deletions vllm/platforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import TYPE_CHECKING, Optional

from vllm.plugins import load_plugins_by_group
from vllm.utils import resolve_obj_by_qualname
from vllm.utils import resolve_obj_by_qualname, supports_xccl

from .interface import _Backend # noqa: F401
from .interface import CpuArchEnum, Platform, PlatformEnum
Expand Down Expand Up @@ -139,10 +139,19 @@ def xpu_platform_plugin() -> Optional[str]:
try:
# installed IPEX if the machine has XPUs.
import intel_extension_for_pytorch # noqa: F401
import oneccl_bindings_for_pytorch # noqa: F401
import torch
if supports_xccl():
dist_backend = "xccl"
else:
dist_backend = "ccl"
import oneccl_bindings_for_pytorch # noqa: F401

if hasattr(torch, 'xpu') and torch.xpu.is_available():
is_xpu = True
from vllm.platforms.xpu import XPUPlatform
XPUPlatform.dist_backend = dist_backend
logger.debug("Confirmed %s backend is available.",
XPUPlatform.dist_backend)
logger.debug("Confirmed XPU platform is available.")
except Exception as e:
logger.debug("XPU platform is not available because: %s", str(e))
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class CpuPlatform(Platform):
device_name: str = "cpu"
device_type: str = "cpu"
dispatch_key: str = "CPU"
dist_backend: str = "gloo"

@property
def supported_dtypes(self) -> list[torch.dtype]:
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class CudaPlatformBase(Platform):
device_type: str = "cuda"
dispatch_key: str = "CUDA"
ray_device_key: str = "GPU"
dist_backend: str = "nccl"
device_control_env_var: str = "CUDA_VISIBLE_DEVICES"

@property
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/hpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class HpuPlatform(Platform):
device_type: str = "hpu"
dispatch_key: str = "HPU"
ray_device_key: str = "HPU"
dist_backend: str = "hccl"
device_control_env_var: str = "HABANA_VISIBLE_MODULES"

@classmethod
Expand Down
3 changes: 3 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ class Platform:
# compilation strategy.
simple_compile_backend: str = "inductor"

# The backend used for distributed communication.
dist_backend: str = ""

supported_quantization: list[str] = []

additional_env_vars: list[str] = []
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/neuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class NeuronPlatform(Platform):
device_type: str = "neuron"
ray_device_key: str = "neuron_cores"
supported_quantization: list[str] = ["neuron_quant", "fbgemm_fp8"]
dist_backend: str = "gloo"
device_control_env_var: str = "NEURON_RT_VISIBLE_CORES"

@classmethod
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ class RocmPlatform(Platform):
device_type: str = "cuda"
dispatch_key: str = "CUDA"
ray_device_key: str = "GPU"
dist_backend: str = "nccl"
# rocm shares the same device control env var as CUDA
device_control_env_var: str = "CUDA_VISIBLE_DEVICES"

Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class TpuPlatform(Platform):
device_type: str = "tpu"
dispatch_key: str = "XLA"
ray_device_key: str = "TPU"
dist_backend: str = "gloo"
device_control_env_var: str = "TPU_VISIBLE_CHIPS"
simple_compile_backend: str = "openxla"

Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class XPUPlatform(Platform):
# Intel XPU's device key is "GPU" for Ray.
# see https://github.com/ray-project/ray/blob/6a5eb5865eeb9ccf058a79b44f107e327e360673/python/ray/_private/accelerators/intel_gpu.py#L20 # noqa: E501
ray_device_key: str = "GPU"
dist_backend: str = "ccl" # ccl | xccl
device_control_env_var: str = "ONEAPI_DEVICE_SELECTOR"

@classmethod
Expand Down
6 changes: 6 additions & 0 deletions vllm/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1886,6 +1886,12 @@ def supports_dynamo() -> bool:
return base_torch_version >= Version("2.4.0")


# Supports xccl with PyTorch versions >= 2.8.0 for XPU platform
def supports_xccl() -> bool:
return is_torch_equal_or_newer(
"2.8.0") and torch.distributed.is_xccl_available()


# Some backends use pytorch version < 2.4.0 which doesn't
# support `torch.library.custom_op`.
def supports_custom_op() -> bool:
Expand Down
4 changes: 3 additions & 1 deletion vllm/v1/worker/cpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import init_logger
from vllm.model_executor.utils import set_random_seed
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
Expand Down Expand Up @@ -58,7 +59,8 @@ def init_device(self):
# Initialize the distributed environment.
init_worker_distributed_environment(self.vllm_config, self.rank,
self.distributed_init_method,
self.local_rank, "gloo")
self.local_rank,
current_platform.dist_backend)
# Set random seed.
set_random_seed(self.model_config.seed)

Expand Down
3 changes: 2 additions & 1 deletion vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ def init_device(self):
# Initialize the distributed environment.
init_worker_distributed_environment(self.vllm_config, self.rank,
self.distributed_init_method,
self.local_rank)
self.local_rank,
current_platform.dist_backend)
# Set random seed.
set_random_seed(self.model_config.seed)

Expand Down
3 changes: 2 additions & 1 deletion vllm/v1/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT
from vllm.v1.core.sched.output import SchedulerOutput
Expand Down Expand Up @@ -300,7 +301,7 @@ def _init_tpu_worker_distributed_environment(
rank=rank,
local_rank=local_rank,
distributed_init_method=distributed_init_method,
backend="gloo",
backend=current_platform.dist_backend,
)
ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size,
Expand Down
3 changes: 2 additions & 1 deletion vllm/worker/hpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest
from vllm.utils import bind_kv_cache
Expand Down Expand Up @@ -413,7 +414,7 @@ def init_worker_distributed_environment(
rank,
distributed_init_method,
local_rank,
backend='hccl')
backend=current_platform.dist_backend)

ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
Expand Down
2 changes: 1 addition & 1 deletion vllm/worker/neuron_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def init_distributed_environment(self):
rank=self.rank,
local_rank=self.local_rank,
distributed_init_method=self.distributed_init_method,
backend="gloo",
backend=current_platform.dist_backend,
)

ensure_model_parallel_initialized(
Expand Down
3 changes: 2 additions & 1 deletion vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,8 @@ def init_worker_distributed_environment(
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)

init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank)
distributed_init_method, local_rank,
current_platform.dist_backend)
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)

Expand Down