diff --git a/examples/qwen/conf/config_qwen2.5_7b_disagg_xpyd.yaml b/examples/qwen/conf/config_qwen2.5_7b_disagg_xpyd.yaml new file mode 100644 index 000000000..2e4c723dd --- /dev/null +++ b/examples/qwen/conf/config_qwen2.5_7b_disagg_xpyd.yaml @@ -0,0 +1,36 @@ +defaults: +- _self_ +- serve: serve_qwen2.5_7b + +experiment: + exp_name: qwen2.5_7b + exp_dir: outputs/${experiment.exp_name} + task: + type: serve + deploy: + port: 10001 + use_fs_serve: false + prefill_decode_disaggregation: true + prefill_num: 2 + prefill_address: x.x.x.x # optional, default "auto" + decode_num: 2 + decode_address: x.x.x.x # optional, default "auto" + runner: + hostfile: examples/qwen/conf/hostfile.txt + docker: fr-v2 + envs: + CUDA_DEVICE_MAX_CONNECTIONS: 1 + VLLM_USE_V1: 0 + FLAGCX_SOCKET_IFNAME: bond0 + FLAGCX_PATH: /path/to/FlagCX/ + FLAGCX_DEBUG: TRACE + FLAGCX_DEBUG_SUBSYS: ALL + USE_FLAGCX: true + cmds: + before_start: source /root/miniconda3/bin/activate flagscale-inference + +action: run + +hydra: + run: + dir: ${experiment.exp_dir}/hydra diff --git a/examples/qwen/conf/hostfile.txt b/examples/qwen/conf/hostfile.txt new file mode 100644 index 000000000..0d8b1e05f --- /dev/null +++ b/examples/qwen/conf/hostfile.txt @@ -0,0 +1,5 @@ +# ip slots type=xxx[optional] +# master node +x.x.x.x slots=8 type=gpu +# worker nodes +x.x.x.x slots=8 type=gpu diff --git a/examples/qwen/conf/serve/serve_qwen2.5_7b.yaml b/examples/qwen/conf/serve/serve_qwen2.5_7b.yaml index 387d29b07..b1bd95bfb 100644 --- a/examples/qwen/conf/serve/serve_qwen2.5_7b.yaml +++ b/examples/qwen/conf/serve/serve_qwen2.5_7b.yaml @@ -2,6 +2,7 @@ engine: vllm engine_args: model: /models/Qwen2.5-7B-Instruct + host: 0.0.0.0 tensor_parallel_size: 1 pipeline_parallel_size: 1 gpu_memory_utilization: 0.9 diff --git a/flagscale/backends/vllm/vllm/distributed/device_communicators/pynccl_wrapper.py b/flagscale/backends/vllm/vllm/distributed/device_communicators/pynccl_wrapper.py new file mode 100644 index 000000000..7b490ca56 --- /dev/null +++ b/flagscale/backends/vllm/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -0,0 +1,364 @@ +# Copied from https://github.com/vllm-project/vllm/blob/1ad957950ffc1552af5abda78c03d88ddb67945b/vllm/distributed/device_communicators/pynccl_wrapper.py. +# Below is the original copyright: + +# SPDX-License-Identifier: Apache-2.0 + +# This file is a pure Python wrapper for the NCCL library. +# The main purpose is to use NCCL combined with CUDA graph. +# Before writing this script, we tried the following approach: +# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself +# often gets stuck when initializing the NCCL communicator. +# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce` +# contains many other potential cuda APIs, that are not allowed during +# capturing the CUDA graph. For further details, please check +# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ . +# +# Another rejected idea is to write a C/C++ binding for NCCL. It is usually +# doable, but we often encounter issues related with nccl versions, and need +# to switch between different versions of NCCL. See +# https://github.com/NVIDIA/nccl/issues/1234 for more details. +# A C/C++ binding is not flexible enough to handle this. It requires +# recompilation of the code every time we want to switch between different +# versions. This current implementation, with a **pure** Python wrapper, is +# more flexible. We can easily switch between different versions of NCCL by +# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file` +# variable in the code. + +import ctypes +import platform +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +import torch +from torch.distributed import ReduceOp + +from vllm.logger import init_logger +from vllm.utils import find_nccl_library + +logger = init_logger(__name__) + +# === export types and functions from nccl to Python === +# for the original nccl definition, please check +# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in + +ncclResult_t = ctypes.c_int +ncclComm_t = ctypes.c_void_p + + +class ncclUniqueId(ctypes.Structure): + _fields_ = [("internal", ctypes.c_byte * 128)] + + +cudaStream_t = ctypes.c_void_p +buffer_type = ctypes.c_void_p + +ncclDataType_t = ctypes.c_int + + +class ncclDataTypeEnum: + ncclInt8 = 0 + ncclChar = 0 + ncclUint8 = 1 + ncclInt32 = 2 + ncclInt = 2 + ncclUint32 = 3 + ncclInt64 = 4 + ncclUint64 = 5 + ncclFloat16 = 6 + ncclHalf = 6 + ncclFloat32 = 7 + ncclFloat = 7 + ncclFloat64 = 8 + ncclDouble = 8 + ncclBfloat16 = 9 + ncclNumTypes = 10 + + @classmethod + def from_torch(cls, dtype: torch.dtype) -> int: + if dtype == torch.int8: + return cls.ncclInt8 + if dtype == torch.uint8: + return cls.ncclUint8 + if dtype == torch.int32: + return cls.ncclInt32 + if dtype == torch.int64: + return cls.ncclInt64 + if dtype == torch.float16: + return cls.ncclFloat16 + if dtype == torch.float32: + return cls.ncclFloat32 + if dtype == torch.float64: + return cls.ncclFloat64 + if dtype == torch.bfloat16: + return cls.ncclBfloat16 + raise ValueError(f"Unsupported dtype: {dtype}") + + +ncclRedOp_t = ctypes.c_int + + +class ncclRedOpTypeEnum: + ncclSum = 0 + ncclProd = 1 + ncclMax = 2 + ncclMin = 3 + ncclAvg = 4 + ncclNumOps = 5 + + @classmethod + def from_torch(cls, op: ReduceOp) -> int: + if op == ReduceOp.SUM: + return cls.ncclSum + if op == ReduceOp.PRODUCT: + return cls.ncclProd + if op == ReduceOp.MAX: + return cls.ncclMax + if op == ReduceOp.MIN: + return cls.ncclMin + if op == ReduceOp.AVG: + return cls.ncclAvg + raise ValueError(f"Unsupported op: {op}") + + +@dataclass +class Function: + name: str + restype: Any + argtypes: List[Any] + + +class NCCLLibrary: + exported_functions = [ + # const char* ncclGetErrorString(ncclResult_t result) + Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]), + # ncclResult_t ncclGetVersion(int *version); + Function("ncclGetVersion", ncclResult_t, + [ctypes.POINTER(ctypes.c_int)]), + # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); + Function("ncclGetUniqueId", ncclResult_t, + [ctypes.POINTER(ncclUniqueId)]), + # ncclResult_t ncclCommInitRank( + # ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); + # note that ncclComm_t is a pointer type, so the first argument + # is a pointer to a pointer + Function("ncclCommInitRank", ncclResult_t, [ + ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, + ctypes.c_int + ]), + # ncclResult_t ncclAllReduce( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function("ncclAllReduce", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ncclRedOp_t, ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclAllGather( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function("ncclAllGather", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclReduceScatter( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function("ncclReduceScatter", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ncclRedOp_t, ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclSend( + # const void* sendbuff, size_t count, ncclDataType_t datatype, + # int dest, ncclComm_t comm, cudaStream_t stream); + Function("ncclSend", ncclResult_t, [ + buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, + ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclRecv( + # void* recvbuff, size_t count, ncclDataType_t datatype, + # int src, ncclComm_t comm, cudaStream_t stream); + Function("ncclRecv", ncclResult_t, [ + buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, + ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclBroadcast( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, int root, ncclComm_t comm, + # cudaStream_t stream); + Function("ncclBroadcast", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ctypes.c_int, ncclComm_t, cudaStream_t + ]), + + # be cautious! this is a collective call, it will block until all + # processes in the communicator have called this function. + # because Python object destruction can happen in random order, + # it is better not to call it at all. + # ncclResult_t ncclCommDestroy(ncclComm_t comm); + Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]), + ] + + # class attribute to store the mapping from the path to the library + # to avoid loading the same library multiple times + path_to_library_cache: Dict[str, Any] = {} + + # class attribute to store the mapping from library path + # to the corresponding dictionary + path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} + + def __init__(self, so_file: Optional[str] = None): + + so_file = so_file or find_nccl_library() + + try: + if so_file not in NCCLLibrary.path_to_dict_mapping: + lib = ctypes.CDLL(so_file) + NCCLLibrary.path_to_library_cache[so_file] = lib + self.lib = NCCLLibrary.path_to_library_cache[so_file] + except Exception as e: + logger.error( + "Failed to load NCCL library from %s. " + "It is expected if you are not running on NVIDIA/AMD GPUs." + "Otherwise, the nccl library might not exist, be corrupted " + "or it does not support the current platform %s. " + "If you already have the library, please set the " + "environment variable VLLM_NCCL_SO_PATH" + " to point to the correct nccl library path.", so_file, + platform.platform()) + raise e + + if so_file not in NCCLLibrary.path_to_dict_mapping: + _funcs: Dict[str, Any] = {} + for func in NCCLLibrary.exported_functions: + f = getattr(self.lib, func.name) + f.restype = func.restype + f.argtypes = func.argtypes + _funcs[func.name] = f + NCCLLibrary.path_to_dict_mapping[so_file] = _funcs + self._funcs = NCCLLibrary.path_to_dict_mapping[so_file] + + def ncclGetErrorString(self, result: ncclResult_t) -> str: + return self._funcs["ncclGetErrorString"](result).decode("utf-8") + + def NCCL_CHECK(self, result: ncclResult_t) -> None: + if result != 0: + error_str = self.ncclGetErrorString(result) + raise RuntimeError(f"NCCL error: {error_str}") + + def ncclGetVersion(self) -> str: + version = ctypes.c_int() + self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version))) + version_str = str(version.value) + # something like 21903 --> "2.19.3" + major = version_str[0].lstrip("0") + minor = version_str[1:3].lstrip("0") + patch = version_str[3:].lstrip("0") + return f"{major}.{minor}.{patch}" + + def ncclGetUniqueId(self) -> ncclUniqueId: + unique_id = ncclUniqueId() + self.NCCL_CHECK(self._funcs["ncclGetUniqueId"]( + ctypes.byref(unique_id))) + return unique_id + + def unique_id_from_bytes(self, data: bytes) -> ncclUniqueId: + """ + Reconstructs an `ncclUniqueId` object from bytes data. + + Args: + data: Must be a 128-byte data block (matching NCCL's unique_id). + + Returns: + ncclUniqueId: The reconstructed NCCL Unique ID object. + + Raises: + ValueError: If the input data length is not 128 bytes. + """ + if len(data) != 128: + raise ValueError( + f"Expected 128 bytes for ncclUniqueId, got {len(data)} bytes") + + unique_id = ncclUniqueId() + ctypes.memmove(ctypes.addressof(unique_id.internal), data, 128) + return unique_id + + def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId, + rank: int) -> ncclComm_t: + comm = ncclComm_t() + self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm), + world_size, unique_id, + rank)) + return comm + + def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, op: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count, + datatype, op, comm, + stream)) + + def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, op: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["ncclReduceScatter"](sendbuff, recvbuff, + count, datatype, op, + comm, stream)) + + def ncclAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + # `datatype` actually should be `ncclDataType_t` + # which is an aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["ncclAllGather"](sendbuff, recvbuff, count, + datatype, comm, stream)) + + def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int, + dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None: + self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype, + dest, comm, stream)) + + def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int, + src: int, comm: ncclComm_t, stream: cudaStream_t) -> None: + self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src, + comm, stream)) + + def ncclBroadcast(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, root: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + self.NCCL_CHECK(self._funcs["ncclBroadcast"](sendbuff, recvbuff, count, + datatype, root, comm, + stream)) + + def ncclCommDestroy(self, comm: ncclComm_t) -> None: + self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) + + +__all__ = [ + "NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId", + "ncclComm_t", "cudaStream_t", "buffer_type" +] diff --git a/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_connector/factory.py b/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_connector/factory.py new file mode 100644 index 000000000..52de6757e --- /dev/null +++ b/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -0,0 +1,66 @@ +# Copied from https://github.com/vllm-project/vllm/blob/1ad957950ffc1552af5abda78c03d88ddb67945b/vllm/distributed/kv_transfer/kv_connector/factory.py. +# Below is the original copyright: +# SPDX-License-Identifier: Apache-2.0 + +import importlib +from typing import TYPE_CHECKING, Callable, Dict, Type + +from .base import KVConnectorBase + +if TYPE_CHECKING: + from vllm.config import VllmConfig + + +class KVConnectorFactory: + _registry: Dict[str, Callable[[], Type[KVConnectorBase]]] = {} + + @classmethod + def register_connector(cls, name: str, module_path: str, + class_name: str) -> None: + """Register a connector with a lazy-loading module and class name.""" + if name in cls._registry: + raise ValueError(f"Connector '{name}' is already registered.") + + def loader() -> Type[KVConnectorBase]: + module = importlib.import_module(module_path) + return getattr(module, class_name) + + cls._registry[name] = loader + + @classmethod + def create_connector(cls, rank: int, local_rank: int, + config: "VllmConfig") -> KVConnectorBase: + connector_name = config.kv_transfer_config.kv_connector + if connector_name not in cls._registry: + raise ValueError(f"Unsupported connector type: {connector_name}") + + connector_cls = cls._registry[connector_name]() + return connector_cls(rank, local_rank, config) + + +# Register various connectors here. +# The registration should not be done in each individual file, as we want to +# only load the files corresponding to the current connector. +KVConnectorFactory.register_connector( + "P2pConnector", "vllm.distributed.kv_transfer.kv_connector.p2p_connector", + "P2pConnector") + +KVConnectorFactory.register_connector( + "PyNcclConnector", + "vllm.distributed.kv_transfer.kv_connector.simple_connector", + "SimpleConnector") + +KVConnectorFactory.register_connector( + "MooncakeConnector", + "vllm.distributed.kv_transfer.kv_connector.simple_connector", + "SimpleConnector") + +KVConnectorFactory.register_connector( + "LMCacheConnector", + "vllm.distributed.kv_transfer.kv_connector.lmcache_connector", + "LMCacheConnector") + +KVConnectorFactory.register_connector( + "MooncakeStoreConnector", + "vllm.distributed.kv_transfer.kv_connector.mooncake_store_connector", + "MooncakeStoreConnector") \ No newline at end of file diff --git a/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_connector/p2p_connector.py b/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_connector/p2p_connector.py new file mode 100644 index 000000000..c88444370 --- /dev/null +++ b/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_connector/p2p_connector.py @@ -0,0 +1,305 @@ +# Mainly adopted from https://github.com/FlagOpen/FlagScale/blob/44ceca57dd6f86b10163968e617497c613e47d6e/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_connector/p2p_connector.py. +# Below is the original copyright: +# SPDX-License-Identifier: Apache-2.0 +import os +import re +from typing import TYPE_CHECKING, List, Tuple, Union + +import torch + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +if os.getenv("USE_FLAGCX", "false").lower() in ("1", "true"): + from vllm.distributed.kv_transfer.kv_pipe.flagcx_p2p_nccl_pipe import P2pNcclPipe +else: + from vllm.distributed.kv_transfer.kv_pipe.p2p_nccl_pipe import P2pNcclPipe +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + +logger = init_logger(__name__) + + +class P2pConnector(KVConnectorBase): + + def __init__( + self, + rank: int, + local_rank: int, + config: VllmConfig, + ): + self.rank = rank + self.config = config.kv_transfer_config + self.tp_size = config.parallel_config.tensor_parallel_size + self.is_deepseek_mla = config.model_config.is_deepseek_mla + self.use_mla_opt = not envs.VLLM_MLA_DISABLE + + assert self.config.kv_connector == "P2pConnector" + + self.lookup_buffer_size = self.config.kv_buffer_size + + self.p2p_nccl_pipe = P2pNcclPipe( + local_rank=local_rank, + config=self.config, + hostname="", + port_offset=rank, + ) + + def send_kv_caches_and_hidden_states( + self, + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + ) -> None: + + # input_tokens_tensor = model_input.input_tokens + seq_lens = model_input.attn_metadata.seq_lens + slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() + num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens + request_ids = list(model_input.request_ids_to_seq_ids.keys()) + start_layer = model_executable.model.start_layer + end_layer = model_executable.model.end_layer + + model_config = model_executable.model.config + num_heads = int(model_config.num_key_value_heads / self.tp_size) + hidden_size = model_config.hidden_size + num_attention_heads = model_config.num_attention_heads + + # Deepseek's MLA (Multi-head Latent Attention) uses two different + # kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0. + # When VLLM_MLA_DISABLE=0 (default), forward absorb is applied, + # resulting in a kv_cache shape of [num_blks, blk_size, 1, + # kv_lora_rank + qk_rope_head_dim]. + # When VLLM_MLA_DISABLE=1, standard FA is used instead, leading + # to a kv_cache shape of [2, num_blks, blk_size, + # num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim]. + # For more details, see vllm/attention/backends/mla/common.py. + if self.is_deepseek_mla and self.use_mla_opt: + head_size = model_config.kv_lora_rank + \ + model_config.qk_rope_head_dim + num_heads = 1 + elif self.is_deepseek_mla and not self.use_mla_opt: + head_size = model_config.qk_nope_head_dim + \ + model_config.qk_rope_head_dim + else: + head_size = getattr(model_config, "head_dim", + int(hidden_size // num_attention_heads)) + + # query_lens contains new KV caches that are added to vLLM. + # so we will send them to decode instance + # FIXME(Kuntai): This assume that all requests are prefill. + for idx, slen in enumerate(seq_lens): + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + + if start_pos >= num_prefill_tokens: + # vllm/worker/model_runner.py::_prepare_model_input_tensors: + # - input_tokens[:num_prefill_tokens] contains prefill tokens. + # - input_tokens[num_prefill_tokens:] contains decode tokens. + logger.warning("You have some decode requests while using " + "SimpleConnector. Their KVCache won't be sent.") + break + + # current_tokens = input_tokens_tensor[start_pos:end_pos] + + keys, values = [], [] + + for layer_id in range(start_layer, end_layer): + kv_cache = kv_caches[layer_id - start_layer] + + if self.is_deepseek_mla and self.use_mla_opt: + key_cache = kv_cache.reshape(-1, num_heads, head_size) + value_cache = kv_cache.reshape(-1, num_heads, head_size) + else: + key_cache = kv_cache[0].reshape(-1, num_heads, head_size) + value_cache = kv_cache[1].reshape(-1, num_heads, head_size) + + current_slot_mapping = slot_mapping_flat[start_pos:end_pos] + + keys.append(key_cache[current_slot_mapping].unsqueeze(0)) + values.append(value_cache[current_slot_mapping].unsqueeze(0)) + + keys = torch.cat(keys, dim=0) + values = torch.cat(values, dim=0) + + request_id = request_ids[idx] + ip, port = self.parse_request_id(request_id, True) + remote_address = ip + ":" + str(port + self.rank) + + self.p2p_nccl_pipe.send_tensor(request_id + "keys", keys, + remote_address) + self.p2p_nccl_pipe.send_tensor(request_id + "values", values, + remote_address) + self.p2p_nccl_pipe.send_tensor( + request_id + "hidden", + hidden_or_intermediate_states[start_pos:end_pos], + remote_address) + + logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) + + def recv_kv_caches_and_hidden_states( + self, model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor] + ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, + "ModelInputForGPUWithSamplingMetadata"]: + + # When bypass_model_exec is set to False, it means that at least for one + # request its corresponding KV cache or hidden state is missing. + # In this case we need to do prefilling to recompute missing KV cache + # and hidden states. + bypass_model_exec = True + + model_config = model_executable.model.config + + input_tokens_tensor = model_input.input_tokens + seq_lens = model_input.attn_metadata.seq_lens + num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens + slot_mapping = model_input.attn_metadata.slot_mapping.flatten() + request_ids = list(model_input.request_ids_to_seq_ids.keys()) + + hidden_or_intermediate_states_for_one_req = [] + + input_tokens_list = [] + num_computed_tokens_list = [] + start_pos_list = [] + + # enumerate different requests + # FIXME(Kuntai): This impl assumes that all requests are prefill. + for idx, slen in enumerate(seq_lens): + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + + if start_pos >= num_prefill_tokens: + # This can happen during inflight batching. See: + # vllm/worker/model_runner.py::_prepare_model_input_tensors: + # - input_tokens[:num_prefill_tokens] contains prefill tokens. + # - input_tokens[num_prefill_tokens:] contains decode tokens. + logger.warning("You should set --enable_chunked_prefill=False " + "and --max_num_batched_tokens " + "should be equal to --max_seq_len_to_capture") + bypass_model_exec = False + assert start_pos == num_prefill_tokens + break + + current_tokens = input_tokens_tensor[start_pos:end_pos] + num_tokens = slen + + # collecting data for rebuilding the input + input_tokens_list.append(current_tokens) + start_pos_list.append(start_pos) + + request_id = request_ids[idx] + ip, port = self.parse_request_id(request_id, False) + remote_address = ip + ":" + str(port + self.rank) + + keys = self.p2p_nccl_pipe.recv_tensor(request_id + "keys", + remote_address) + values = self.p2p_nccl_pipe.recv_tensor(request_id + "values", + remote_address) + hidden = self.p2p_nccl_pipe.recv_tensor(request_id + "hidden", + remote_address) + + num_computed_tokens = current_tokens.shape[0] + num_computed_tokens_list.append(num_computed_tokens) + + # check if both KV cache and the hidden states are received + # If not, need to redo the forwarding to compute missing states + if not all([(num_computed_tokens == num_tokens), keys is not None, + values is not None, hidden is not None]): + bypass_model_exec = False + break + + # update the end position based on how many tokens are cached. + end_pos = start_pos + num_computed_tokens + + # put received KV caches into paged memory + for i in range(model_executable.model.start_layer, + model_executable.model.end_layer): + + kv_cache = kv_caches[i - model_executable.model.start_layer] + layer = model_executable.model.layers[i] + + if self.is_deepseek_mla and self.use_mla_opt: + layer.self_attn.attn = layer.self_attn.mla_attn + k_c_normed_k_pe = keys[ + i - model_executable.model.start_layer].to( + kv_cache.device).squeeze(1) + k_c_normed = k_c_normed_k_pe[:, :model_config.kv_lora_rank] + k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank:] + ops.concat_and_cache_mla( + k_c_normed, + k_pe, + kv_cache, + slot_mapping[start_pos:end_pos], + layer.self_attn.attn.kv_cache_dtype, + layer.self_attn.attn._k_scale, + ) + else: + key_cache, value_cache = kv_cache[0], kv_cache[1] + ops.reshape_and_cache_flash( + keys[i - model_executable.model.start_layer].to( + key_cache.device), + values[i - model_executable.model.start_layer].to( + value_cache.device), + key_cache, + value_cache, + slot_mapping[start_pos:end_pos], + layer.self_attn.attn.kv_cache_dtype, + layer.self_attn.attn._k_scale, + layer.self_attn.attn._v_scale, + ) + + hidden_or_intermediate_states_for_one_req.append(hidden) + + if not bypass_model_exec: + # Some of the KV cache is not retrieved + # Here we will fall back to normal model forwarding + # But optionally you can adjust model_input so that you only do + # prefilling on those tokens that are missing KV caches. + logger.warning( + "[rank%d]: Failed to receive all KVs and hidden " + "states, redo model forwarding.", torch.distributed.get_rank()) + hidden_or_intermediate_states = None + + else: + logger.debug( + "[rank%d]: Successfully received all KVs and hidden " + "states, skip model forwarding.", torch.distributed.get_rank()) + hidden_or_intermediate_states = torch.cat( + hidden_or_intermediate_states_for_one_req, dim=0) + + return hidden_or_intermediate_states, bypass_model_exec, model_input + + @staticmethod + def parse_request_id(request_id: str, is_prefill=True) -> Tuple[str, int]: + logger.debug("parse_request_id, request_id: %s, is_prefill: %s", + request_id, is_prefill) + # Regular expression to match the string hostname and integer port + if is_prefill: + pattern = r"___decode_addr_(.*):(\d+)" + else: + pattern = r"___prefill_addr_(.*):(\d+)___" + + # Use re.search to find the pattern in the request_id + match = re.search(pattern, request_id) + if match: + # Extract the ranks + ip = match.group(1) + port = int(match.group(2)) + + logger.debug("parse_request_id, request_id: %s, ip: %s, port: %s", + request_id, ip, str(port)) + return ip, port + raise ValueError( + f"Request id {request_id} does not contain hostname and port") + + def close(self): + self.p2p_nccl_pipe.close() + diff --git a/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_pipe/flagcx_p2p_nccl_pipe.py b/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_pipe/flagcx_p2p_nccl_pipe.py new file mode 100644 index 000000000..3ceec7936 --- /dev/null +++ b/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_pipe/flagcx_p2p_nccl_pipe.py @@ -0,0 +1,458 @@ +# Mainly adopted from https://github.com/vllm-project/vllm/blob/1ad957950ffc1552af5abda78c03d88ddb67945b/vllm/distributed/kv_transfer/kv_pipe/p2p_nccl_pipe.py. +# Below is the original copyright: +# SPDX-License-Identifier: Apache-2.0 +import os +import logging +import threading +import time +import typing +from collections import deque +from typing import Any, Deque, Dict, List, Optional + +import msgpack +import torch +import zmq +import ctypes +import sys +sys.path.append(os.getenv('FLAGCX_PATH')) +from plugin.interservice.flagcx_wrapper import ( + FLAGCXLibrary, + buffer_type, + cudaStream_t, + flagcxComm_t, + flagcxDataTypeEnum, +) +from vllm.config import KVTransferConfig +from vllm.utils import current_stream, get_ip + +logger = logging.getLogger(__name__) + + +class P2pNcclPipe: + + def __init__(self, + local_rank: int, + config: KVTransferConfig, + hostname: str = "", + port_offset: int = 0, + library_path: Optional[str] = None) -> None: + self.config = config + self.rank = port_offset + self.local_rank = local_rank + self.device = torch.device(f"cuda:{self.local_rank}") + flagcx_path = os.getenv('FLAGCX_PATH') + library_path=os.path.join(flagcx_path, "build/lib/libflagcx.so") + self.flagcx = FLAGCXLibrary(library_path) + + if not hostname: + hostname = get_ip() + port = self.config.kv_port + port_offset + if port == 0: + raise ValueError("Port cannot be 0") + self._hostname = hostname + self._port = port + + # Each card corresponds to a ZMQ address. + self.zmq_address = f"{self._hostname}:{self._port}" + + # The `http_port` must be consistent with the port of OpenAI. + self.http_address = ( + f"{self._hostname}:" + f"{self.config.kv_connector_extra_config['http_port']}") + + # If `proxy_ip` or `proxy_port` is `""`, + # then the ping thread will not be enabled. + proxy_ip = self.config.get_from_extra_config("proxy_ip", "") + proxy_port = self.config.get_from_extra_config("proxy_port", "") + if proxy_ip == "" or proxy_port == "": + self.proxy_address = "" + else: + self.proxy_address = proxy_ip + ":" + proxy_port + + self.context = zmq.Context() + self.router_socket = self.context.socket(zmq.ROUTER) + self.router_socket.bind(f"tcp://{self.zmq_address}") + + self.poller = zmq.Poller() + self.poller.register(self.router_socket, zmq.POLLIN) + + self.send_store_cv = threading.Condition() + self.send_queue_cv = threading.Condition() + self.recv_store_cv = threading.Condition() + self.comm_cv = threading.Condition() + + # The sending type includes tree mutually exclusive options: + # PUT, GET, PUT_ASYNC. + self.send_type = self.config.get_from_extra_config("send_type", "PUT") + if self.send_type == "GET": + self.send_store: Dict[str, + torch.Tensor] = {} # tensor_id: torch.Tensor + else: + # PUT or PUT_ASYNC + self.send_queue: Deque[ + List[Any]] = deque() # tensor_id: torch.Tensor + if self.send_type == "PUT_ASYNC": + self._send_thread = threading.Thread(target=self._send_async, + daemon=True) + self._send_thread.start() + + self.recv_store: Dict[str, + torch.Tensor] = {} # tensor_id: torch.Tensor + self.socks: Dict[str, Any] = {} # remote_address: client socket + self.comms: Dict[str, Any] = {} # remote_address: (ncclComm_t, rank) + + self.buffer_size = 0 + self.buffer_size_threshold = self.config.kv_buffer_size + + self._listener_thread = threading.Thread( + target=self._listen_for_requests, daemon=True) + self._listener_thread.start() + + self._ping_thread = None + if port_offset == 0 and self.proxy_address != "": + self._ping_thread = threading.Thread(target=self._ping, + daemon=True) + self._ping_thread.start() + + def _create_connect(self, remote_address: typing.Optional[str] = None): + assert remote_address is not None + if remote_address not in self.socks: + sock = self.context.socket(zmq.DEALER) + sock.setsockopt_string(zmq.IDENTITY, self.zmq_address) + sock.connect(f"tcp://{remote_address}") + self.socks[remote_address] = sock + if remote_address in self.comms: + logger.info("👋comm exists, remote_address:%s, comms:%s", + remote_address, self.comms) + return sock, self.comms[remote_address] + + unique_id = self.flagcx.flagcxGetUniqueId().contents + data = {"cmd": "NEW", "unique_id": bytes(unique_id.internal)} + sock.send(msgpack.dumps(data)) + + with torch.cuda.device(self.device): + rank = 0 + comm = self.flagcx.flagcxCommInitRank( + 2, ctypes.byref(unique_id), rank) + self.comms[remote_address] = (comm, rank) + logger.info("🤝ncclCommInitRank Success, %s👉%s, MyRank: %s", + self.zmq_address, remote_address, rank) + + return self.socks[remote_address], self.comms[remote_address] + + def send_tensor( + self, + tensor_id: str, + tensor: torch.Tensor, + remote_address: typing.Optional[str] = None, + ) -> bool: + if remote_address is None: + with self.recv_store_cv: + self.recv_store[tensor_id] = tensor + self.recv_store_cv.notify() + return True + else: + if self.send_type == "PUT": + return self._send_sync(tensor_id, tensor, remote_address) + elif self.send_type == "PUT_ASYNC": + with self.send_queue_cv: + self.send_queue.append([tensor_id, remote_address, tensor]) + self.send_queue_cv.notify() + else: # GET + with self.send_store_cv: + tensor_size = tensor.element_size() * tensor.numel() + while (self.buffer_size + tensor_size + > self.buffer_size_threshold): + oldest_tenser_id = next(iter(self.send_store)) + oldest_tenser = self.send_store.pop(oldest_tenser_id) + oldest_tenser_size = oldest_tenser.element_size( + ) * oldest_tenser.numel() + self.buffer_size -= oldest_tenser_size + logger.info( + "⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d," + " buffer_size:%d, oldest_tenser_size:%d, rank:%d", + remote_address, tensor_id, tensor_size, + self.buffer_size, oldest_tenser_size, self.rank) + + self.send_store[tensor_id] = tensor + self.buffer_size += tensor_size + logger.info( + "🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, " + "shape:%s, rank:%d, buffer_size:%d(%.2f%%)", + remote_address, tensor_id, tensor_size, tensor.shape, + self.rank, self.buffer_size, + self.buffer_size / self.buffer_size_threshold * 100) + + return True + + def recv_tensor( + self, + tensor_id: str, + remote_address: typing.Optional[str] = None, + ) -> torch.Tensor: + if self.send_type == "PUT" or self.send_type == "PUT_ASYNC": + start_time = time.time() + with self.recv_store_cv: + while tensor_id not in self.recv_store: + self.recv_store_cv.wait() + tensor = self.recv_store[tensor_id] + self.recv_store[tensor_id] = None + while len(self.recv_store) > 10000: + self.recv_store.pop(next(iter(self.recv_store))) + + duration = time.time() - start_time + if tensor is not None: + self.buffer_size -= (tensor.element_size() * tensor.numel()) + logger.info( + "🔵[PUT]Recv From %s, tensor_id:%s, shape:%s, " + "duration:%.3fms, size:%.3fGB, rank:%d", remote_address, + tensor_id, tensor.shape, duration * 1000, + tensor.element_size() * tensor.numel() / 1024**3, + self.rank) + else: + logger.warning( + "🔴[PUT]Recv From %s, tensor_id:%s, duration:%.3fms, " + "rank:%d", remote_address, tensor_id, duration * 1000, + self.rank) + return tensor + + # GET + if remote_address is None: + return None + + if remote_address not in self.socks: + self._create_connect(remote_address) + + sock = self.socks[remote_address] + comm, rank = self.comms[remote_address] + + data = {"cmd": "GET", "tensor_id": tensor_id} + sock.send(msgpack.dumps(data)) + + message = sock.recv() + data = msgpack.loads(message) + if data["ret"] != 0: + logger.warning("🔴[GET]Recv From %s, tensor_id: %s, ret: %d", + remote_address, tensor_id, data["ret"]) + return None + + tensor = torch.empty(data["shape"], + dtype=getattr(torch, data["dtype"]), + device=self.device) + + start_time = time.time() + self._recv(comm, tensor, rank ^ 1) + duration = time.time() - start_time + logger.info( + "🔵[GET]Recv From %s, tensor_id:%s, shape:%s, duration:%.3fms, " + "size:%.3fGB, rank:%d", remote_address, tensor_id, tensor.shape, + duration * 1000, + tensor.element_size() * tensor.numel() / 1024**3, self.rank) + + return tensor + + def _listen_for_requests(self): + while True: + socks = dict(self.poller.poll()) + if self.router_socket in socks: + remote_address, message = self.router_socket.recv_multipart() + data = msgpack.loads(message) + logger.debug("Received message from %s, data:%s", + remote_address.decode(), data) + if data["cmd"] == "NEW": + unique_id = self.flagcx.unique_id_from_bytes( + bytes(data["unique_id"])) + with torch.cuda.device(self.device): + rank = 1 + # comm: ncclComm_t = self.nccl.ncclCommInitRank( + # 2, unique_id, rank) + comm = self.flagcx.flagcxCommInitRank( + 2, ctypes.byref(unique_id), rank) + self.comms[remote_address.decode()] = (comm, rank) + logger.info( + "🤝ncclCommInitRank Success, %s👈%s, MyRank:%s", + self.zmq_address, remote_address.decode(), rank) + elif data["cmd"] == "PUT": + tensor_id = data["tensor_id"] + try: + tensor = torch.empty(data["shape"], + dtype=getattr( + torch, data["dtype"]), + device=self.device) + + tensor_size = tensor.element_size() * tensor.numel() + if (self.buffer_size + tensor_size + > self.buffer_size_threshold): + self.router_socket.send_multipart( + [remote_address, b"2"]) + logger.warning( + "🔴[PUT]Recv Tensor, Out Of Threshold, " + "%s👈%s, data:%s", self.zmq_address, + remote_address.decode(), data) + tensor = None + else: + self.buffer_size += tensor_size + self.router_socket.send_multipart( + [remote_address, b"0"]) + comm, rank = self.comms[remote_address.decode()] + self._recv(comm, tensor, rank ^ 1) + logger.info( + "🔵[PUT]Recv Tensor, %s👈%s, MyRank:%s, " + "data:%s, shape:%s", self.zmq_address, + remote_address.decode(), rank, data, + tensor.shape) + + except torch.cuda.OutOfMemoryError: + self.router_socket.send_multipart( + [remote_address, b"1"]) + tensor = None + logger.warning( + "🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, " + "data:%s", self.zmq_address, + remote_address.decode(), data) + + with self.recv_store_cv: + self.recv_store[tensor_id] = tensor + self.recv_store_cv.notify() + + elif data["cmd"] == "GET": + tensor_id = data["tensor_id"] + with self.send_store_cv: + tensor = self.send_store.pop(tensor_id, None) + if tensor is not None: + data = { + "ret": 0, + "shape": tensor.shape, + "dtype": + str(tensor.dtype).replace("torch.", "") + } + # LRU + self.send_store[tensor_id] = tensor + else: + data = {"ret": 1} + + self.router_socket.send_multipart( + [remote_address, msgpack.dumps(data)]) + + if data["ret"] == 0: + self._send(comm, tensor.to(self.device), rank ^ 1) + + logger.info( + "🔵[GET]Send Tensor, %s👉%s, " + "MyRank:%s, data:%s", self.zmq_address, + remote_address.decode(), rank, data) + else: + logger.warning( + "🚧Unexpected, Received message from %s, data:%s", + remote_address, data) + + # Asynchronous sending may cause conflicts between P2P NCCL and + # NCCL used in TP/PP, which can lead to deadlock issues. + def _send_async(self): + while True: + with self.send_queue_cv: + while not self.send_queue: + self.send_queue_cv.wait() + tensor_id, remote_address, tensor = self.send_queue.popleft() + if not self.send_queue: + self.send_queue_cv.notify() + self._send_sync(tensor_id, tensor, remote_address) + + def wait_for_sent(self): + if self.send_type == "PUT_ASYNC": + start_time = time.time() + with self.send_queue_cv: + while self.send_queue: + self.send_queue_cv.wait() + duration = time.time() - start_time + logger.info( + "🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue" + " to be empty, rank:%d", duration * 1000, self.rank) + + def _send_sync( + self, + tensor_id: str, + tensor: torch.Tensor, + remote_address: typing.Optional[str] = None, + ) -> bool: + if remote_address is None: + return False + if remote_address not in self.socks: + self._create_connect(remote_address) + + sock = self.socks[remote_address] + comm, rank = self.comms[remote_address] + data = { + "cmd": "PUT", + "tensor_id": tensor_id, + "shape": tensor.shape, + "dtype": str(tensor.dtype).replace("torch.", "") + } + sock.send(msgpack.dumps(data)) + + response = sock.recv() + if response != b"0": + # with self.send_queue_cv: + # self.send_queue.append([tensor_id, remote_address, tensor]) + # self.send_queue_cv.notify() + logger.warning( + "🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, " + "MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s", + self.zmq_address, remote_address, rank, data, tensor.shape, + tensor.element_size() * tensor.numel() / 1024**3, + response.decode()) + return False + + self._send(comm, tensor.to(self.device), rank ^ 1) + logger.info("🔵Send Tensor, %s👉%s, MyRank:%s, data:%s, tensor:%s", + self.zmq_address, remote_address, rank, data, tensor.shape) + return True + + def _ping(self): + sock = self.context.socket(zmq.DEALER) + sock.setsockopt_string(zmq.IDENTITY, self.zmq_address) + logger.debug("ping start, zmq_address:%s", self.zmq_address) + sock.connect(f"tcp://{self.proxy_address}") + data = { + "type": "P" if self.config.is_kv_producer else "D", + "http_address": self.http_address, + "zmq_address": self.zmq_address + } + while True: + sock.send(msgpack.dumps(data)) + time.sleep(3) + + def _send(self, comm, tensor: torch.Tensor, dst: int, stream=None): + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = current_stream() + + with self.comm_cv: + flagcx_stream = self.flagcx.adaptor_stream_copy(stream) + self.flagcx.flagcxSend(buffer_type(tensor.data_ptr()), tensor.numel(), + flagcxDataTypeEnum.from_torch(tensor.dtype), dst, + comm, flagcx_stream) + self.flagcx.adaptor_stream_free(flagcx_stream) + + def _recv(self, comm, tensor: torch.Tensor, src: int, stream=None): + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = current_stream() + + with self.comm_cv: + flagcx_stream = self.flagcx.adaptor_stream_copy(stream) + self.flagcx.flagcxRecv(buffer_type(tensor.data_ptr()), tensor.numel(), + flagcxDataTypeEnum.from_torch(tensor.dtype), src, + comm, flagcx_stream) + self.flagcx.adaptor_stream_free(flagcx_stream) + + def close(self) -> None: + self._listener_thread.join() + if self.send_type == "PUT_ASYNC": + self._send_thread.join() + if self._ping_thread is not None: + self._ping_thread.join() \ No newline at end of file diff --git a/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_pipe/p2p_nccl_pipe.py b/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_pipe/p2p_nccl_pipe.py new file mode 100644 index 000000000..2451cd317 --- /dev/null +++ b/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_pipe/p2p_nccl_pipe.py @@ -0,0 +1,445 @@ +# Copied adopted from https://github.com/vllm-project/vllm/blob/1ad957950ffc1552af5abda78c03d88ddb67945b/vllm/distributed/kv_transfer/kv_pipe/p2p_nccl_pipe.py. +# Below is the original copyright: +# SPDX-License-Identifier: Apache-2.0 +import logging +import threading +import time +import typing +from collections import deque +from typing import Any, Deque, Dict, List, Optional + +import msgpack +import torch +import zmq + +from vllm.config import KVTransferConfig +from vllm.distributed.device_communicators.pynccl_wrapper import ( + NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum) +from vllm.utils import current_stream, get_ip + +logger = logging.getLogger(__name__) + + +class P2pNcclPipe: + + def __init__(self, + local_rank: int, + config: KVTransferConfig, + hostname: str = "", + port_offset: int = 0, + library_path: Optional[str] = None) -> None: + self.config = config + self.rank = port_offset + self.local_rank = local_rank + self.device = torch.device(f"cuda:{self.local_rank}") + self.nccl = NCCLLibrary(library_path) + + if not hostname: + hostname = get_ip() + port = self.config.kv_port + port_offset + if port == 0: + raise ValueError("Port cannot be 0") + self._hostname = hostname + self._port = port + + # Each card corresponds to a ZMQ address. + self.zmq_address = f"{self._hostname}:{self._port}" + + # The `http_port` must be consistent with the port of OpenAI. + self.http_address = ( + f"{self._hostname}:" + f"{self.config.kv_connector_extra_config['http_port']}") + + # If `proxy_ip` or `proxy_port` is `""`, + # then the ping thread will not be enabled. + proxy_ip = self.config.get_from_extra_config("proxy_ip", "") + proxy_port = self.config.get_from_extra_config("proxy_port", "") + if proxy_ip == "" or proxy_port == "": + self.proxy_address = "" + else: + self.proxy_address = proxy_ip + ":" + proxy_port + + self.context = zmq.Context() + self.router_socket = self.context.socket(zmq.ROUTER) + self.router_socket.bind(f"tcp://{self.zmq_address}") + + self.poller = zmq.Poller() + self.poller.register(self.router_socket, zmq.POLLIN) + + self.send_store_cv = threading.Condition() + self.send_queue_cv = threading.Condition() + self.recv_store_cv = threading.Condition() + + self.send_stream = torch.cuda.Stream() + self.recv_stream = torch.cuda.Stream() + + # The sending type includes tree mutually exclusive options: + # PUT, GET, PUT_ASYNC. + self.send_type = self.config.get_from_extra_config("send_type", "PUT") + if self.send_type == "GET": + self.send_store: Dict[str, + torch.Tensor] = {} # tensor_id: torch.Tensor + else: + # PUT or PUT_ASYNC + self.send_queue: Deque[ + List[Any]] = deque() # tensor_id: torch.Tensor + if self.send_type == "PUT_ASYNC": + self._send_thread = threading.Thread(target=self._send_async, + daemon=True) + self._send_thread.start() + + self.recv_store: Dict[str, + torch.Tensor] = {} # tensor_id: torch.Tensor + self.socks: Dict[str, Any] = {} # remote_address: client socket + self.comms: Dict[str, Any] = {} # remote_address: (ncclComm_t, rank) + + self.buffer_size = 0 + self.buffer_size_threshold = self.config.kv_buffer_size + + self._listener_thread = threading.Thread( + target=self._listen_for_requests, daemon=True) + self._listener_thread.start() + + self._ping_thread = None + if port_offset == 0 and self.proxy_address != "": + self._ping_thread = threading.Thread(target=self._ping, + daemon=True) + self._ping_thread.start() + + def _create_connect(self, remote_address: typing.Optional[str] = None): + assert remote_address is not None + if remote_address not in self.socks: + sock = self.context.socket(zmq.DEALER) + sock.setsockopt_string(zmq.IDENTITY, self.zmq_address) + sock.connect(f"tcp://{remote_address}") + self.socks[remote_address] = sock + if remote_address in self.comms: + logger.info("👋comm exists, remote_address:%s, comms:%s", + remote_address, self.comms) + return sock, self.comms[remote_address] + + unique_id = self.nccl.ncclGetUniqueId() + data = {"cmd": "NEW", "unique_id": bytes(unique_id.internal)} + sock.send(msgpack.dumps(data)) + + with torch.cuda.device(self.device): + rank = 0 + comm: ncclComm_t = self.nccl.ncclCommInitRank( + 2, unique_id, rank) + self.comms[remote_address] = (comm, rank) + logger.info("🤝ncclCommInitRank Success, %s👉%s, MyRank: %s", + self.zmq_address, remote_address, rank) + + return self.socks[remote_address], self.comms[remote_address] + + def send_tensor( + self, + tensor_id: str, + tensor: torch.Tensor, + remote_address: typing.Optional[str] = None, + ) -> bool: + if remote_address is None: + with self.recv_store_cv: + self.recv_store[tensor_id] = tensor + self.recv_store_cv.notify() + return True + else: + if self.send_type == "PUT": + return self._send_sync(tensor_id, tensor, remote_address) + elif self.send_type == "PUT_ASYNC": + with self.send_queue_cv: + self.send_queue.append([tensor_id, remote_address, tensor]) + self.send_queue_cv.notify() + else: # GET + with self.send_store_cv: + tensor_size = tensor.element_size() * tensor.numel() + while (self.buffer_size + tensor_size + > self.buffer_size_threshold): + oldest_tenser_id = next(iter(self.send_store)) + oldest_tenser = self.send_store.pop(oldest_tenser_id) + oldest_tenser_size = oldest_tenser.element_size( + ) * oldest_tenser.numel() + self.buffer_size -= oldest_tenser_size + logger.info( + "⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d," + " buffer_size:%d, oldest_tenser_size:%d, rank:%d", + remote_address, tensor_id, tensor_size, + self.buffer_size, oldest_tenser_size, self.rank) + + self.send_store[tensor_id] = tensor + self.buffer_size += tensor_size + logger.info( + "🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, " + "shape:%s, rank:%d, buffer_size:%d(%.2f%%)", + remote_address, tensor_id, tensor_size, tensor.shape, + self.rank, self.buffer_size, + self.buffer_size / self.buffer_size_threshold * 100) + + return True + + def recv_tensor( + self, + tensor_id: str, + remote_address: typing.Optional[str] = None, + ) -> torch.Tensor: + if self.send_type == "PUT" or self.send_type == "PUT_ASYNC": + start_time = time.time() + with self.recv_store_cv: + while tensor_id not in self.recv_store: + self.recv_store_cv.wait() + tensor = self.recv_store[tensor_id] + self.recv_store[tensor_id] = None + while len(self.recv_store) > 10000: + self.recv_store.pop(next(iter(self.recv_store))) + + duration = time.time() - start_time + if tensor is not None: + self.buffer_size -= (tensor.element_size() * tensor.numel()) + logger.info( + "🔵[PUT]Recv From %s, tensor_id:%s, shape:%s, " + "duration:%.3fms, size:%.3fGB, rank:%d", remote_address, + tensor_id, tensor.shape, duration * 1000, + tensor.element_size() * tensor.numel() / 1024**3, + self.rank) + else: + logger.warning( + "🔴[PUT]Recv From %s, tensor_id:%s, duration:%.3fms, " + "rank:%d", remote_address, tensor_id, duration * 1000, + self.rank) + return tensor + + # GET + if remote_address is None: + return None + + if remote_address not in self.socks: + self._create_connect(remote_address) + + sock = self.socks[remote_address] + comm, rank = self.comms[remote_address] + + data = {"cmd": "GET", "tensor_id": tensor_id} + sock.send(msgpack.dumps(data)) + + message = sock.recv() + data = msgpack.loads(message) + if data["ret"] != 0: + logger.warning("🔴[GET]Recv From %s, tensor_id: %s, ret: %d", + remote_address, tensor_id, data["ret"]) + return None + + tensor = torch.empty(data["shape"], + dtype=getattr(torch, data["dtype"]), + device=self.device) + + start_time = time.time() + self._recv(comm, tensor, rank ^ 1, self.recv_stream) + duration = time.time() - start_time + logger.info( + "🔵[GET]Recv From %s, tensor_id:%s, shape:%s, duration:%.3fms, " + "size:%.3fGB, rank:%d", remote_address, tensor_id, tensor.shape, + duration * 1000, + tensor.element_size() * tensor.numel() / 1024**3, self.rank) + + return tensor + + def _listen_for_requests(self): + while True: + socks = dict(self.poller.poll()) + if self.router_socket in socks: + remote_address, message = self.router_socket.recv_multipart() + data = msgpack.loads(message) + logger.debug("Received message from %s, data:%s", + remote_address.decode(), data) + if data["cmd"] == "NEW": + unique_id = self.nccl.unique_id_from_bytes( + bytes(data["unique_id"])) + with torch.cuda.device(self.device): + rank = 1 + comm: ncclComm_t = self.nccl.ncclCommInitRank( + 2, unique_id, rank) + self.comms[remote_address.decode()] = (comm, rank) + logger.info( + "🤝ncclCommInitRank Success, %s👈%s, MyRank:%s", + self.zmq_address, remote_address.decode(), rank) + elif data["cmd"] == "PUT": + tensor_id = data["tensor_id"] + try: + tensor = torch.empty(data["shape"], + dtype=getattr( + torch, data["dtype"]), + device=self.device) + + tensor_size = tensor.element_size() * tensor.numel() + if (self.buffer_size + tensor_size + > self.buffer_size_threshold): + self.router_socket.send_multipart( + [remote_address, b"2"]) + logger.warning( + "🔴[PUT]Recv Tensor, Out Of Threshold, " + "%s👈%s, data:%s", self.zmq_address, + remote_address.decode(), data) + tensor = None + else: + self.buffer_size += tensor_size + self.router_socket.send_multipart( + [remote_address, b"0"]) + comm, rank = self.comms[remote_address.decode()] + self._recv(comm, tensor, rank ^ 1, + self.recv_stream) + logger.info( + "🔵[PUT]Recv Tensor, %s👈%s, MyRank:%s, " + "data:%s, shape:%s", self.zmq_address, + remote_address.decode(), rank, data, + tensor.shape) + + except torch.cuda.OutOfMemoryError: + self.router_socket.send_multipart( + [remote_address, b"1"]) + tensor = None + logger.warning( + "🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, " + "data:%s", self.zmq_address, + remote_address.decode(), data) + + with self.recv_store_cv: + self.recv_store[tensor_id] = tensor + self.recv_store_cv.notify() + + elif data["cmd"] == "GET": + tensor_id = data["tensor_id"] + with self.send_store_cv: + tensor = self.send_store.pop(tensor_id, None) + if tensor is not None: + data = { + "ret": 0, + "shape": tensor.shape, + "dtype": + str(tensor.dtype).replace("torch.", "") + } + # LRU + self.send_store[tensor_id] = tensor + else: + data = {"ret": 1} + + self.router_socket.send_multipart( + [remote_address, msgpack.dumps(data)]) + + if data["ret"] == 0: + comm, rank = self.comms[remote_address.decode()] + self._send(comm, tensor.to(self.device), rank ^ 1, + self.send_stream) + + logger.info( + "🔵[GET]Send Tensor, %s👉%s, " + "MyRank:%s, data:%s", self.zmq_address, + remote_address.decode(), rank, data) + else: + logger.warning( + "🚧Unexpected, Received message from %s, data:%s", + remote_address, data) + + def _send_async(self): + while True: + with self.send_queue_cv: + while not self.send_queue: + self.send_queue_cv.wait() + tensor_id, remote_address, tensor = self.send_queue.popleft() + if not self.send_queue: + self.send_queue_cv.notify() + self._send_sync(tensor_id, tensor, remote_address) + + def wait_for_sent(self): + if self.send_type == "PUT_ASYNC": + start_time = time.time() + with self.send_queue_cv: + while self.send_queue: + self.send_queue_cv.wait() + duration = time.time() - start_time + logger.info( + "🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue" + " to be empty, rank:%d", duration * 1000, self.rank) + + def _send_sync( + self, + tensor_id: str, + tensor: torch.Tensor, + remote_address: typing.Optional[str] = None, + ) -> bool: + if remote_address is None: + return False + if remote_address not in self.socks: + self._create_connect(remote_address) + + sock = self.socks[remote_address] + comm, rank = self.comms[remote_address] + data = { + "cmd": "PUT", + "tensor_id": tensor_id, + "shape": tensor.shape, + "dtype": str(tensor.dtype).replace("torch.", "") + } + sock.send(msgpack.dumps(data)) + + response = sock.recv() + if response != b"0": + # with self.send_queue_cv: + # self.send_queue.append([tensor_id, remote_address, tensor]) + # self.send_queue_cv.notify() + logger.warning( + "🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, " + "MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s", + self.zmq_address, remote_address, rank, data, tensor.shape, + tensor.element_size() * tensor.numel() / 1024**3, + response.decode()) + return False + + self._send(comm, tensor.to(self.device), rank ^ 1, self.send_stream) + logger.info("🔵Send Tensor, %s👉%s, MyRank:%s, data:%s, tensor:%s", + self.zmq_address, remote_address, rank, data, tensor.shape) + return True + + def _ping(self): + sock = self.context.socket(zmq.DEALER) + sock.setsockopt_string(zmq.IDENTITY, self.zmq_address) + logger.debug("ping start, zmq_address:%s", self.zmq_address) + sock.connect(f"tcp://{self.proxy_address}") + data = { + "type": "P" if self.config.is_kv_producer else "D", + "http_address": self.http_address, + "zmq_address": self.zmq_address + } + while True: + sock.send(msgpack.dumps(data)) + time.sleep(3) + + def _send(self, comm, tensor: torch.Tensor, dst: int, stream=None): + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = current_stream() + + with torch.cuda.stream(stream): + self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), dst, + comm, cudaStream_t(stream.cuda_stream)) + + def _recv(self, comm, tensor: torch.Tensor, src: int, stream=None): + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = current_stream() + + with torch.cuda.stream(stream): + self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), src, + comm, cudaStream_t(stream.cuda_stream)) + + def close(self) -> None: + self._listener_thread.join() + if self.send_type == "PUT_ASYNC": + self._send_thread.join() + if self._ping_thread is not None: + self._ping_thread.join() \ No newline at end of file diff --git a/flagscale/runner/runner_serve.py b/flagscale/runner/runner_serve.py index 2d2308a41..8fce37cba 100644 --- a/flagscale/runner/runner_serve.py +++ b/flagscale/runner/runner_serve.py @@ -1,5 +1,6 @@ import asyncio import contextlib +import copy import json import os import shlex @@ -12,8 +13,10 @@ from flagscale.runner.runner_base import JobStatus, RunnerBase from flagscale.runner.utils import ( + ResourceManager, benchmark, dummy_random_input, + flatten_dict_to_args, get_free_port, get_nproc_per_node, logger, @@ -22,6 +25,16 @@ ) +def _get_multiple_free_ports(num=1, exclude_ports=[]): + allocated_ports = [] + for i in range(num): + port = get_free_port() + while port in allocated_ports or port in exclude_ports: + port = get_free_port() + allocated_ports.append(port) + return allocated_ports + + def _get_args_vllm(config: DictConfig): # see the following link for more details # https://github.com/facebookresearch/hydra/discussions/2750 @@ -97,6 +110,8 @@ def _get_engine_args(config, model="vllm_model"): def _update_config_serve(config: DictConfig): + deploy_config = config.experiment.get("deploy", {}) + exp_dir = os.path.abspath(config.experiment.exp_dir) if not os.path.isdir(exp_dir): os.makedirs(exp_dir) @@ -104,6 +119,9 @@ def _update_config_serve(config: DictConfig): OmegaConf.set_struct(config, False) + if deploy_config.get("prefill_decode_disaggregation", False): + deploy_config["pd_proxy_port"] = get_free_port() + if config.get("logging", None) is None: config.logging = DictConfig({}) @@ -150,6 +168,8 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi vllm_path = os.path.dirname(vllm.__path__[0]) except Exception as e: vllm_path = f"{root_dir}/vllm" + deploy_config = config.experiment.get("deploy", {}) + envs = config.experiment.get("envs", {}) with open(host_run_script_file, "w") as f: f.write("#!/bin/bash\n\n") f.write("set -x\n") @@ -163,15 +183,191 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi f.write(f' export PYTHONPATH="$PYTHONPATH:{vllm_path}:{root_dir}"\n') f.write(f"fi\n") f.write(f"\n") + envs_str = " && ".join(f"export {key}={value}" for key, value in envs.items()) if nodes: - f.write(f"ray_path=$(realpath $(which ray))\n") - master_ip = nodes[0][0] - target_port = nodes[0][1].get("port") + if deploy_config.get("prefill_decode_disaggregation", False): + resource_manager = ResourceManager(nodes) + master_ip = nodes[0][0] + target_port = nodes[0][1].get("port") + p_num = deploy_config.get("prefill_num", 1) + d_num = deploy_config.get("decode_num", 1) + ports_num = (p_num + d_num) * 2 + kv_related_ports = _get_multiple_free_ports(ports_num) + pd_proxy_port = deploy_config.get("pd_proxy_port", None) + if not pd_proxy_port: + raise ValueError(f"PD disaggregation requires a proxy port to be set.") + + engine_args = _get_engine_args(config) + command_items = ["vllm", "serve"] + command_items.append(engine_args["model"]) + other_args = flatten_dict_to_args(engine_args, ["model", "port"]) + command_items.extend(other_args) + vllm_command = " ".join(command_items) + if before_start_cmd: + vllm_command = f"{before_start_cmd} && " + vllm_command + if envs_str: + vllm_command = f"{envs_str} && " + vllm_command + p_address = deploy_config.get("prefill_address", "auto") + d_address = deploy_config.get("decode_address", "auto") + tensor_parallel_size = engine_args.get("tensor_parallel_size", 1) + pipeline_parallel_size = engine_args.get("pipeline_parallel_size", 1) + each_instance_card_num = tensor_parallel_size * pipeline_parallel_size + default_log_dir = deploy_config.get( + "prefill_decode_log_dir", logging_config.log_dir + ) + + f.write(f"# clean nodes \n") + if len(nodes) > 1: + for ip, node in nodes[1:]: + if not node.get("type", None): + raise ValueError( + f"Node type must be specified for node {node}. Available types are 'cpu', 'gpu', or a custom resource name." + ) + if not node.get("slots", None): + raise ValueError( + f"Number of slots must be specified for node {node}. This can be done by setting the 'slots' attribute." + ) + node_cmd = f"mkdir -p {default_log_dir} && pkill -f vllm" + + ssh_cmd = f'ssh -n -p {ssh_port} {ip} "{node_cmd}"' + + if docker_name: + ssh_cmd = f"ssh -n -p {ssh_port} {ip} \"docker exec {docker_name} /bin/bash -c '{node_cmd}'\"" + f.write(f"{ssh_cmd}\n") + + f.write("pkill -f 'run_inference_engine'\n") + f.write("pkill -f 'run_fs_serve_vllm'\n") + f.write("pkill -f 'vllm serve'\n") + f.write("pkill -f 'run_disagg_xpyd_router'\n") + f.write(f"mkdir -p {default_log_dir}\n") + f.write(f"\n") + + f.write("echo '=========== launch prefill instance ==========='\n") + + for i in range(p_num): + kv_port = kv_related_ports.pop() + http_port = kv_related_ports.pop() + p_kv_config = { + "kv_connector": "P2pConnector", + "kv_role": "kv_producer", + "kv_port": str(kv_port), + "kv_connector_extra_config": { + "proxy_ip": master_ip, + "proxy_port": str(pd_proxy_port), + "http_port": str(http_port), + }, + } + logger.info( + f"============= prefill instance {i}, p_kv_config: {p_kv_config} =============" + ) + card_ids = resource_manager.get_available_card_ids( + address=p_address, num=each_instance_card_num + ) + card_ids_str = ",".join(map(str, card_ids)) + ids_env = f"export CUDA_VISIBLE_DEVICES={card_ids_str}" + + p_kv_config_json = json.dumps(p_kv_config) + p_instance_log_path = os.path.join(default_log_dir, f"prefill_{i}.log") + + if p_address != master_ip: + p_kv_config_formate_json = p_kv_config_json.replace('"', '\\"') + node_cmd = f"{ids_env} && {vllm_command} --port {http_port} --kv-transfer-config '\\''{p_kv_config_formate_json}'\\''" + if docker_name: + ssh_cmd = f"ssh -f -n -p {ssh_port} {ip} \"docker exec {docker_name} /bin/bash -c '{node_cmd} > {p_instance_log_path} 2>&1 &'\"" + else: + ssh_cmd = f'ssh -f -n -p {ssh_port} {d_address} "{node_cmd} > {p_instance_log_path} 2>&1 &"' + f.write(f"{ssh_cmd}\n\n") + else: + p_cmd = f"{ids_env} && {vllm_command} --port {http_port} --kv-transfer-config '\\''{p_kv_config_json}'\\''" + f.write(f"p_{i}_cmd='{p_cmd}'\n") + f.write(f"\n") + f.write( + f'nohup bash -c "$p_{i}_cmd; sync" >> {p_instance_log_path} 2>&1 &\n\n' + ) - f.write(f"# clean nodes \n") - if len(nodes) > 1: - for ip, node in nodes[1:]: + f.write("echo '=========== launch decode instance ==========='\n") + + for j in range(d_num): + kv_port = kv_related_ports.pop() + http_port = kv_related_ports.pop() + d_kv_config = { + "kv_connector": "P2pConnector", + "kv_role": "kv_consumer", + "kv_port": str(kv_port), + "kv_connector_extra_config": { + "proxy_ip": master_ip, + "proxy_port": str(pd_proxy_port), + "http_port": str(http_port), + }, + } + logger.info( + f"============= decode instance {i}, d_kv_config: {d_kv_config} =============" + ) + card_ids = resource_manager.get_available_card_ids( + address=d_address, num=each_instance_card_num + ) + card_ids_str = ",".join(map(str, card_ids)) + ids_env = f"export CUDA_VISIBLE_DEVICES={card_ids_str}" + + d_kv_config_json = json.dumps(d_kv_config) + d_instance_log_path = os.path.join(default_log_dir, f"decode_{j}.log") + + if d_address != master_ip: + d_kv_config_formate_json = d_kv_config_json.replace('"', '\\"') + node_cmd = f"{ids_env} && {vllm_command} --port {http_port} --kv-transfer-config '\\''{d_kv_config_formate_json}'\\''" + if docker_name: + ssh_cmd = f"ssh -f -n -p {ssh_port} {ip} \"docker exec {docker_name} /bin/bash -c '{node_cmd} > {d_instance_log_path} 2>&1 &'\"" + else: + ssh_cmd = f'ssh -f -n -p {ssh_port} {d_address} "{node_cmd} > {d_instance_log_path} 2>&1 &"' + f.write(f"{ssh_cmd}\n\n") + else: + d_cmd = f"{ids_env} && {vllm_command} --port {http_port} --kv-transfer-config '\\''{d_kv_config_json}'\\''" + f.write(f"d_{j}_cmd='{d_cmd}'\n") + f.write(f"\n") + f.write( + f'nohup bash -c "$d_{i}_cmd; sync" >> {d_instance_log_path} 2>&1 &\n\n' + ) + + else: + f.write(f"ray_path=$(realpath $(which ray))\n") + master_ip = nodes[0][0] + target_port = nodes[0][1].get("port") + + f.write(f"# clean nodes \n") + if len(nodes) > 1: + for ip, node in nodes[1:]: + if not node.get("type", None): + raise ValueError( + f"Node type must be specified for node {node}. Available types are 'cpu', 'gpu', or a custom resource name." + ) + if not node.get("slots", None): + raise ValueError( + f"Number of slots must be specified for node {node}. This can be done by setting the 'slots' attribute." + ) + node_cmd = f"${{ray_path}} stop" + + if before_start_cmd: + node_cmd = f"{before_start_cmd} && " + node_cmd + + ssh_cmd = f'ssh -n -p {ssh_port} {ip} "{node_cmd}"' + + if docker_name: + ssh_cmd = f"ssh -n -p {ssh_port} {ip} \"docker exec {docker_name} /bin/bash -c '{node_cmd}'\"" + f.write(f"{ssh_cmd}\n") + if before_start_cmd: + f.write(f"{before_start_cmd} && ${{ray_path}} stop\n") + else: + f.write(f"${{ray_path}} stop\n") + f.write("pkill -f 'run_inference_engine'\n") + f.write("pkill -f 'run_fs_serve_vllm'\n") + f.write("pkill -f 'vllm serve'\n") + f.write(f"\n") + + master_port = target_port if target_port else get_free_port() + + address = f"{master_ip}:{master_port}" + for index, (ip, node) in enumerate(nodes): if not node.get("type", None): raise ValueError( f"Node type must be specified for node {node}. Available types are 'cpu', 'gpu', or a custom resource name." @@ -180,45 +376,21 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi raise ValueError( f"Number of slots must be specified for node {node}. This can be done by setting the 'slots' attribute." ) - node_cmd = f"${{ray_path}} stop" + if index == 0: + # master node + f.write(f"# start cluster\n") + f.write(f"# master node\n") + if node.type == "gpu": + node_cmd = f"${{ray_path}} start --head --port={master_port} --num-gpus={node.slots}" + elif node.type == "cpu": + node_cmd = f"${{ray_path}} start --head --port={master_port} --num-cpus={node.slots}" + else: + resource = json.dumps({node.type: node.slots}).replace('"', '\\"') + node_cmd = f"${{ray_path}} start --head --port={master_port} --resources='{resource}'" + if before_start_cmd: + node_cmd = f"{before_start_cmd} && " + node_cmd + f.write(f"{node_cmd}\n") - if before_start_cmd: - node_cmd = f"{before_start_cmd} && " + node_cmd - - ssh_cmd = f'ssh -n -p {ssh_port} {ip} "{node_cmd}"' - - if docker_name: - ssh_cmd = f"ssh -n -p {ssh_port} {ip} \"docker exec {docker_name} /bin/bash -c '{node_cmd}'\"" - f.write(f"{ssh_cmd}\n") - if before_start_cmd: - f.write(f"{before_start_cmd} && ${{ray_path}} stop\n") - else: - f.write(f"${{ray_path}} stop\n") - f.write("pkill -f 'run_inference_engine'\n") - f.write("pkill -f 'run_fs_serve_vllm'\n") - f.write("pkill -f 'vllm serve'\n") - f.write(f"\n") - - master_port = target_port if target_port else get_free_port() - - address = f"{master_ip}:{master_port}" - for index, (ip, node) in enumerate(nodes): - if not node.get("type", None): - raise ValueError( - f"Node type must be specified for node {node}. Available types are 'cpu', 'gpu', or a custom resource name." - ) - if not node.get("slots", None): - raise ValueError( - f"Number of slots must be specified for node {node}. This can be done by setting the 'slots' attribute." - ) - if index == 0: - # master node - f.write(f"# start cluster\n") - f.write(f"# master node\n") - if node.type == "gpu": - node_cmd = f"${{ray_path}} start --head --port={master_port} --num-gpus={node.slots}" - elif node.type == "cpu": - node_cmd = f"${{ray_path}} start --head --port={master_port} --num-cpus={node.slots}" else: resource = json.dumps({node.type: node.slots}).replace('"', '\\"') node_cmd = f"${{ray_path}} start --head --port={master_port} --resources='{resource}'" @@ -247,9 +419,7 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi ) if before_start_cmd: node_cmd = f"{before_start_cmd} && " + node_cmd - ssh_cmd = f'ssh -n -p {ssh_port} {ip} "{node_cmd}"' - if docker_name: ssh_cmd = f"ssh -n -p {ssh_port} {ip} \"docker exec {docker_name} /bin/bash -c '{node_cmd}'\"" f.write(f"{ssh_cmd}\n") @@ -267,7 +437,7 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi f"nproc_per_node must be specified when device_type {device_type} is specified." ) node_cmd = None - deploy_config = config.experiment.get("deploy", {}) + if deploy_config.get("use_fs_serve", True) and config.serve[0].get("engine", None): f.write(f"ray_path=$(realpath $(which ray))\n") if not device_type: @@ -293,6 +463,7 @@ def _generate_run_script_serve(config, host, node_rank, cmd, background=True, wi f.write(f"\n") # TODO: need a option to control whether to append or overwrite the output file # Now, it always appends to the output file + f.write("echo '=========== launch task ==========='\n") if background: f.write( f'nohup bash -c "$cmd; sync" >> {host_output_file} 2>&1 & echo $! > {host_pid_file}\n' @@ -375,7 +546,9 @@ def _prepare(self): self.user_envs = self.config.experiment.get("envs", {}) entrypoint = self.config.experiment.task.get("entrypoint", None) if self.inference_engine: - if not self.use_fs_serve: + if self.config.experiment.get("deploy", {}).get("prefill_decode_disaggregation", False): + self.user_script = "flagscale/serve/run_disagg_xpyd_router.py" + elif not self.use_fs_serve: self.user_script = "flagscale/serve/run_inference_engine.py" else: self.user_script = "flagscale/serve/run_fs_serve_vllm.py" @@ -471,7 +644,6 @@ def _stop_each(self, host, node_rank): kill_process_tree(pid) ray_executable = shutil.which("ray") - print(ray_executable) if ray_executable: ray_path = os.path.realpath(ray_executable) os.system(f"{ray_path} stop") diff --git a/flagscale/runner/utils.py b/flagscale/runner/utils.py index b03dc8ffa..ae249e3e9 100644 --- a/flagscale/runner/utils.py +++ b/flagscale/runner/utils.py @@ -568,3 +568,128 @@ def process_one_metric( print("=" * 50) return result + + +class ResourceManager: + def __init__(self, nodes): + """ + Initialize the ResourceManager with a list of nodes. + Each element in the list should be a two-item list: + - The first item is the node address (a string). + - The second item is a dictionary containing at least the key "slots". + If "type" is not provided, it defaults to "gpu" with a warning. + The first node is treated as the master node, and the rest are worker nodes. + """ + self.nodes = self._initialize_nodes(nodes) + + def _initialize_nodes(self, nodes): + """ + Convert the input nodes list into the internal nodes representation. + Each node is converted into a dictionary with keys: + "address", "slots", "type", and "used" (initialized to 0). + If the "type" is not provided in a node, default it to "gpu" and issue a warning. + """ + initialized_nodes = [] + for node in nodes: + if len(node) != 2: + raise ValueError("Each node must include an address and node data") + address, info = node + if "slots" not in info: + raise ValueError("Node data must contain 'slots'") + if "type" not in info: + logger.warning( + f"Node {address} does not provide a resource type. Defaulting to 'gpu'." + ) + resource_type = info.get("type", "gpu") + initialized_nodes.append( + { + "address": address, + "slots": info["slots"], + "type": resource_type, + "used": 0, # Initialize used slot count to 0 + } + ) + return initialized_nodes + + def get_whole_card_num(self, resource_type="gpu"): + """ + Return the total number of slots across all nodes with the specified resource type. + The return type is int. + """ + total = 0 + for node in self.nodes: + if node["type"] == resource_type: + total += node["slots"] + return total + + def get_available_card_num(self, resource_type="gpu"): + """ + Return the total number of available slots (slots minus used) across all nodes with the specified resource type. + The return type is int. + """ + total = 0 + for node in self.nodes: + if node["type"] == resource_type: + total += node["slots"] - node["used"] + return total + + def get_available_card_ids(self, resource_type="gpu", address="auto", num=1): + """ + Allocate 'num' resource cards from a node and return a list of card indices. + + For the default case (address="auto"), traverse nodes in order: master node first, then worker nodes. + - If a node's available slots (slots - used) are >= num, allocate num consecutive indices (based on the current used value) + and update the node's used count, returning the allocated indices (0-indexed) as a list. + - If the available slots are insufficient at a particular node and address is "auto", continue searching through other nodes. + - If an explicit address is provided, check only that node; if it doesn't exist or lacks sufficient available slots, raise an error. + - If none of the nodes can satisfy the request, raise an error indicating insufficient resources. + """ + # Check the specified node if address is not "auto" + if address != "auto": + node_found = None + for node in self.nodes: + if node["address"] == address and node["type"] == resource_type: + node_found = node + break + if node_found is None: + raise ValueError(f"Node {address} does not exist or resource type mismatch") + free = node_found["slots"] - node_found["used"] + if free < num: + raise ValueError("Insufficient resources") + allocated_ids = list(range(node_found["used"], node_found["used"] + num)) + node_found["used"] += num + return allocated_ids + + # For address == "auto", traverse all nodes (master node first, then worker nodes) + for node in self.nodes: + if node["type"] == resource_type: + free = node["slots"] - node["used"] + if free >= num: + allocated_ids = list(range(node["used"], node["used"] + num)) + node["used"] += num + return allocated_ids + + # If no node satisfies the allocation request, raise an error. + resource_status = self.get_status() + raise ValueError( + f"Require number {num} of resource_type {resource_type} But there is insufficient resources: \n{resource_status}" + ) + + def get_status(self): + """ + Return the status of all nodes as a dictionary. + Each key in the returned dictionary is the node's address, and its value is a dictionary with: + - type: the resource type. + - slots: the total number of slots. + - used: the number of allocated slots. + - available: the number of available slots (slots - used). + """ + status = {} + for node in self.nodes: + status[node["address"]] = { + "type": node["type"], + "slots": node["slots"], + "used": node["used"], + "available": node["slots"] - node["used"], + } + return status diff --git a/flagscale/serve/run_disagg_xpyd_router.py b/flagscale/serve/run_disagg_xpyd_router.py new file mode 100644 index 000000000..3dd35d911 --- /dev/null +++ b/flagscale/serve/run_disagg_xpyd_router.py @@ -0,0 +1,239 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Adopted from https://github.com/vllm-project/vllm/blob/1ad957950ffc1552af5abda78c03d88ddb67945b/examples/online_serving/disagg_xpyd/disagg_prefill_proxy_xpyd.py. Below is the original copyright: +# +# SPDX-License-Identifier: Apache-2.0 +# + + +import os +import random +import socket +import threading +import uuid + +import aiohttp +import msgpack +import zmq + +from quart import Quart, make_response, request + +try: + import flag_scale +except Exception as e: + pass + +from flagscale import serve +from flagscale.logger import logger +from flagscale.utils import flatten_dict_to_args + +# Refer to https://github.com/vllm-project/vllm/pull/15806 + + +# ----------------------------------------------------------------------------- +# LoadManager: unified management of P/D instances and their load +# ----------------------------------------------------------------------------- +class LoadManager: + def __init__(self): + self._lock = threading.Lock() + # Each resource type 'P' or 'D' maps to {http_addr: {'zmq': zmq_addr, 'load': int}} + self._instances: dict[str, dict[str, dict[str, object]]] = {"P": {}, "D": {}} + + def register(self, rtype: str, http_addr: str, zmq_addr: str): + with self._lock: + if http_addr not in self._instances[rtype]: + self._instances[rtype][http_addr] = {"zmq": zmq_addr, "load": 0} + logger.info(f"Registered new {rtype}-instance {http_addr} (zmq={zmq_addr})") + else: + # If zmq address changed, synchronize it + self._instances[rtype][http_addr]["zmq"] = zmq_addr + + def increment_load(self, rtype: str, http_addr: str): + with self._lock: + self._instances[rtype][http_addr]["load"] += 1 + logger.debug( + f"[{rtype}] +1 load on {http_addr}, now={self._instances[rtype][http_addr]['load']}" + ) + + def decrement_load(self, rtype: str, http_addr: str): + with self._lock: + self._instances[rtype][http_addr]["load"] -= 1 + logger.debug( + f"[{rtype}] -1 load on {http_addr}, now={self._instances[rtype][http_addr]['load']}" + ) + + def get_random(self, rtype: str) -> tuple[str, str]: + with self._lock: + items = list(self._instances[rtype].items()) + http_addr, info = random.choice(items) + return http_addr, info["zmq"] + + def get_robin_loaded(self, rtype: str) -> tuple[str, str]: + with self._lock: + http_addr, info = min(self._instances[rtype].items(), key=lambda kv: kv[1]["load"]) + print(f"========== whole instance status {self._instances}==========", flush=True) + return http_addr, info["zmq"] + + +# ----------------------------------------------------------------------------- +# Globals & configuration +# ----------------------------------------------------------------------------- +lm = LoadManager() + +# Legacy registration dicts & Conditions retained for external waiting +prefill_instances: dict[str, str] = {} +decode_instances: dict[str, str] = {} +prefill_cv = threading.Condition() +decode_cv = threading.Condition() + +# Scheduling strategy: 'random' or 'robin' (robin load) +SCHEDULING_STRATEGY = os.environ.get("SCHEDULING_STRATEGY", "robin").lower() + + +# ----------------------------------------------------------------------------- +# Service discovery: receive instance registrations +# ----------------------------------------------------------------------------- +def _listen_for_register(poller, router_socket): + while True: + socks = dict(poller.poll()) + if router_socket in socks: + remote_addr, message = router_socket.recv_multipart() + data = msgpack.loads(message) + typ = data.get("type") + http_addr = data.get("http_address") + zmq_addr = data.get("zmq_address") + if typ == "P": + with prefill_cv: + prefill_instances[http_addr] = zmq_addr + lm.register("P", http_addr, zmq_addr) + elif typ == "D": + with decode_cv: + decode_instances[http_addr] = zmq_addr + lm.register("D", http_addr, zmq_addr) + else: + logger.warning(f"Unexpected registration message: {data}") + + +def start_service_discovery(hostname, port): + if not hostname: + hostname = socket.gethostname() + if port == 0: + raise ValueError("Port cannot be 0") + + context = zmq.Context() + router_socket = context.socket(zmq.ROUTER) + router_socket.bind(f"tcp://{hostname}:{port}") + + poller = zmq.Poller() + poller.register(router_socket, zmq.POLLIN) + + listener = threading.Thread( + target=_listen_for_register, args=[poller, router_socket], daemon=True + ) + listener.start() + return listener + + +# ----------------------------------------------------------------------------- +# HTTP proxy & request forwarding +# ----------------------------------------------------------------------------- +AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) +app = Quart(__name__) + + +def random_uuid() -> str: + return uuid.uuid4().hex + + +async def forward_request(url, data, request_id): + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id, + } + async with session.post(url=url, json=data, headers=headers) as resp: + if resp.status == 200: + async for chunk in resp.content.iter_chunked(1024): + yield chunk + else: + content = await resp.read() + yield content + + +# support both /v1/completions and /v1/chat/completions +@app.route("/v1/completions", methods=["POST"]) +@app.route("/v1/chat/completions", methods=["POST"]) +async def handle_request(): + try: + original_data = await request.get_json() + endpoint = request.path # this will be '/v1/completions' or '/v1/chat/completions' + + # Prefill request: max_tokens=1 + prefill_request = original_data.copy() + prefill_request["max_tokens"] = 1 + + # Select Prefill instance + if SCHEDULING_STRATEGY == "robin": + prefill_addr, prefill_zmq = lm.get_robin_loaded("P") + else: + prefill_addr, prefill_zmq = lm.get_random("P") + logger.info(f"Selected P-instance {prefill_addr} via '{SCHEDULING_STRATEGY}'") + + # Select Decode instance + if SCHEDULING_STRATEGY == "robin": + decode_addr, decode_zmq = lm.get_robin_loaded("D") + else: + decode_addr, decode_zmq = lm.get_random("D") + logger.info(f"Selected D-instance {decode_addr} via '{SCHEDULING_STRATEGY}'") + + # Keep original request_id composition format + request_id = f"___prefill_addr_{prefill_zmq}___decode_addr_{decode_zmq}_{random_uuid()}" + + # Execute Prefill and update load + lm.increment_load("P", prefill_addr) + try: + async for _ in forward_request( + f"http://{prefill_addr}{endpoint}", prefill_request, request_id + ): + pass + finally: + lm.decrement_load("P", prefill_addr) + + # Execute Decode and update load + async def tracked_decode(): + lm.increment_load("D", decode_addr) + try: + async for chunk in forward_request( + f"http://{decode_addr}{endpoint}", original_data, request_id + ): + yield chunk + finally: + lm.decrement_load("D", decode_addr) + + resp = await make_response(tracked_decode()) + resp.timeout = None + return resp + + except Exception as e: + logger.error("Error in proxy server", exc_info=e) + return {"error": str(e)}, 500 + + +def main(): + serve.load_args() + deploy_config = serve.task_config.experiment.get("deploy", {}) + serve_port = deploy_config.get("port", None) + # Used to register with the pd service discovery + pd_proxy_port = deploy_config.get("pd_proxy_port", None) + if not serve_port: + raise ValueError("No port specified in deploy config") + if not pd_proxy_port: + raise ValueError("No pd_proxy_port specified in deploy config") + print(f"Starting Proxy Server...with pd_proxy_port {pd_proxy_port} and serve_port {serve_port}") + listener = start_service_discovery("0.0.0.0", pd_proxy_port) + app.run(host="0.0.0.0", port=serve_port) + listener.join() + + +if __name__ == "__main__": + main()