Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
250 changes: 210 additions & 40 deletions vllm/distributed/device_communicators/cuda_communicator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import Optional, Union
from typing import Any, Callable, Optional, Union

import torch
from torch.distributed import ProcessGroup
Expand All @@ -21,6 +21,174 @@
logger = init_logger(__name__)


# ROCm allreduce dispatcher that dispatches the
# performant allreduce ops based on the available
# implementations, payload size of input tensor
# and TP size. It only supports AMD ROCm platforms.
class ROCmAllreduceDispatcher:

def __init__(self,
group: ProcessGroup,
device: Union[int, str, torch.device],
ca_comm=None,
pynccl_comm=None):
self.process_group = group
self.device = device
self.cur_device_arch = self._get_current_device_arch()
self.supported_device_archs = ["gfx94", "gfx95"]

self.tp_size = torch.distributed.get_world_size(
group=self.process_group)

# dispatch thresholds(unit: KB) by tp_size:
self.gfx95_thresholds = {
2: 512,
4: 2048,
8: 32768,
}
self.gfx94_thresholds = {
2: 2048,
4: 4096,
8: 8192,
}

# allreduce naming : associated allreduce impl, allreduce check impl
self.available_allreduce_impls: dict[str,
tuple[Callable,
Optional[Callable]]] = {}

self.fallback_impl = None
if pynccl_comm is not None:
self.available_allreduce_impls["pynccl"] = \
(pynccl_comm.all_reduce, None)
self.fallback_impl = pynccl_comm.all_reduce

if ca_comm is not None:
self.available_allreduce_impls["vllm_ca"] = \
(ca_comm.custom_all_reduce, ca_comm.should_custom_ar)

# Initialize a custom quick all-reduce implementation for AMD.
# Quick reduce is designed as a complement to custom allreduce.
# Based on quickreduce (https://github.com/mk1-project/quickreduce).
# If it's a rocm, 'use_custom_allreduce==True' means it must
# currently be an MI300 series.
from vllm.distributed.device_communicators.quick_all_reduce import (
QuickAllReduce)
self.qr_comm = QuickAllReduce(group=self.process_group,
device=self.device)
if self.qr_comm is not None:
self.available_allreduce_impls["vllm_qr"] = (
self.qr_comm.quick_all_reduce,
self.qr_comm.should_quick_allreduce)

# Initialize a custom all-reduce implementation from aiter.
if self._is_aiter_custom_allreduce_available():
from aiter.dist.custom_all_reduce import CustomAllreduce
self.aiter_ca_comm = CustomAllreduce(
group=self.process_group,
device=self.device,
)
if self.aiter_ca_comm is not None:
self.available_allreduce_impls["aiter_ca"] = \
(self.aiter_ca_comm.custom_all_reduce, \
self.aiter_ca_comm.should_custom_ar)

def _is_aiter_custom_allreduce_available(self) -> bool:
"""Check if aiter is enabled for ROCm platform."""
if not envs.VLLM_ROCM_USE_AITER:
return False

try:
from aiter.dist.custom_all_reduce import ( # noqa: F401
CustomAllreduce)
return True
except ImportError:
return False

def _get_current_device_arch(self) -> str:
"""Get the device micro architecture number of the current device."""
# TODO(zejun): Add more device architectures
device_arch = torch.cuda.get_device_properties("cuda").gcnArchName
if "gfx95" in device_arch:
return "gfx95"
elif "gfx94" in device_arch:
return "gfx94"
else:
return device_arch

def _should_allreduce(self, input_: torch.Tensor, impl_name: str) -> bool:
if impl_name not in self.available_allreduce_impls:
return False
check_fn = self.available_allreduce_impls[impl_name][1]
if check_fn is None:
return False
return check_fn(input_)

def _dispatch_gfx94(self, input_: torch.Tensor, payload_size_KB: int,
tp_size: int):
"""Dispatch implementation for gfx94 architecture."""
if tp_size not in self.gfx94_thresholds:
return self.fallback_impl

threshold = self.gfx94_thresholds[tp_size]

# for the small payload size, prioritize calling vllm custom allreduce
# for the large payload size, calling quick allreduce
if payload_size_KB <= threshold:
if self._should_allreduce(input_, "vllm_ca"):
return self.available_allreduce_impls["vllm_ca"][0]
else:
return self.fallback_impl
elif self._should_allreduce(input_, "vllm_qr"):
return self.available_allreduce_impls["vllm_qr"][0]
else:
return self.fallback_impl

def _dispatch_gfx95(self, input_: torch.Tensor, payload_size_KB: int,
tp_size: int):
"""Dispatch implementation for gfx95 architecture."""
if tp_size not in self.gfx95_thresholds:
return self.fallback_impl

threshold = self.gfx95_thresholds[tp_size]

# for the small payload size, prioritize calling aiter custom allreduce
# for the large payload size, calling quick allreduce
if payload_size_KB <= threshold:
if self._should_allreduce(input_, "aiter_ca"):
return self.available_allreduce_impls["aiter_ca"][0]
elif self._should_allreduce(input_, "vllm_ca"):
return self.available_allreduce_impls["vllm_ca"][0]
else:
return self.fallback_impl
elif self._should_allreduce(input_, "vllm_qr"):
return self.available_allreduce_impls["vllm_qr"][0]
else:
return self.fallback_impl

def _dispatch_impl(self, input_: torch.Tensor, payload_size_KB: int,
device_arch: str, tp_size: int):
if device_arch not in self.supported_device_archs:
logger.debug(
"Device architecture {device_arch} not supported, using pynccl"
)
return self.fallback_impl

if device_arch == "gfx95":
return self._dispatch_gfx95(input_, payload_size_KB, tp_size)
elif device_arch == "gfx94":
return self._dispatch_gfx94(input_, payload_size_KB, tp_size)
else:
return self.fallback_impl

def dispatch(self, input_: torch.Tensor) -> Optional[Callable[..., Any]]:
"""Dispatch the allreduce implementation"""
payload_size_KB = int(input_.numel() * input_.element_size() / 1024.0)
op = self._dispatch_impl(input_, payload_size_KB, self.cur_device_arch,
self.tp_size)
return op


class CudaCommunicator(DeviceCommunicatorBase):

def __init__(self,
Expand Down Expand Up @@ -51,8 +219,6 @@ def __init__(self,
CustomAllreduce)
from vllm.distributed.device_communicators.pynccl import (
PyNcclCommunicator)
from vllm.distributed.device_communicators.quick_all_reduce import (
QuickAllReduce)
from vllm.distributed.device_communicators.symm_mem import (
SymmMemCommunicator)

Expand All @@ -66,8 +232,12 @@ def __init__(self,
register_nccl_symmetric_ops(self.pynccl_comm)

self.ca_comm: Optional[CustomAllreduce] = None
self.qr_comm: Optional[QuickAllReduce] = None
self.symm_mem_comm: Optional[SymmMemCommunicator] = None

# Initialize a custom all-reduce dispatcher for ROCm platform
self.rocm_allreduce_dispatcher: Optional[
ROCmAllreduceDispatcher] = None

if use_torch_symm_mem and current_platform.is_cuda():
self.symm_mem_comm = SymmMemCommunicator(
group=self.cpu_group,
Expand All @@ -84,13 +254,12 @@ def __init__(self,
)

if current_platform.is_rocm():
# Initialize a custom quick all-reduce implementation for AMD.
# Quick reduce is designed as a complement to custom allreduce.
# Based on quickreduce (https://github.com/mk1-project/quickreduce).
# If it's a rocm, 'use_custom_allreduce==True' means it must
# currently be an MI300 series.
self.qr_comm = QuickAllReduce(group=self.cpu_group,
device=self.device)
self.rocm_allreduce_dispatcher = \
ROCmAllreduceDispatcher(group=self.cpu_group,
device=self.device,
ca_comm=self.ca_comm,
pynccl_comm=self.pynccl_comm)
logger.info("Initializing ROCm allreduce dispatcher.")

if self.use_all2all:
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
Expand Down Expand Up @@ -123,36 +292,37 @@ def __init__(self,
raise ValueError(f"Unknown all2all backend: {all2all_backend}")

def all_reduce(self, input_):
# since currently we perform copy input -> symm_input -> out-of-place AR
# return symm_output, we don't need to check if input is symmetric
if self.pynccl_comm is not None and \
should_nccl_symm_mem_allreduce(self.pynccl_comm.world_size,input_):
out = torch.ops.vllm.all_reduce_symmetric_with_copy(input_)
if out is not None:
if current_platform.is_rocm() and \
self.rocm_allreduce_dispatcher is not None:
op = self.rocm_allreduce_dispatcher.dispatch(input_)
logger.debug("ROCm allreduce dispatcher dispatched: {op}")
out = None if op is None else op(input_)
else:
# since currently we perform:
# copy input -> symm_input -> out-of-place AR
# return symm_output, we don't need to check if input is symmetric
if self.pynccl_comm is not None and should_nccl_symm_mem_allreduce(
self.pynccl_comm.world_size, input_):
out = torch.ops.vllm.all_reduce_symmetric_with_copy(input_)
if out is not None:
return out
ca_comm = self.ca_comm
if ca_comm is not None and not ca_comm.disabled and \
ca_comm.should_custom_ar(input_):
out = ca_comm.custom_all_reduce(input_)
assert out is not None
return out
# always try quick reduce first, then custom allreduce,
# and then pynccl. (quick reduce just for ROCM MI3*)
qr_comm = self.qr_comm
if qr_comm is not None and not qr_comm.disabled and \
qr_comm.should_quick_allreduce(input_):
out = qr_comm.quick_all_reduce(input_)
assert out is not None
return out
ca_comm = self.ca_comm
if ca_comm is not None and not ca_comm.disabled and \
ca_comm.should_custom_ar(input_):
out = ca_comm.custom_all_reduce(input_)
assert out is not None
return out
symm_mem_comm = self.symm_mem_comm
if symm_mem_comm is not None and \
symm_mem_comm.should_use_symm_mem(input_):
out = symm_mem_comm.all_reduce(input_)
assert out is not None
return out
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None
out = pynccl_comm.all_reduce(input_)
symm_mem_comm = self.symm_mem_comm
if symm_mem_comm is not None and \
symm_mem_comm.should_use_symm_mem(input_):
out = symm_mem_comm.all_reduce(input_)
assert out is not None
return out
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None
out = pynccl_comm.all_reduce(input_)

# fallback to the default all-reduce using PyTorch.
if out is None:
# fall back to the default all-reduce using PyTorch.
# this usually happens during testing.
Expand Down