diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index eef3f9f75f9f..4206a67fea28 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union +from importlib.util import find_spec +from typing import Optional, Protocol, Union import torch from torch.distributed import ProcessGroup @@ -15,6 +16,46 @@ logger = init_logger(__name__) +class CustomAllreduceProtocol(Protocol): + """Protocol for custom allreduce implementations. + used just to bypass mypy error""" + + disabled: bool = True + + def __init__(self, group: ProcessGroup, + device: Union[int, str, torch.device]) -> None: + ... + + def should_custom_ar(self, inp: torch.Tensor): + ... + + def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: + ... + + +def is_rocm_aiter_custom_allreduce_enabled() -> bool: + """Check if aiter custom allreduce is enabled for ROCm platform.""" + from vllm.platforms.rocm import on_gfx9 + return current_platform.is_rocm() \ + and on_gfx9() \ + and envs.VLLM_ROCM_USE_AITER \ + and envs.VLLM_ROCM_USE_AITER_CUSTOM_ALL_REDUCE \ + and find_spec("aiter.dist.custom_all_reduce") is not None \ + + +def dispatch_custom_allreduce() -> type[CustomAllreduceProtocol]: + """Dispatch the custom allreduce implementation based on the platform.""" + if is_rocm_aiter_custom_allreduce_enabled(): + from aiter.dist.custom_all_reduce import CustomAllreduce + logger.info_once( + "Using aiter.dist.custom_all_reduce for ROCm platform") + else: + from vllm.distributed.device_communicators.custom_all_reduce import ( # noqa: E501 + CustomAllreduce) + + return CustomAllreduce + + class CudaCommunicator(DeviceCommunicatorBase): def __init__(self, @@ -38,8 +79,7 @@ def __init__(self, self.use_custom_allreduce = use_custom_allreduce # lazy import to avoid documentation build error - from vllm.distributed.device_communicators.custom_all_reduce import ( - CustomAllreduce) + CustomAllreduce = dispatch_custom_allreduce() from vllm.distributed.device_communicators.pynccl import ( PyNcclCommunicator) from vllm.distributed.device_communicators.quick_all_reduce import ( @@ -54,7 +94,7 @@ def __init__(self, device=self.device, ) - self.ca_comm: Optional[CustomAllreduce] = None + self.ca_comm: Optional[CustomAllreduceProtocol] = None self.qr_comm: Optional[QuickAllReduce] = None self.symm_mem_comm: Optional[SymmMemCommunicator] = None if use_custom_allreduce and self.world_size > 1: diff --git a/vllm/envs.py b/vllm/envs.py index 1232bd7bf963..fde188336747 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -99,6 +99,7 @@ VLLM_ROCM_USE_AITER_RMSNORM: bool = True VLLM_ROCM_USE_AITER_MLA: bool = True VLLM_ROCM_USE_AITER_MHA: bool = True + VLLM_ROCM_USE_AITER_CUSTOM_ALL_REDUCE: bool = True VLLM_ROCM_USE_AITER_FP8BMM: bool = True VLLM_ROCM_USE_SKINNY_GEMM: bool = True VLLM_ROCM_FP8_PADDING: bool = True @@ -775,6 +776,12 @@ def get_vllm_port() -> Optional[int]: lambda: (os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in ("true", "1")), + # Whether to use aiter custom allreduce for ROCm platform. + # By default is disabled, uses vLLM built-in custom allreduce. + "VLLM_ROCM_USE_AITER_CUSTOM_ALL_REDUCE": + lambda: + (os.getenv("VLLM_ROCM_USE_AITER_CUSTOM_ALL_REDUCE", "True").lower() in + ("true", "1")), # Whether to use aiter triton fp8 bmm kernel # By default is enabled. "VLLM_ROCM_USE_AITER_FP8BMM":