Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 44 additions & 4 deletions vllm/distributed/device_communicators/cuda_communicator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import Optional, Union
from importlib.util import find_spec
from typing import Optional, Protocol, Union

import torch
from torch.distributed import ProcessGroup
Expand All @@ -15,6 +16,46 @@
logger = init_logger(__name__)


class CustomAllreduceProtocol(Protocol):
"""Protocol for custom allreduce implementations.
used just to bypass mypy error"""

disabled: bool = True

def __init__(self, group: ProcessGroup,
device: Union[int, str, torch.device]) -> None:
...

def should_custom_ar(self, inp: torch.Tensor):
...

def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
...


def is_rocm_aiter_custom_allreduce_enabled() -> bool:
"""Check if aiter custom allreduce is enabled for ROCm platform."""
from vllm.platforms.rocm import on_gfx9
return current_platform.is_rocm() \
and on_gfx9() \
and envs.VLLM_ROCM_USE_AITER \
and envs.VLLM_ROCM_USE_AITER_CUSTOM_ALL_REDUCE \
and find_spec("aiter.dist.custom_all_reduce") is not None \


def dispatch_custom_allreduce() -> type[CustomAllreduceProtocol]:
"""Dispatch the custom allreduce implementation based on the platform."""
if is_rocm_aiter_custom_allreduce_enabled():
from aiter.dist.custom_all_reduce import CustomAllreduce
logger.info_once(
"Using aiter.dist.custom_all_reduce for ROCm platform")
else:
from vllm.distributed.device_communicators.custom_all_reduce import ( # noqa: E501
CustomAllreduce)

return CustomAllreduce


class CudaCommunicator(DeviceCommunicatorBase):

def __init__(self,
Expand All @@ -38,8 +79,7 @@ def __init__(self,
self.use_custom_allreduce = use_custom_allreduce

# lazy import to avoid documentation build error
from vllm.distributed.device_communicators.custom_all_reduce import (
CustomAllreduce)
CustomAllreduce = dispatch_custom_allreduce()
from vllm.distributed.device_communicators.pynccl import (
PyNcclCommunicator)
from vllm.distributed.device_communicators.quick_all_reduce import (
Expand All @@ -54,7 +94,7 @@ def __init__(self,
device=self.device,
)

self.ca_comm: Optional[CustomAllreduce] = None
self.ca_comm: Optional[CustomAllreduceProtocol] = None
self.qr_comm: Optional[QuickAllReduce] = None
self.symm_mem_comm: Optional[SymmMemCommunicator] = None
if use_custom_allreduce and self.world_size > 1:
Expand Down
7 changes: 7 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
VLLM_ROCM_USE_AITER_MLA: bool = True
VLLM_ROCM_USE_AITER_MHA: bool = True
VLLM_ROCM_USE_AITER_CUSTOM_ALL_REDUCE: bool = True
VLLM_ROCM_USE_AITER_FP8BMM: bool = True
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True
Expand Down Expand Up @@ -775,6 +776,12 @@ def get_vllm_port() -> Optional[int]:
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in
("true", "1")),

# Whether to use aiter custom allreduce for ROCm platform.
# By default is disabled, uses vLLM built-in custom allreduce.
"VLLM_ROCM_USE_AITER_CUSTOM_ALL_REDUCE":
lambda:
(os.getenv("VLLM_ROCM_USE_AITER_CUSTOM_ALL_REDUCE", "True").lower() in
("true", "1")),
# Whether to use aiter triton fp8 bmm kernel
# By default is enabled.
"VLLM_ROCM_USE_AITER_FP8BMM":
Expand Down