From fbe8deca0787e0a2caa87795681abb11f2d6d907 Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Tue, 23 Sep 2025 19:31:33 +0800 Subject: [PATCH] [ROCm][Allreduce] Add dispatch mechanism for choosing performant allreduce implementations for AMD platforms Signed-off-by: zejunchen-zejun --- .../device_communicators/cuda_communicator.py | 250 +++++++++++++++--- 1 file changed, 210 insertions(+), 40 deletions(-) diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index bab372b722db..db84a35266d4 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union +from typing import Any, Callable, Optional, Union import torch from torch.distributed import ProcessGroup @@ -21,6 +21,174 @@ logger = init_logger(__name__) +# ROCm allreduce dispatcher that dispatches the +# performant allreduce ops based on the available +# implementations, payload size of input tensor +# and TP size. It only supports AMD ROCm platforms. +class ROCmAllreduceDispatcher: + + def __init__(self, + group: ProcessGroup, + device: Union[int, str, torch.device], + ca_comm=None, + pynccl_comm=None): + self.process_group = group + self.device = device + self.cur_device_arch = self._get_current_device_arch() + self.supported_device_archs = ["gfx94", "gfx95"] + + self.tp_size = torch.distributed.get_world_size( + group=self.process_group) + + # dispatch thresholds(unit: KB) by tp_size: + self.gfx95_thresholds = { + 2: 512, + 4: 2048, + 8: 32768, + } + self.gfx94_thresholds = { + 2: 2048, + 4: 4096, + 8: 8192, + } + + # allreduce naming : associated allreduce impl, allreduce check impl + self.available_allreduce_impls: dict[str, + tuple[Callable, + Optional[Callable]]] = {} + + self.fallback_impl = None + if pynccl_comm is not None: + self.available_allreduce_impls["pynccl"] = \ + (pynccl_comm.all_reduce, None) + self.fallback_impl = pynccl_comm.all_reduce + + if ca_comm is not None: + self.available_allreduce_impls["vllm_ca"] = \ + (ca_comm.custom_all_reduce, ca_comm.should_custom_ar) + + # Initialize a custom quick all-reduce implementation for AMD. + # Quick reduce is designed as a complement to custom allreduce. + # Based on quickreduce (https://github.com/mk1-project/quickreduce). + # If it's a rocm, 'use_custom_allreduce==True' means it must + # currently be an MI300 series. + from vllm.distributed.device_communicators.quick_all_reduce import ( + QuickAllReduce) + self.qr_comm = QuickAllReduce(group=self.process_group, + device=self.device) + if self.qr_comm is not None: + self.available_allreduce_impls["vllm_qr"] = ( + self.qr_comm.quick_all_reduce, + self.qr_comm.should_quick_allreduce) + + # Initialize a custom all-reduce implementation from aiter. + if self._is_aiter_custom_allreduce_available(): + from aiter.dist.custom_all_reduce import CustomAllreduce + self.aiter_ca_comm = CustomAllreduce( + group=self.process_group, + device=self.device, + ) + if self.aiter_ca_comm is not None: + self.available_allreduce_impls["aiter_ca"] = \ + (self.aiter_ca_comm.custom_all_reduce, \ + self.aiter_ca_comm.should_custom_ar) + + def _is_aiter_custom_allreduce_available(self) -> bool: + """Check if aiter is enabled for ROCm platform.""" + if not envs.VLLM_ROCM_USE_AITER: + return False + + try: + from aiter.dist.custom_all_reduce import ( # noqa: F401 + CustomAllreduce) + return True + except ImportError: + return False + + def _get_current_device_arch(self) -> str: + """Get the device micro architecture number of the current device.""" + # TODO(zejun): Add more device architectures + device_arch = torch.cuda.get_device_properties("cuda").gcnArchName + if "gfx95" in device_arch: + return "gfx95" + elif "gfx94" in device_arch: + return "gfx94" + else: + return device_arch + + def _should_allreduce(self, input_: torch.Tensor, impl_name: str) -> bool: + if impl_name not in self.available_allreduce_impls: + return False + check_fn = self.available_allreduce_impls[impl_name][1] + if check_fn is None: + return False + return check_fn(input_) + + def _dispatch_gfx94(self, input_: torch.Tensor, payload_size_KB: int, + tp_size: int): + """Dispatch implementation for gfx94 architecture.""" + if tp_size not in self.gfx94_thresholds: + return self.fallback_impl + + threshold = self.gfx94_thresholds[tp_size] + + # for the small payload size, prioritize calling vllm custom allreduce + # for the large payload size, calling quick allreduce + if payload_size_KB <= threshold: + if self._should_allreduce(input_, "vllm_ca"): + return self.available_allreduce_impls["vllm_ca"][0] + else: + return self.fallback_impl + elif self._should_allreduce(input_, "vllm_qr"): + return self.available_allreduce_impls["vllm_qr"][0] + else: + return self.fallback_impl + + def _dispatch_gfx95(self, input_: torch.Tensor, payload_size_KB: int, + tp_size: int): + """Dispatch implementation for gfx95 architecture.""" + if tp_size not in self.gfx95_thresholds: + return self.fallback_impl + + threshold = self.gfx95_thresholds[tp_size] + + # for the small payload size, prioritize calling aiter custom allreduce + # for the large payload size, calling quick allreduce + if payload_size_KB <= threshold: + if self._should_allreduce(input_, "aiter_ca"): + return self.available_allreduce_impls["aiter_ca"][0] + elif self._should_allreduce(input_, "vllm_ca"): + return self.available_allreduce_impls["vllm_ca"][0] + else: + return self.fallback_impl + elif self._should_allreduce(input_, "vllm_qr"): + return self.available_allreduce_impls["vllm_qr"][0] + else: + return self.fallback_impl + + def _dispatch_impl(self, input_: torch.Tensor, payload_size_KB: int, + device_arch: str, tp_size: int): + if device_arch not in self.supported_device_archs: + logger.debug( + "Device architecture {device_arch} not supported, using pynccl" + ) + return self.fallback_impl + + if device_arch == "gfx95": + return self._dispatch_gfx95(input_, payload_size_KB, tp_size) + elif device_arch == "gfx94": + return self._dispatch_gfx94(input_, payload_size_KB, tp_size) + else: + return self.fallback_impl + + def dispatch(self, input_: torch.Tensor) -> Optional[Callable[..., Any]]: + """Dispatch the allreduce implementation""" + payload_size_KB = int(input_.numel() * input_.element_size() / 1024.0) + op = self._dispatch_impl(input_, payload_size_KB, self.cur_device_arch, + self.tp_size) + return op + + class CudaCommunicator(DeviceCommunicatorBase): def __init__(self, @@ -51,8 +219,6 @@ def __init__(self, CustomAllreduce) from vllm.distributed.device_communicators.pynccl import ( PyNcclCommunicator) - from vllm.distributed.device_communicators.quick_all_reduce import ( - QuickAllReduce) from vllm.distributed.device_communicators.symm_mem import ( SymmMemCommunicator) @@ -66,8 +232,12 @@ def __init__(self, register_nccl_symmetric_ops(self.pynccl_comm) self.ca_comm: Optional[CustomAllreduce] = None - self.qr_comm: Optional[QuickAllReduce] = None self.symm_mem_comm: Optional[SymmMemCommunicator] = None + + # Initialize a custom all-reduce dispatcher for ROCm platform + self.rocm_allreduce_dispatcher: Optional[ + ROCmAllreduceDispatcher] = None + if use_torch_symm_mem and current_platform.is_cuda(): self.symm_mem_comm = SymmMemCommunicator( group=self.cpu_group, @@ -84,13 +254,12 @@ def __init__(self, ) if current_platform.is_rocm(): - # Initialize a custom quick all-reduce implementation for AMD. - # Quick reduce is designed as a complement to custom allreduce. - # Based on quickreduce (https://github.com/mk1-project/quickreduce). - # If it's a rocm, 'use_custom_allreduce==True' means it must - # currently be an MI300 series. - self.qr_comm = QuickAllReduce(group=self.cpu_group, - device=self.device) + self.rocm_allreduce_dispatcher = \ + ROCmAllreduceDispatcher(group=self.cpu_group, + device=self.device, + ca_comm=self.ca_comm, + pynccl_comm=self.pynccl_comm) + logger.info("Initializing ROCm allreduce dispatcher.") if self.use_all2all: all2all_backend = envs.VLLM_ALL2ALL_BACKEND @@ -123,36 +292,37 @@ def __init__(self, raise ValueError(f"Unknown all2all backend: {all2all_backend}") def all_reduce(self, input_): - # since currently we perform copy input -> symm_input -> out-of-place AR - # return symm_output, we don't need to check if input is symmetric - if self.pynccl_comm is not None and \ - should_nccl_symm_mem_allreduce(self.pynccl_comm.world_size,input_): - out = torch.ops.vllm.all_reduce_symmetric_with_copy(input_) - if out is not None: + if current_platform.is_rocm() and \ + self.rocm_allreduce_dispatcher is not None: + op = self.rocm_allreduce_dispatcher.dispatch(input_) + logger.debug("ROCm allreduce dispatcher dispatched: {op}") + out = None if op is None else op(input_) + else: + # since currently we perform: + # copy input -> symm_input -> out-of-place AR + # return symm_output, we don't need to check if input is symmetric + if self.pynccl_comm is not None and should_nccl_symm_mem_allreduce( + self.pynccl_comm.world_size, input_): + out = torch.ops.vllm.all_reduce_symmetric_with_copy(input_) + if out is not None: + return out + ca_comm = self.ca_comm + if ca_comm is not None and not ca_comm.disabled and \ + ca_comm.should_custom_ar(input_): + out = ca_comm.custom_all_reduce(input_) + assert out is not None return out - # always try quick reduce first, then custom allreduce, - # and then pynccl. (quick reduce just for ROCM MI3*) - qr_comm = self.qr_comm - if qr_comm is not None and not qr_comm.disabled and \ - qr_comm.should_quick_allreduce(input_): - out = qr_comm.quick_all_reduce(input_) - assert out is not None - return out - ca_comm = self.ca_comm - if ca_comm is not None and not ca_comm.disabled and \ - ca_comm.should_custom_ar(input_): - out = ca_comm.custom_all_reduce(input_) - assert out is not None - return out - symm_mem_comm = self.symm_mem_comm - if symm_mem_comm is not None and \ - symm_mem_comm.should_use_symm_mem(input_): - out = symm_mem_comm.all_reduce(input_) - assert out is not None - return out - pynccl_comm = self.pynccl_comm - assert pynccl_comm is not None - out = pynccl_comm.all_reduce(input_) + symm_mem_comm = self.symm_mem_comm + if symm_mem_comm is not None and \ + symm_mem_comm.should_use_symm_mem(input_): + out = symm_mem_comm.all_reduce(input_) + assert out is not None + return out + pynccl_comm = self.pynccl_comm + assert pynccl_comm is not None + out = pynccl_comm.all_reduce(input_) + + # fallback to the default all-reduce using PyTorch. if out is None: # fall back to the default all-reduce using PyTorch. # this usually happens during testing.