From 13a95ccac7a2512c6a0ea2435014f6c7e64f5a9e Mon Sep 17 00:00:00 2001 From: ceci3 <592712189@qq.com> Date: Fri, 18 Apr 2025 16:35:50 +0800 Subject: [PATCH] support flagcx in vllm --- .../device_communicators/flagcx_wrapper.py | 339 ++++++++++++++++++ .../device_communicators/pyflagcx.py | 229 ++++++++++++ .../kv_transfer/kv_pipe/pynccl_pipe.py | 288 +++++++++++++++ 3 files changed, 856 insertions(+) create mode 100644 flagscale/inference/vllm/distributed/device_communicators/flagcx_wrapper.py create mode 100644 flagscale/inference/vllm/distributed/device_communicators/pyflagcx.py create mode 100644 flagscale/inference/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py diff --git a/flagscale/inference/vllm/distributed/device_communicators/flagcx_wrapper.py b/flagscale/inference/vllm/distributed/device_communicators/flagcx_wrapper.py new file mode 100644 index 000000000..af422d6e9 --- /dev/null +++ b/flagscale/inference/vllm/distributed/device_communicators/flagcx_wrapper.py @@ -0,0 +1,339 @@ +# SPDX-License-Identifier: Apache-2.0 +# reference https://github.com/vllm-project/vllm/blob/main/vllm/distributed/device_communicators/pynccl_wrapper.py + +import ctypes +import platform +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +import torch +from torch.distributed import ReduceOp + +from vllm.logger import init_logger + +logger = init_logger(__name__) + +# === export types and functions from flagcx to Python === +# for the original flagcx definition, please check +# https://github.com/FlagOpen/FlagCX/blob/main/flagcx/include/flagcx.h + +flagcxResult_t = ctypes.c_int +flagcxComm_t = ctypes.c_void_p + +class flagcxUniqueId(ctypes.Structure): + _fields_ = [("internal", ctypes.c_byte * 256)] + +cudaStream_t = ctypes.c_void_p +flagcxStream_t = ctypes.c_void_p +buffer_type = ctypes.c_void_p + +flagcxDataType_t = ctypes.c_int + + +class flagcxDataTypeEnum: + flagcxInt8 = 0 + flagcxChar = 0 + flagcxUint8 = 1 + flagcxInt32 = 2 + flagcxInt = 2 + flagcxUint32 = 3 + flagcxInt64 = 4 + flagcxUint64 = 5 + flagcxFloat16 = 6 + flagcxHalf = 6 + flagcxFloat32 = 7 + flagcxFloat = 7 + flagcxFloat64 = 8 + flagcxDouble = 8 + flagcxBfloat16 = 9 + flagcxNumTypes = 10 + + @classmethod + def from_torch(cls, dtype: torch.dtype) -> int: + if dtype == torch.int8: + return cls.flagcxInt8 + if dtype == torch.uint8: + return cls.flagcxUint8 + if dtype == torch.int32: + return cls.flagcxInt32 + if dtype == torch.int64: + return cls.flagcxInt64 + if dtype == torch.float16: + return cls.flagcxFloat16 + if dtype == torch.float32: + return cls.flagcxFloat32 + if dtype == torch.float64: + return cls.flagcxFloat64 + if dtype == torch.bfloat16: + return cls.flagcxBfloat16 + raise ValueError(f"Unsupported dtype: {dtype}") + + +flagcxRedOp_t = ctypes.c_int + + +class flagcxRedOpTypeEnum: + flagcxSum = 0 + flagcxProd = 1 + flagcxMax = 2 + flagcxMin = 3 + flagcxAvg = 4 + flagcxNumOps = 5 + + @classmethod + def from_torch(cls, op: ReduceOp) -> int: + if op == ReduceOp.SUM: + return cls.flagcxSum + if op == ReduceOp.PRODUCT: + return cls.flagcxProd + if op == ReduceOp.MAX: + return cls.flagcxMax + if op == ReduceOp.MIN: + return cls.flagcxMin + if op == ReduceOp.AVG: + return cls.flagcxAvg + raise ValueError(f"Unsupported op: {op}") + + +@dataclass +class Function: + name: str + restype: Any + argtypes: List[Any] + + +class FLAGCXLibrary: + exported_functions = [ + # const char *flagcxGetErrorString(flagcxResult_t result); + Function("flagcxGetErrorString", ctypes.c_char_p, [flagcxResult_t]), + # flagcxResult_t flagcxGetVersion(int *version); + Function("flagcxGetVersion", flagcxResult_t, + [ctypes.POINTER(ctypes.c_int)]), + # flagcxResult_t flagcxGetUniqueId(flagcxUniqueId_t *uniqueId); + Function("flagcxGetUniqueId", flagcxResult_t, + [ctypes.POINTER(ctypes.POINTER(flagcxUniqueId))]), + # [ctypes.POINTER(ctypes.POINTER(flagcxUniqueId))]), + # flagcxResult_t flagcxCommInitRank(flagcxComm_t *comm, int nranks, + # flagcxUniqueId_t commId, int rank); + # note that flagcxComm_t is a pointer type, so the first argument + # is a pointer to a pointer + Function("flagcxCommInitRank", flagcxResult_t, [ + ctypes.POINTER(flagcxComm_t), ctypes.c_int, ctypes.POINTER(flagcxUniqueId), + ctypes.c_int + ]), + # flagcxResult_t flagcxAllReduce(const void *sendbuff, void *recvbuff, + # size_t count, flagcxDataType_t datatype, + # flagcxRedOp_t op, flagcxComm_t comm, + # flagcxStream_t stream); + # note that flagcxStream_t is a pointer type, so the last argument + # is a pointer + Function("flagcxAllReduce", flagcxResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, flagcxDataType_t, + flagcxRedOp_t, flagcxComm_t, flagcxStream_t + ]), + + # flagcxResult_t flagcxAllGather(const void *sendbuff, void *recvbuff, + # size_t sendcount, flagcxDataType_t datatype, + # flagcxComm_t comm, flagcxStream_t stream); + # note that flagcxStream_t is a pointer type, so the last argument + # is a pointer + Function("flagcxAllGather", flagcxResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, flagcxDataType_t, + flagcxComm_t, flagcxStream_t + ]), + + # flagcxResult_t flagcxReduceScatter(const void *sendbuff, void *recvbuff, + # size_t recvcount, flagcxDataType_t datatype, + # flagcxRedOp_t op, flagcxComm_t comm, + # flagcxStream_t stream); + # note that flagcxStream_t is a pointer type, so the last argument + # is a pointer + Function("flagcxReduceScatter", flagcxResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, flagcxDataType_t, + flagcxRedOp_t, flagcxComm_t, flagcxStream_t + ]), + + # flagcxResult_t flagcxSend(const void *sendbuff, size_t count, + # flagcxDataType_t datatype, int peer, + # flagcxComm_t comm, flagcxStream_t stream); + Function("flagcxSend", flagcxResult_t, [ + buffer_type, ctypes.c_size_t, flagcxDataType_t, ctypes.c_int, + flagcxComm_t, flagcxStream_t + ]), + + # flagcxResult_t flagcxRecv(void *recvbuff, size_t count, + # flagcxDataType_t datatype, int peer, + # flagcxComm_t comm, flagcxStream_t stream); + Function("flagcxRecv", flagcxResult_t, [ + buffer_type, ctypes.c_size_t, flagcxDataType_t, ctypes.c_int, + flagcxComm_t, flagcxStream_t + ]), + + # flagcxResult_t flagcxBroadcast(const void *sendbuff, void *recvbuff, + # size_t count, flagcxDataType_t datatype, + # int root, flagcxComm_t comm, + # flagcxStream_t stream); + Function("flagcxBroadcast", flagcxResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, flagcxDataType_t, + ctypes.c_int, flagcxComm_t, flagcxStream_t + ]), + + # be cautious! this is a collective call, it will block until all + # processes in the communicator have called this function. + # because Python object destruction can happen in random order, + # it is better not to call it at all. + # flagcxResult_t flagcxCommDestroy(flagcxComm_t comm); + Function("flagcxCommDestroy", flagcxResult_t, [flagcxComm_t]), + + # flagcxResult_t cudaAdaptorStreamCopy(flagcxStream_t *newStream, + # void *oldStream) + Function("_Z21cudaAdaptorStreamCopyPP12flagcxStreamPv", flagcxResult_t, [ctypes.POINTER(flagcxStream_t), flagcxStream_t]), + + # flagcxResult_t cudaAdaptorStreamFree(flagcxStream_t stream) + Function("_Z21cudaAdaptorStreamFreeP12flagcxStream", flagcxResult_t, [flagcxStream_t]), + ] + + # class attribute to store the mapping from the path to the library + # to avoid loading the same library multiple times + path_to_library_cache: Dict[str, Any] = {} + + # class attribute to store the mapping from library path + # to the corresponding dictionary + path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} + + def __init__(self, so_file: Optional[str] = None): + + so_file = so_file #or find_flagcx_library() + + try: + if so_file not in FLAGCXLibrary.path_to_dict_mapping: + lib = ctypes.CDLL(so_file) + FLAGCXLibrary.path_to_library_cache[so_file] = lib + self.lib = FLAGCXLibrary.path_to_library_cache[so_file] + except Exception as e: + logger.error( + "Failed to load flagCX library from %s. " + "It is expected if you are not running on NVIDIA/AMD GPUs." + "Otherwise, the flagcx library might not exist, be corrupted " + "or it does not support the current platform %s. " + "If you already have the library, please set the " + "environment variable VLLM_NCCL_SO_PATH" + " to point to the correct flagcx library path.", so_file, + platform.platform()) + raise e + + if so_file not in FLAGCXLibrary.path_to_dict_mapping: + _funcs: Dict[str, Any] = {} + for func in FLAGCXLibrary.exported_functions: + f = getattr(self.lib, func.name) + f.restype = func.restype + f.argtypes = func.argtypes + _funcs[func.name] = f + FLAGCXLibrary.path_to_dict_mapping[so_file] = _funcs + self._funcs = FLAGCXLibrary.path_to_dict_mapping[so_file] + + def flagcxGetErrorString(self, result: flagcxResult_t) -> str: + return self._funcs["flagcxGetErrorString"](result).decode("utf-8") + + def FLAGCX_CHECK(self, result: flagcxResult_t) -> None: + if result != 0: + error_str = self.flagcxGetErrorString(result) + raise RuntimeError(f"FLAGCX error: {error_str}") + + def flagcxGetVersion(self) -> str: + version = ctypes.c_int() + self.FLAGCX_CHECK(self._funcs["flagcxGetVersion"](ctypes.byref(version))) + version_str = str(version.value) + # something like 21903 --> "2.19.3" + major = version_str[0].lstrip("0") + minor = version_str[1:3].lstrip("0") + patch = version_str[3:].lstrip("0") + return f"{major}.{minor}.{patch}" + + def flagcxGetUniqueId(self) -> flagcxUniqueId: + unique_id = ctypes.POINTER(flagcxUniqueId)() + self.FLAGCX_CHECK(self._funcs["flagcxGetUniqueId"]( + ctypes.byref(unique_id))) + return unique_id + # def flagcxGetUniqueId(self, unique_id: flagcxUniqueId) -> flagcxUniqueId: + # # unique_id = ctypes.POINTER(flagcxUniqueId)() + # self.FLAGCX_CHECK(self._funcs["flagcxGetUniqueId"]( + # ctypes.byref(unique_id))) + # return unique_id + + def flagcxCommInitRank(self, world_size: int, unique_id: flagcxUniqueId, + rank: int) -> flagcxComm_t: + comm = flagcxComm_t() + self.FLAGCX_CHECK(self._funcs["flagcxCommInitRank"](ctypes.byref(comm), + world_size, unique_id, + rank)) + return comm + + def flagcxAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, op: int, comm: flagcxComm_t, + stream: flagcxStream_t) -> None: + # `datatype` actually should be `flagcxDataType_t` + # and `op` should be `flagcxRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.FLAGCX_CHECK(self._funcs["flagcxAllReduce"](sendbuff, recvbuff, count, + datatype, op, comm, + stream)) + + def flagcxReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, op: int, comm: flagcxComm_t, + stream: flagcxStream_t) -> None: + # `datatype` actually should be `flagcxDataType_t` + # and `op` should be `flagcxRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.FLAGCX_CHECK(self._funcs["flagcxReduceScatter"](sendbuff, recvbuff, + count, datatype, op, + comm, stream)) + + def flagcxAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, comm: flagcxComm_t, + stream: flagcxStream_t) -> None: + # `datatype` actually should be `flagcxDataType_t` + # which is an aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.FLAGCX_CHECK(self._funcs["flagcxAllGather"](sendbuff, recvbuff, count, + datatype, comm, stream)) + + def flagcxSend(self, sendbuff: buffer_type, count: int, datatype: int, + dest: int, comm: flagcxComm_t, stream: flagcxStream_t) -> None: + self.FLAGCX_CHECK(self._funcs["flagcxSend"](sendbuff, count, datatype, + dest, comm, stream)) + + def flagcxRecv(self, recvbuff: buffer_type, count: int, datatype: int, + src: int, comm: flagcxComm_t, stream: flagcxStream_t) -> None: + self.FLAGCX_CHECK(self._funcs["flagcxRecv"](recvbuff, count, datatype, src, + comm, stream)) + + def flagcxBroadcast(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, root: int, comm: flagcxComm_t, + stream: flagcxStream_t) -> None: + self.FLAGCX_CHECK(self._funcs["flagcxBroadcast"](sendbuff, recvbuff, count, + datatype, root, comm, + stream)) + + def flagcxCommDestroy(self, comm: flagcxComm_t) -> None: + self.FLAGCX_CHECK(self._funcs["flagcxCommDestroy"](comm)) + + def adaptor_stream_copy(self, old_stream: cudaStream_t): + new_stream = flagcxStream_t() + + self.FLAGCX_CHECK(self._funcs["_Z21cudaAdaptorStreamCopyPP12flagcxStreamPv"](ctypes.byref(new_stream), ctypes.byref(old_stream))) + return new_stream + + def adaptor_stream_free(stream): + self.FLAGCX_CHECK(self._funcs["_Z21cudaAdaptorStreamFreeP12flagcxStream"](stream)) + result = lib.cudaAdaptorStreamFree(stream) + +__all__ = [ + "FLAGCXLibrary", "flagcxDataTypeEnum", "flagcxRedOpTypeEnum", "flagcxUniqueId", + "flagcxComm_t", "flagcxStream_t", "buffer_type", "cudaStream_t" +] diff --git a/flagscale/inference/vllm/distributed/device_communicators/pyflagcx.py b/flagscale/inference/vllm/distributed/device_communicators/pyflagcx.py new file mode 100644 index 000000000..1b336f40f --- /dev/null +++ b/flagscale/inference/vllm/distributed/device_communicators/pyflagcx.py @@ -0,0 +1,229 @@ +# SPDX-License-Identifier: Apache-2.0 +# reference https://github.com/vllm-project/vllm/blob/main/vllm/distributed/device_communicators/pynccl.py + +import os +import ctypes +from typing import Optional, Union + +# ===================== import region ===================== +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup, ReduceOp + +from vllm.distributed.device_communicators.flagcx_wrapper import ( + FLAGCXLibrary, buffer_type, flagcxComm_t, flagcxDataTypeEnum, + flagcxRedOpTypeEnum, flagcxUniqueId, cudaStream_t) +from vllm.distributed.utils import StatelessProcessGroup +from vllm.logger import init_logger +from vllm.utils import current_stream + +logger = init_logger(__name__) + + +class PyFlagcxCommunicator: + + def __init__( + self, + group: Union[ProcessGroup, StatelessProcessGroup], + device: Union[int, str, torch.device], + library_path: Optional[str] = None, + ): + """ + Args: + group: the process group to work on. If None, it will use the + default process group. + device: the device to bind the PyFlagcxCommunicator to. If None, + it will be bind to f"cuda:{local_rank}". + library_path: the path to the flagCX library. If None, it will + use the default library path. + It is the caller's responsibility to make sure each communicator + is bind to a unique device. + """ + if not isinstance(group, StatelessProcessGroup): + assert dist.is_initialized() + assert dist.get_backend(group) != dist.Backend.NCCL, ( + "PyFlagcxCommunicator should be attached to a non-NCCL group.") + # note: this rank is the rank in the group + self.rank = dist.get_rank(group) + self.world_size = dist.get_world_size(group) + else: + self.rank = group.rank + self.world_size = group.world_size + + self.group = group + + # if world_size == 1, no need to create communicator + if self.world_size == 1: + self.available = False + self.disabled = True + return + try: + self.flagcx = FLAGCXLibrary(library_path) + except Exception: + # disable because of missing flagCX library + # e.g. in a non-GPU environment + self.available = False + self.disabled = True + return + + self.available = True + self.disabled = False + + logger.info("vLLM is using flagcx==%s", self.flagcx.flagcxGetVersion()) + + if self.rank == 0: + # get the unique id from flagCX + self.unique_id = self.flagcx.flagcxGetUniqueId().contents + else: + # construct an empty unique id + self.unique_id = flagcxUniqueId() + + if not isinstance(group, StatelessProcessGroup): + tensor = torch.ByteTensor(list(self.unique_id.internal)) + ranks = dist.get_process_group_ranks(group) + # arg `src` in `broadcast` is the global rank + dist.broadcast(tensor, src=ranks[0], group=group) + byte_list = tensor.tolist() + for i, byte in enumerate(byte_list): + self.unique_id.internal[i] = byte + else: + self.unique_id = group.broadcast_obj(self.unique_id, src=0) + + if isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + # now `device` is a `torch.device` object + assert isinstance(device, torch.device) + self.device = device + # flagcx communicator and stream will use this device + # `torch.cuda.device` is a context manager that changes the + # current cuda device to the specified one + with torch.cuda.device(device): + self.comm = self.flagcx.flagcxCommInitRank( + self.world_size, ctypes.byref(self.unique_id), self.rank) + stream = current_stream() + + # A small all_reduce for warmup. + data = torch.zeros(1, device=device) + self.all_reduce(data) + + stream.synchronize() + del data + + def all_reduce(self, + in_tensor: torch.Tensor, + op: ReduceOp = ReduceOp.SUM, + stream=None) -> torch.Tensor: + if self.disabled: + return None + # flagcx communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert in_tensor.device == self.device, ( + f"this flagcx communicator is created to work on {self.device}, " + f"but the input tensor is on {in_tensor.device}") + + out_tensor = torch.empty_like(in_tensor) + + if stream is None: + stream = current_stream() + flagcx_stream = self.flagcx.adaptor_stream_copy(cudaStream_t(stream.cuda_stream)) + self.flagcx.flagcxAllReduce(buffer_type(in_tensor.data_ptr()), + buffer_type(out_tensor.data_ptr()), + in_tensor.numel(), + flagcxDataTypeEnum.from_torch(in_tensor.dtype), + flagcxRedOpTypeEnum.from_torch(op), self.comm, + flagcx_stream) + return out_tensor + + def all_gather(self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + stream=None): + if self.disabled: + return + # flagcx communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert input_tensor.device == self.device, ( + f"this flagcx communicator is created to work on {self.device}, " + f"but the input tensor is on {input_tensor.device}") + if stream is None: + stream = current_stream() + flagcx_stream = self.flagcx.adaptor_stream_copy(cudaStream_t(stream.cuda_stream)) + self.flagcx.flagcxAllGather( + buffer_type(input_tensor.data_ptr()), + buffer_type(output_tensor.data_ptr()), input_tensor.numel(), + flagcxDataTypeEnum.from_torch(input_tensor.dtype), self.comm, + flagcx_stream) + + def reduce_scatter(self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + op: ReduceOp = ReduceOp.SUM, + stream=None): + if self.disabled: + return + # flagcx communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert input_tensor.device == self.device, ( + f"this flagcx communicator is created to work on {self.device}, " + f"but the input tensor is on {input_tensor.device}") + if stream is None: + stream = current_stream() + flagcx_stream = self.flagcx.adaptor_stream_copy(cudaStream_t(stream.cuda_stream)) + self.flagcx.flagcxReduceScatter( + buffer_type(input_tensor.data_ptr()), + buffer_type(output_tensor.data_ptr()), output_tensor.numel(), + flagcxDataTypeEnum.from_torch(input_tensor.dtype), + flagcxRedOpTypeEnum.from_torch(op), self.comm, + flagcx_stream) + + def send(self, tensor: torch.Tensor, dst: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this flagcx communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = current_stream() + flagcx_stream = self.flagcx.adaptor_stream_copy(cudaStream_t(stream.cuda_stream)) + self.flagcx.flagcxSend(buffer_type(tensor.data_ptr()), tensor.numel(), + flagcxDataTypeEnum.from_torch(tensor.dtype), dst, + self.comm, flagcx_stream) + + def recv(self, tensor: torch.Tensor, src: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this flagcx communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = current_stream() + flagcx_stream = self.flagcx.adaptor_stream_copy(cudaStream_t(stream.cuda_stream)) + self.flagcx.flagcxRecv(buffer_type(tensor.data_ptr()), tensor.numel(), + flagcxDataTypeEnum.from_torch(tensor.dtype), src, + self.comm, flagcx_stream) + + def broadcast(self, tensor: torch.Tensor, src: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this flagcx communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = current_stream() + if src == self.rank: + sendbuff = buffer_type(tensor.data_ptr()) + # FLAGCX requires the sender also to have a receive buffer + recvbuff = buffer_type(tensor.data_ptr()) + else: + sendbuff = buffer_type() + recvbuff = buffer_type(tensor.data_ptr()) + flagcx_stream = self.flagcx.adaptor_stream_copy(cudaStream_t(stream.cuda_stream)) + self.flagcx.flagcxBroadcast(sendbuff, recvbuff, tensor.numel(), + flagcxDataTypeEnum.from_torch(tensor.dtype), src, + self.comm, flagcx_stream) + diff --git a/flagscale/inference/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py b/flagscale/inference/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py new file mode 100644 index 000000000..2d8896360 --- /dev/null +++ b/flagscale/inference/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py @@ -0,0 +1,288 @@ +# SPDX-License-Identifier: Apache-2.0 +""" + This module implements a PyNccl pipe for sending and receiving + Optional[torch.Tensor] between distributed ranks with advanced + communication features. + + Key Features: + - Supports sending and receiving tensors with metadata + - Handles both CUDA and CPU device communications + - Implements a non-blocking tensor transfer mechanism + - Manages buffer size and provides backpressure control + - Supports distributed process groups with configurable parameters +""" + +import threading +import time +from concurrent.futures import ThreadPoolExecutor +from typing import Callable, Dict, Optional, Tuple + +import torch + +from vllm.config import KVTransferConfig +from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator +from vllm.distributed.device_communicators.pyflagcx import PyFlagcxCommunicator +from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase +from vllm.distributed.utils import StatelessProcessGroup +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class BrokenPipeException(Exception): + + def __init__(self, message): + self.message = message + super().__init__(self.message) + + +Metadata = Dict[str, Optional[torch.Tensor]] + + +class PyNcclPipe(KVPipeBase): + + METADATA_LENGTH = 16 + MAX_TENSOR_DIMENSIONS = 14 + METADATA_DTYPE = torch.int64 + + def __init__(self, + local_rank: int, + config: KVTransferConfig, + device: Optional[str] = None, + port_offset: int = 0): + self.config = config + self.local_rank = local_rank + self.kv_rank = self.config.kv_rank + self.kv_parallel_size = self.config.kv_parallel_size + if device is None: + self.device = self._select_device(self.config.kv_buffer_device) + else: + self.device = self._select_device(device) + + # build distributed connection and send/recv implementation + store_timeout = self.config.get_from_extra_config("store_timeout", 300) + self.group = StatelessProcessGroup.create( + host=self.config.kv_ip, + port=self.config.kv_port + port_offset, + rank=self.kv_rank, + world_size=self.kv_parallel_size, + store_timeout=store_timeout, + ) + # add a barrier to make sure the connection is initiated properly + self.group.barrier() + impl = self._get_device_send_recv_impl(self.group) + self.device_send_func, self.device_recv_func = impl + # set target rank + self.target_rank_for_send = (self.kv_rank + 1) % self.kv_parallel_size + self.target_rank_for_recv = (self.kv_rank - 1) % self.kv_parallel_size + + # transportation-related variables + self.transport_thread: Optional[ThreadPoolExecutor] = None + self.buffer_size = 0 + self.buffer_size_lock = threading.Lock() + self.buffer_size_thresh = self.config.kv_buffer_size + + def _get_device_send_recv_impl( + self, group: StatelessProcessGroup + ) -> Tuple[Callable[[torch.Tensor, int], None], Callable[ + [torch.Tensor, int], None]]: + + send: Callable[[torch.Tensor, int], None] + recv: Callable[[torch.Tensor, int], None] + if self.device.type == "cuda": + try: + flagcx_path = os.getenv('FLAGCX_PATH') + comm = PyFlagcxCommunicator(group, device=self.local_rank, library_path=os.path.join(flagcx_path, "build/lib/libflagcx.so")) + comm.disabled = False + send, recv = comm.send, comm.recv # type: ignore + except: + # use PyNCCL for send / recv + comm = PyNcclCommunicator(group, device=self.local_rank) + comm.disabled = False + send, recv = comm.send, comm.recv # type: ignore + + else: + # This send / recv implementation here is NOT intended to transfer + # KV caches (and should NOT be repurposed to transfer KV caches). + # Currently it is only used to transmit control-plane messages + # for PyNcclBuffer. + send = group.send_obj + + def my_recv(x, src): + x[...] = group.recv_obj(src) + + recv = my_recv + + return send, recv + + def _select_device(self, device: str): + logger.info("Selecting device: %s", device) + if device == "cuda": + return torch.device(f"cuda:{self.local_rank}") + else: + return torch.device("cpu") + + def _make_metadata(self, tensor: Optional[torch.Tensor]) -> Metadata: + """ + Create the metadata as a dictionary based on the input tensor. + + Parameters: + - tensor: The input tensor or None if no tensor is provided. + + Returns: + - metadata: A dictionary with the following keys: + - "dtype": The data type of the tensor or None. + - "shape": The shape of the tensor or None. + """ + if tensor is None: + return {"dtype": None, "shape": None} + else: + return {"dtype": tensor.dtype, "shape": tensor.shape} + + def _prepare_recv_buffer(self, metadata: Metadata) -> torch.Tensor: + """ + Create a buffer to receive the tensor based on the provided metadata. + + Parameters: + - metadata: A dictionary with keys "dtype" and "shape", describing + the tensor's data type and shape. + + Returns: + - buffer: A tensor of the specified type and shape, allocated on + self.device. + """ + return torch.empty(metadata["shape"], + dtype=metadata["dtype"], + device=self.device) + + def _send_metadata(self, metadata: Metadata): + """ + Send the metadata dictionary to the target rank. + + Parameters: + - metadata: A dictionary with keys "dtype" and "shape". + """ + self.group.send_obj(metadata, self.target_rank_for_send) + + def _recv_metadata(self) -> Metadata: + """ + Receive the metadata dictionary from the target rank. + + Returns: + - metadata: A dictionary with keys "dtype" and "shape" describing + the tensor. + """ + return self.group.recv_obj(self.target_rank_for_recv) + + def _send_impl(self, tensor: Optional[torch.Tensor]) -> None: + """ + The actual implementation of sending the tensor and its metadata to the + target rank. + + Parameters: + - tensor: The input tensor to be sent, or None if no tensor is + being sent. + """ + metadata = self._make_metadata(tensor) + self._send_metadata(metadata) + if tensor is not None: + self.device_send_func(tensor.to(self.device), + self.target_rank_for_send) + + def _recv_impl(self) -> Optional[torch.Tensor]: + """ + The actual implementation of receiving a tensor and its metadata from + the target rank. + + Returns: + - buffer: The received tensor, or None if no tensor is received. + """ + metadata = self._recv_metadata() + if metadata["dtype"] is None: + return None + buffer = self._prepare_recv_buffer(metadata) + self.device_recv_func(buffer, self.target_rank_for_recv) + + return buffer + + def send_tensor_wrapper(self, tensor: Optional[torch.Tensor], + tensor_size: int) -> None: + """ + Wrapper for _send_impl to handle exceptions and update buffer size. + """ + try: + self._send_impl(tensor) + + with self.buffer_size_lock: + self.buffer_size -= tensor_size + except Exception as e: + logger.error("[rank%d]: Exception when trying to send %s, msg: %s", + torch.distributed.get_rank(), str(tensor), str(e)) + import traceback + traceback.print_exc() + + def block_if_full(self): + """ + Block the current thread if the buffer size is larger than the + threshold. + """ + while self.buffer_size > self.buffer_size_thresh: + logger.debug("KV cache transfer pipe is full. Waiting...") + time.sleep(0.05) + + def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: + """ + Sends a tensor and its metadata to the destination rank in a + non-blocking way. + + Parameters: + - tensor: The tensor to send, or None if no tensor is being sent. + """ + if self.transport_thread is None: + self.transport_thread = ThreadPoolExecutor(max_workers=1) + + if tensor is not None: + tensor_size = tensor.element_size() * tensor.numel() + else: + tensor_size = 0 + + self.block_if_full() + + with self.buffer_size_lock: + self.buffer_size += tensor_size + + self.transport_thread.submit(self.send_tensor_wrapper, tensor, + tensor_size) + + def recv_tensor(self) -> Optional[torch.Tensor]: + """ + Receives a tensor and its metadata from the source rank. Blocking call. + + Returns: + - tensor: The received tensor, or None if no tensor is received. + """ + if self.transport_thread is None: + self.transport_thread = ThreadPoolExecutor(max_workers=1) + + future = self.transport_thread.submit(self._recv_impl) + + try: + tensor = future.result() + except Exception as e: + logger.error("Encountering exception in KV receiving thread") + logger.error("%s", e) + logger.error("My device: %s", self.device) + import traceback + traceback.print_exc() + raise e + + return tensor + + def close(self): + """ + Close the pipe and release associated resources. + """ + if hasattr(self, + "transport_thread") and self.transport_thread is not None: + self.transport_thread.shutdown() +