diff --git a/vllm/distributed/device_communicators/xpu_communicator.py b/vllm/distributed/device_communicators/xpu_communicator.py index dee5ed7a2883..067315deb773 100644 --- a/vllm/distributed/device_communicators/xpu_communicator.py +++ b/vllm/distributed/device_communicators/xpu_communicator.py @@ -7,8 +7,13 @@ import torch.distributed as dist from torch.distributed import ProcessGroup +import vllm.envs as envs +from vllm.logger import init_logger + from .base_device_communicator import DeviceCommunicatorBase +logger = init_logger(__name__) + class XpuCommunicator(DeviceCommunicatorBase): @@ -18,6 +23,12 @@ def __init__(self, device_group: Optional[ProcessGroup] = None, unique_name: str = ""): super().__init__(cpu_group, device, device_group, unique_name) + if self.use_all2all: + all2all_backend = envs.VLLM_ALL2ALL_BACKEND + if all2all_backend == "naive": + from .all2all import NaiveAll2AllManager + self.all2all_manager = NaiveAll2AllManager(self.cpu_group) + logger.info("Using naive all2all manager.") def all_reduce(self, input_) -> torch.Tensor: dist.all_reduce(input_, group=self.device_group) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index b9de03ddd216..f413715f4ed8 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -655,6 +655,8 @@ def forward_tpu( forward_native = forward_tpu elif current_platform.is_cpu(): forward_native = forward_cpu + elif current_platform.is_xpu(): + forward_native = forward_xpu else: forward_native = forward_cuda