diff --git a/docs/design/v1/p2p_nccl_connector.md b/docs/design/v1/p2p_nccl_connector.md index 32cdaacf058a..cbef2aea0784 100644 --- a/docs/design/v1/p2p_nccl_connector.md +++ b/docs/design/v1/p2p_nccl_connector.md @@ -8,7 +8,7 @@ As shown in Figure 1, the overall process of this **PD disaggregation** solution 1. The client sends an HTTP request to the Proxy/Router's `/v1/completions` interface. 2. The Proxy/Router selects a **1P1D (1 Prefill instance + 1 Decode instance)** through either through round-robin or random selection, generates a `request_id` (rules to be introduced later), modifies the `max_tokens` in the HTTP request message to **1**, and then forwards the request to the **P instance**. 3. Immediately afterward, the Proxy/Router forwards the **original HTTP request** to the **D instance**. -4. The **P instance** performs **Prefill** and then **actively sends the generated KV cache** to the D instance (using **PUT_ASYNC** mode). The D instance's `zmq_addr` can be resolved through the `request_id`. +4. The **P instance** performs **Prefill** and then **actively sends the generated KV cache** to the D instance (using **PUT** mode). The D instance's `zmq_addr` can be resolved through the `request_id`. 5. The **D instance** has a **dedicated thread** for receiving the KV cache (to avoid blocking the main process). The received KV cache is saved into the **GPU memory buffer**, the size of which is determined by the vLLM startup parameter `kv_buffer_size`. When the GPU buffer is full, the KV cache is stored in the **local Tensor memory pool**. 6. During the **Decode**, the D instance's main process retrieves the KV cache (transmitted by the P instance) from either the **GPU buffer** or the **memory pool**, thereby **skipping Prefill**. 7. After completing **Decode**, the D instance returns the result to the **Proxy/Router**, which then forwards it to the **client**. @@ -31,9 +31,9 @@ Each P/D instance periodically sends a heartbeat packet to the Proxy/Router (cur ## KV Cache Transfer Methods -There are three methods for KVcache transfer: PUT, GET, and PUT_ASYNC. These methods can be specified using the `--kv-transfer-config` and `kv_connector_extra_config` parameters, specifically through the `send_type` field. Both PUT and PUT_ASYNC involve the P instance actively sending KVcache to the D instance. The difference is that PUT is a synchronous transfer method that blocks the main process, while PUT_ASYNC is an asynchronous transfer method. PUT_ASYNC uses a dedicated thread for sending KVcache, which means it does not block the main process. In contrast, the GET method involves the P instance saving the KVcache to the memory buffer after computing the prefill. The D instance then actively retrieves the computed KVcache from the P instance once it has allocated space for the KVcache. +There are three methods for KVcache transfer: PUT and GET. These methods can be specified using the `--kv-transfer-config` and `kv_connector_extra_config` parameters, specifically through the `send_type` field. PUT involve the P instance actively sending KVcache to the D instance. PUT is an asynchronous transfer method. PUT uses a dedicated thread for sending KVcache, which means it does not block the main process. In contrast, the GET method involves the P instance saving the KVcache to the memory buffer after computing the prefill. The D instance then actively retrieves the computed KVcache from the P instance once it has allocated space for the KVcache. -Experimental results have shown that the performance of these methods, from highest to lowest, is as follows: PUT_ASYNC → GET → PUT. +Experimental results have shown that the performance of these methods, from highest to lowest, is as follows: PUT → GET. ## P2P Communication via ZMQ & NCCL @@ -53,7 +53,7 @@ Each NCCL group occupies a certain amount of GPU memory buffer for communication ## GPU Memory Buffer and Tensor Memory Pool -The trade-off in the size of the memory buffer is as follows: For P instances, the memory buffer is not required in PUT and PUT_ASYNC modes, but it is necessary in GET mode. For D instances, a memory buffer is needed in all three modes. The memory buffer for D instances should not be too large. Similarly, for P instances in GET mode, the memory buffer should also not be too large. The memory buffer of D instances is used to temporarily store KVcache sent by P instances. If it is too large, it will reduce the KVcache space available for normal inference by D instances, thereby decreasing the inference batch size and ultimately leading to a reduction in output throughput. The size of the memory buffer is configured by the parameter `kv_buffer_size`, measured in bytes, and is typically set to 5%~10% of the memory size. +The trade-off in the size of the memory buffer is as follows: For P instances, the memory buffer is not required in PUT mode, but it is necessary in GET mode. For D instances, a memory buffer is needed in all three modes. The memory buffer for D instances should not be too large. Similarly, for P instances in GET mode, the memory buffer should also not be too large. The memory buffer of D instances is used to temporarily store KVcache sent by P instances. If it is too large, it will reduce the KVcache space available for normal inference by D instances, thereby decreasing the inference batch size and ultimately leading to a reduction in output throughput. The size of the memory buffer is configured by the parameter `kv_buffer_size`, measured in bytes, and is typically set to 5%~10% of the memory size. If the `--max-num-seqs` parameter for P instances is set to a large value, due to the large batch size, P instances will generate a large amount of KVcache simultaneously. This may exceed the capacity of the memory buffer of D instances, resulting in KVcache loss. Once KVcache is lost, D instances need to recompute Prefill, which is equivalent to performing Prefill twice. Consequently, the time-to-first-token (TTFT) will significantly increase, leading to degraded performance. @@ -68,7 +68,7 @@ To address the above issues, I have designed and developed a local Tensor memory cd /home # Download the installation package, and I will update the commit-id in time. You can directly copy the command. - wget https://vllm-wheels.s3.us-west-2.amazonaws.com/9112b443a042d8d815880b8780633882ad32b183/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl + wget https://vllm-wheels.s3.us-west-2.amazonaws.com/0d06b533a0fcca7a62603c868df68235659d6935/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl # Download the code repository. git clone -b xpyd-v1 https://github.com/Abatom/vllm.git @@ -88,9 +88,9 @@ To address the above issues, I have designed and developed a local Tensor memory - Pay attention to the setting of the `kv_buffer_size` (in bytes). The empirical value is 10% of the GPU memory size. This is related to the kvcache size. If it is too small, the GPU memory buffer for temporarily storing the received kvcache will overflow, causing the kvcache to be stored in the tensor memory pool, which increases latency. If it is too large, the kvcache available for inference will be reduced, leading to a smaller batch size and decreased throughput. - For Prefill instances, when using non-GET mode, the `kv_buffer_size` can be set to 1, as Prefill currently does not need to receive kvcache. However, when using GET mode, a larger `kv_buffer_size` is required because it needs to store the kvcache sent to the D instance. - You may need to modify the `kv_buffer_size` and `port` in the following commands (if there is a conflict). -- `PUT_ASYNC` offers the best performance and should be prioritized. +- `PUT` offers the more performance and should be prioritized. - The `--port` must be consistent with the `http_port` in the `--kv-transfer-config`. -- The `disagg_prefill_proxy_xpyd.py` script will use port 10001 (for receiving client requests) and port 30001 (for receiving service discovery from P and D instances). +- The `disagg_prefill_proxy_xpyd.py` script will use port 10101 (for receiving client requests) and port 30201 (for receiving service discovery from P and D instances). - The node running the proxy must have `quart` installed. - Supports multiple nodes; you just need to modify the `proxy_ip` and `proxy_port` in `--kv-transfer-config`. - In the following examples, it is assumed that **the proxy's IP is 10.0.1.1**. @@ -123,7 +123,7 @@ python3 disagg_prefill_proxy_xpyd.py & --gpu-memory-utilization 0.9 \ --disable-log-request \ --kv-transfer-config \ - '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_buffer_size":"1e1","kv_port":"21001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20005","send_type":"PUT_ASYNC","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 & + '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_buffer_size":"1e1","kv_port":"21001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30201","http_port":"20005","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 & ``` ### Decode1 (e.g. 10.0.1.3 or 10.0.1.1) @@ -145,7 +145,7 @@ python3 disagg_prefill_proxy_xpyd.py & --gpu-memory-utilization 0.7 \ --disable-log-request \ --kv-transfer-config \ - '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_buffer_size":"8e9","kv_port":"22001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20009","send_type":"PUT_ASYNC","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 & + '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_buffer_size":"8e9","kv_port":"22001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30201","http_port":"20009","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 & ``` ### Decode2 (e.g. 10.0.1.4 or 10.0.1.1) @@ -167,7 +167,7 @@ python3 disagg_prefill_proxy_xpyd.py & --gpu-memory-utilization 0.7 \ --disable-log-request \ --kv-transfer-config \ - '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_buffer_size":"8e9","kv_port":"23001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20003","send_type":"PUT_ASYNC","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 & + '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_buffer_size":"8e9","kv_port":"23001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30201","http_port":"20003","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 & ``` ### Decode3 (e.g. 10.0.1.5 or 10.0.1.1) @@ -189,7 +189,7 @@ python3 disagg_prefill_proxy_xpyd.py & --gpu-memory-utilization 0.7 \ --disable-log-request \ --kv-transfer-config \ - '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_buffer_size":"8e9","kv_port":"24001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20008","send_type":"PUT_ASYNC","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 & + '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_buffer_size":"8e9","kv_port":"24001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30201","http_port":"20008","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 & ``` ## Run 3P1D @@ -220,7 +220,7 @@ python3 disagg_prefill_proxy_xpyd.py & --gpu-memory-utilization 0.9 \ --disable-log-request \ --kv-transfer-config \ - '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_buffer_size":"1e1","kv_port":"21001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20005","send_type":"PUT_ASYNC","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 & + '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_buffer_size":"1e1","kv_port":"21001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30201","http_port":"20005","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 & ``` ### Prefill2 (e.g. 10.0.1.3 or 10.0.1.1) @@ -242,7 +242,7 @@ python3 disagg_prefill_proxy_xpyd.py & --gpu-memory-utilization 0.9 \ --disable-log-request \ --kv-transfer-config \ - '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_buffer_size":"1e1","kv_port":"22001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20009","send_type":"PUT_ASYNC","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 & + '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_buffer_size":"1e1","kv_port":"22001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30201","http_port":"20009","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 & ``` ### Prefill3 (e.g. 10.0.1.4 or 10.0.1.1) @@ -264,7 +264,7 @@ python3 disagg_prefill_proxy_xpyd.py & --gpu-memory-utilization 0.9 \ --disable-log-request \ --kv-transfer-config \ - '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_buffer_size":"1e1","kv_port":"23001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20003","send_type":"PUT_ASYNC","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 & + '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_buffer_size":"1e1","kv_port":"23001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30201","http_port":"20003","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 & ``` ### Decode1 (e.g. 10.0.1.5 or 10.0.1.1) @@ -286,13 +286,13 @@ python3 disagg_prefill_proxy_xpyd.py & --gpu-memory-utilization 0.7 \ --disable-log-request \ --kv-transfer-config \ - '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_buffer_size":"8e9","kv_port":"24001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20008","send_type":"PUT_ASYNC","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 & + '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_buffer_size":"8e9","kv_port":"24001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30201","http_port":"20008","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 & ``` # Single request ```shell -curl -X POST -s http://10.0.1.1:10001/v1/completions \ +curl -X POST -s http://10.0.1.1:10101/v1/completions \ -H "Content-Type: application/json" \ -d '{ "model": "base_model", @@ -313,7 +313,7 @@ curl -X POST -s http://10.0.1.1:10001/v1/completions \ --tokenizer meta-llama/Llama-3.1-8B-Instruct \ --dataset-name "random" \ --host 10.0.1.1 \ - --port 10001 \ + --port 10101 \ --random-input-len 1024 \ --random-output-len 1024 \ --ignore-eos \ diff --git a/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd.py b/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd.py index 73f2caaa0dbd..7f1b5e4f6f8f 100644 --- a/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd.py +++ b/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd.py @@ -3,7 +3,9 @@ import os import socket import threading +import time import uuid +from typing import Any import aiohttp import msgpack @@ -11,12 +13,25 @@ from quart import Quart, make_response, request count = 0 -prefill_instances: dict[str, str] = {} # http_address: zmq_address -decode_instances: dict[str, str] = {} # http_address: zmq_address +prefill_instances: dict[str, Any] = {} # http_address: (zmq_address, stamp) +decode_instances: dict[str, Any] = {} # http_address: (zmq_address, stamp) prefill_cv = threading.Condition() decode_cv = threading.Condition() +DEFAULT_PING_SECONDS = 5 + + +def _remove_oldest_instances(instances: dict[str, Any]) -> None: + oldest_key = next(iter(instances), None) + while oldest_key is not None: + value = instances[oldest_key] + if value[1] > time.time(): + break + print(f"🔴Remove [HTTP:{oldest_key}, ZMQ:{value[0]}, stamp:{value[1]}]") + instances.pop(oldest_key, None) + oldest_key = next(iter(instances), None) + def _listen_for_register(poller, router_socket): while True: @@ -30,12 +45,23 @@ def _listen_for_register(poller, router_socket): global prefill_instances global prefill_cv with prefill_cv: - prefill_instances[data["http_address"]] = data["zmq_address"] + node = prefill_instances.pop(data["http_address"], None) + prefill_instances[data["http_address"]] = ( + data["zmq_address"], + time.time() + DEFAULT_PING_SECONDS, + ) + _remove_oldest_instances(prefill_instances) + elif data["type"] == "D": global decode_instances global decode_cv with decode_cv: - decode_instances[data["http_address"]] = data["zmq_address"] + node = decode_instances.pop(data["http_address"], None) + decode_instances[data["http_address"]] = ( + data["zmq_address"], + time.time() + DEFAULT_PING_SECONDS, + ) + _remove_oldest_instances(decode_instances) else: print( "Unexpected, Received message from %s, data: %s", @@ -43,6 +69,9 @@ def _listen_for_register(poller, router_socket): data, ) + if node is None: + print(f"🔵Add [HTTP:{data['http_address']}, ZMQ:{data['zmq_address']}") + def start_service_discovery(hostname, port): if not hostname: @@ -104,12 +133,14 @@ async def handle_request(): with prefill_cv: prefill_list = list(prefill_instances.items()) prefill_addr, prefill_zmq_addr = prefill_list[count % len(prefill_list)] + prefill_zmq_addr = prefill_zmq_addr[0] global decode_instances global decode_cv with decode_cv: decode_list = list(decode_instances.items()) decode_addr, decode_zmq_addr = decode_list[count % len(decode_list)] + decode_zmq_addr = decode_zmq_addr[0] print( f"handle_request count: {count}, [HTTP:{prefill_addr}, " @@ -149,6 +180,6 @@ async def handle_request(): if __name__ == "__main__": - t = start_service_discovery("0.0.0.0", 30001) - app.run(host="0.0.0.0", port=10001) + t = start_service_discovery("0.0.0.0", 30201) + app.run(host="0.0.0.0", port=10101) t.join() diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py index 2f870971ded7..818d2274b9b1 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py @@ -89,7 +89,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): self.p2p_nccl_engine = P2pNcclEngine( local_rank=self._local_rank, - config=self.config, + vllm_config=vllm_config, hostname="", port_offset=self._rank, ) if role == KVConnectorRole.WORKER else None @@ -202,7 +202,7 @@ def inject_kv_into_layer( if kv_cache is None: logger.warning("🚧src_kv_cache is None, %s", request.request_id) - continue + break inject_kv_into_layer(kv_cache_layer, kv_cache, request.slot_mapping, request.request_id) @@ -265,9 +265,8 @@ def extract_kv_from_layer( kv_cache, remote_address) def wait_for_save(self): - if self.is_producer: - assert self.p2p_nccl_engine is not None - self.p2p_nccl_engine.wait_for_sent() + """P2pNcclConnector does not save explicitly.""" + return def get_finished( self, finished_req_ids: set[str], @@ -317,10 +316,10 @@ def get_num_new_matched_tokens( num_external_tokens = (len(request.prompt_token_ids) - 1 - num_computed_tokens) - if num_external_tokens < 0: - num_external_tokens = 0 + if num_external_tokens <= 0: + return 0, False - return num_external_tokens, False + return num_external_tokens, True def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", @@ -328,7 +327,7 @@ def update_state_after_alloc(self, request: "Request", """ Update KVConnector state after block allocation. """ - if not self.is_producer and num_external_tokens > 0: + if not self.is_producer and num_external_tokens == 0: self._requests_need_load[request.request_id] = ( request, blocks.get_block_ids()[0]) @@ -355,6 +354,10 @@ def build_connector_meta( # the request's prompt is chunked prefill if num_tokens < len(new_req.prompt_token_ids): # 'CachedRequestData' has no attribute 'prompt_token_ids' + logger.info( + "🚧%s is chunked prefill, num_tokens:%d, num_prompt:%d", + new_req.req_id, num_tokens, + len(new_req.prompt_token_ids)) self.chunked_prefill[new_req.req_id] = ( new_req.block_ids[0], new_req.prompt_token_ids) continue @@ -388,6 +391,10 @@ def build_connector_meta( prompt_token_ids = self.chunked_prefill[req_id][1] # the request's prompt is chunked prefill again if num_tokens < len(prompt_token_ids): + logger.info( + "🚧%s is chunked prefill again, num_tokens:%d, " + "num_prompt:%d", req_id, num_tokens, + len(prompt_token_ids)) self.chunked_prefill[req_id] = (block_ids, prompt_token_ids) continue @@ -408,6 +415,8 @@ def build_connector_meta( total_tokens = num_computed_tokens + 1 token_ids = request.all_token_ids[:total_tokens] + logger.info("🚧%s is resumed from preemption, total_tokens:%d", + req_id, total_tokens) # NOTE(rob): For resumed req, new_block_ids is all # of the block_ids for the request. block_ids = new_block_ids[0] @@ -417,14 +426,6 @@ def build_connector_meta( block_ids=block_ids, block_size=self._block_size) - # Requests loaded asynchronously are not in the scheduler_output. - # for request_id in self._requests_need_load: - # request, block_ids = self._requests_need_load[request_id] - # meta.add_request(request_id=request.request_id, - # token_ids=request.prompt_token_ids, - # block_ids=block_ids, - # block_size=self._block_size) - self._requests_need_load.clear() return meta @@ -446,7 +447,7 @@ def request_finished( self.chunked_prefill.pop(request.request_id, None) - return False, None + return self.is_producer, None # ============================== # Static methods diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py index 35c26897fe3f..d59121b9211b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py @@ -13,7 +13,6 @@ import torch import zmq -from vllm.config import KVTransferConfig from vllm.distributed.device_communicators.pynccl_wrapper import ( NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum) from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501 @@ -21,6 +20,7 @@ from vllm.utils import current_stream, get_ip if TYPE_CHECKING: + from vllm.config import VllmConfig from vllm.forward_context import ForwardContext logger = logging.getLogger(__name__) @@ -43,7 +43,7 @@ def set_p2p_nccl_context(num_channels: str): for var in env_vars: original_values[var] = os.environ.get(var) - logger.info("set_p2p_nccl_context, original_values: %s", original_values) + logger.debug("set_p2p_nccl_context, original_values: %s", original_values) try: os.environ['NCCL_MAX_NCHANNELS'] = num_channels @@ -62,11 +62,12 @@ class P2pNcclEngine: def __init__(self, local_rank: int, - config: KVTransferConfig, + vllm_config: "VllmConfig", hostname: str = "", port_offset: int = 0, library_path: Optional[str] = None) -> None: - self.config = config + self.config = vllm_config.kv_transfer_config + self.compilation_config = vllm_config.compilation_config self.rank = port_offset self.local_rank = local_rank self.device = torch.device(f"cuda:{self.local_rank}") @@ -77,15 +78,15 @@ def __init__(self, port = int(self.config.kv_port) + port_offset if port == 0: raise ValueError("Port cannot be 0") - self._hostname = hostname - self._port = port + self.hostname = hostname + self.port = port # Each card corresponds to a ZMQ address. - self.zmq_address = f"{self._hostname}:{self._port}" + self.zmq_address = f"{self.hostname}:{self.port}" # The `http_port` must be consistent with the port of OpenAI. self.http_address = ( - f"{self._hostname}:" + f"{self.hostname}:" f"{self.config.kv_connector_extra_config['http_port']}") # If `proxy_ip` or `proxy_port` is `""`, @@ -117,20 +118,19 @@ def __init__(self, 1024**3) # GB # The sending type includes tree mutually exclusive options: - # PUT, GET, PUT_ASYNC. + # PUT, GET. self.send_type = self.config.get_from_extra_config("send_type", "PUT") if self.send_type == "GET": # tensor_id: torch.Tensor self.send_store: dict[str, torch.Tensor] = {} else: - # PUT or PUT_ASYNC + # PUT # tensor_id: torch.Tensor self.send_queue: deque[list[Any]] = deque() self.send_request_id_to_tensor_ids: dict[str, set[str]] = {} - if self.send_type == "PUT_ASYNC": - self._send_thread = threading.Thread(target=self._send_async, - daemon=True) - self._send_thread.start() + self.send_thread = threading.Thread(target=self.send_async, + daemon=True) + self.send_thread.start() # tensor_id: torch.Tensor/(addr, dtype, shape) self.recv_store: dict[str, Any] = {} @@ -144,15 +144,18 @@ def __init__(self, self.nccl_num_channels = self.config.get_from_extra_config( "nccl_num_channels", "8") - self._listener_thread = threading.Thread( - target=self._listen_for_requests, daemon=True) - self._listener_thread.start() + self.listener_thread = threading.Thread( + target=self.listen_for_requests, daemon=True) + self.listener_thread.start() - self._ping_thread = None + self.ping_thread = None if port_offset == 0 and self.proxy_address != "": - self._ping_thread = threading.Thread(target=self._ping, - daemon=True) - self._ping_thread.start() + self.ping_thread = threading.Thread(target=self.ping, daemon=True) + self.ping_thread.start() + + self.num_layers = 0 + self.finished_recving: set[str] = set() + self.finished_sending: set[str] = set() logger.info( "💯P2pNcclEngine init, rank:%d, local_rank:%d, http_address:%s, " @@ -161,7 +164,7 @@ def __init__(self, self.http_address, self.zmq_address, self.proxy_address, self.send_type, self.buffer_size_threshold, self.nccl_num_channels) - def _create_connect(self, remote_address: typing.Optional[str] = None): + def create_connect(self, remote_address: typing.Optional[str] = None): assert remote_address is not None if remote_address not in self.socks: sock = self.context.socket(zmq.DEALER) @@ -183,7 +186,7 @@ def _create_connect(self, remote_address: typing.Optional[str] = None): comm: ncclComm_t = self.nccl.ncclCommInitRank( 2, unique_id, rank) self.comms[remote_address] = (comm, rank) - logger.info("🤝ncclCommInitRank Success, %s👉%s, MyRank: %s", + logger.info("🤝ncclCommInitRank Success, %s👉%s, MyRank:%s", self.zmq_address, remote_address, rank) return self.socks[remote_address], self.comms[remote_address] @@ -199,37 +202,37 @@ def send_tensor( self.recv_store[tensor_id] = tensor self.recv_store_cv.notify() return True - else: - if self.send_type == "PUT": - return self._send_sync(tensor_id, tensor, remote_address) - elif self.send_type == "PUT_ASYNC": - with self.send_queue_cv: - self.send_queue.append([tensor_id, remote_address, tensor]) - self.send_queue_cv.notify() - else: # GET - with self.send_store_cv: - tensor_size = tensor.element_size() * tensor.numel() - while (self.buffer_size + tensor_size - > self.buffer_size_threshold): - oldest_tenser_id = next(iter(self.send_store)) - oldest_tenser = self.send_store.pop(oldest_tenser_id) - oldest_tenser_size = oldest_tenser.element_size( - ) * oldest_tenser.numel() - self.buffer_size -= oldest_tenser_size - logger.info( - "⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d," - " buffer_size:%d, oldest_tenser_size:%d, rank:%d", - remote_address, tensor_id, tensor_size, - self.buffer_size, oldest_tenser_size, self.rank) - - self.send_store[tensor_id] = tensor - self.buffer_size += tensor_size - logger.debug( - "🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, " - "shape:%s, rank:%d, buffer_size:%d(%.2f%%)", - remote_address, tensor_id, tensor_size, tensor.shape, - self.rank, self.buffer_size, - self.buffer_size / self.buffer_size_threshold * 100) + + if self.send_type == "PUT": + with self.send_queue_cv: + self.send_queue.append([tensor_id, remote_address, tensor]) + self.send_queue_cv.notify() + return True + + # GET + with self.send_store_cv: + tensor_size = tensor.element_size() * tensor.numel() + while (self.buffer_size + tensor_size + > self.buffer_size_threshold): + oldest_tenser_id = next(iter(self.send_store)) + oldest_tenser = self.send_store.pop(oldest_tenser_id) + oldest_tenser_size = oldest_tenser.element_size( + ) * oldest_tenser.numel() + self.buffer_size -= oldest_tenser_size + logger.info( + "⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d," + " buffer_size:%d, oldest_tenser_size:%d, rank:%d", + remote_address, tensor_id, tensor_size, self.buffer_size, + oldest_tenser_size, self.rank) + + self.send_store[tensor_id] = tensor + self.buffer_size += tensor_size + + logger.debug( + "🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, " + "shape:%s, rank:%d, buffer_size:%d(%.2f%%)", remote_address, + tensor_id, tensor_size, tensor.shape, self.rank, self.buffer_size, + self.buffer_size / self.buffer_size_threshold * 100) return True @@ -238,12 +241,13 @@ def recv_tensor( tensor_id: str, remote_address: typing.Optional[str] = None, ) -> torch.Tensor: - if self.send_type == "PUT" or self.send_type == "PUT_ASYNC": - start_time = time.time() + if self.send_type == "PUT": with self.recv_store_cv: - while tensor_id not in self.recv_store: - self.recv_store_cv.wait() - tensor = self.recv_store[tensor_id] + if tensor_id not in self.recv_store: + logger.warning( + "🔴[PUT]Recv From %s, tensor_id:%s not exist, rank:%d", + remote_address, tensor_id, self.rank) + tensor = self.recv_store.get(tensor_id) if tensor is not None: if isinstance(tensor, tuple): @@ -253,12 +257,11 @@ def recv_tensor( else: self.buffer_size -= (tensor.element_size() * tensor.numel()) + logger.debug("🔵[PUT]Recv From %s, tensor_id:%s, rank:%d", + remote_address, tensor_id, self.rank) else: - duration = time.time() - start_time - logger.warning( - "🔴[PUT]Recv From %s, tensor_id:%s, duration:%.3fms, " - "rank:%d", remote_address, tensor_id, duration * 1000, - self.rank) + logger.warning("🔴[PUT]Recv From %s, tensor_id:%s, rank:%d", + remote_address, tensor_id, self.rank) return tensor # GET @@ -266,7 +269,7 @@ def recv_tensor( return None if remote_address not in self.socks: - self._create_connect(remote_address) + self.create_connect(remote_address) sock = self.socks[remote_address] comm, rank = self.comms[remote_address] @@ -281,134 +284,134 @@ def recv_tensor( remote_address, tensor_id, data["ret"]) return None - tensor = torch.empty(data["shape"], - dtype=getattr(torch, data["dtype"]), - device=self.device) - - self._recv(comm, tensor, rank ^ 1, self.recv_stream) + with torch.cuda.stream(self.recv_stream): + tensor = torch.empty(data["shape"], + dtype=getattr(torch, data["dtype"]), + device=self.device) - return tensor + return self.recv(comm, tensor, rank ^ 1, self.recv_stream) - def _listen_for_requests(self): + def listen_for_requests(self): while True: socks = dict(self.poller.poll()) - if self.router_socket in socks: - remote_address, message = self.router_socket.recv_multipart() - data = msgpack.loads(message) - if data["cmd"] == "NEW": - unique_id = self.nccl.unique_id_from_bytes( - bytes(data["unique_id"])) - with torch.cuda.device(self.device): - rank = 1 - with set_p2p_nccl_context(self.nccl_num_channels): - comm: ncclComm_t = self.nccl.ncclCommInitRank( - 2, unique_id, rank) - self.comms[remote_address.decode()] = (comm, rank) - logger.info( - "🤝ncclCommInitRank Success, %s👈%s, MyRank:%s", - self.zmq_address, remote_address.decode(), rank) - elif data["cmd"] == "PUT": - tensor_id = data["tensor_id"] - try: - with torch.cuda.stream(self.recv_stream): - tensor = torch.empty(data["shape"], - dtype=getattr( - torch, data["dtype"]), - device=self.device) - self.router_socket.send_multipart( - [remote_address, b"0"]) - comm, rank = self.comms[remote_address.decode()] - self._recv(comm, tensor, rank ^ 1, self.recv_stream) - tensor_size = tensor.element_size() * tensor.numel() - if (self.buffer_size + tensor_size - > self.buffer_size_threshold): - # Store Tensor in memory pool - addr = self.pool.store_tensor(tensor) - tensor = (addr, tensor.dtype, tensor.shape) - logger.warning( - "🔴[PUT]Recv Tensor, Out Of Threshold, " - "%s👈%s, data:%s, addr:%d", self.zmq_address, - remote_address.decode(), data, addr) - else: - self.buffer_size += tensor_size - - except torch.cuda.OutOfMemoryError: - self.router_socket.send_multipart( - [remote_address, b"1"]) - tensor = None + if self.router_socket not in socks: + continue + + remote_address, message = self.router_socket.recv_multipart() + data = msgpack.loads(message) + remote = remote_address.decode() + if data["cmd"] == "NEW": + unique_id = self.nccl.unique_id_from_bytes( + bytes(data["unique_id"])) + with torch.cuda.device(self.device): + rank = 1 + with set_p2p_nccl_context(self.nccl_num_channels): + comm: ncclComm_t = self.nccl.ncclCommInitRank( + 2, unique_id, rank) + self.comms[remote] = (comm, rank) + logger.info("🤝ncclCommInitRank Success, %s👈%s, MyRank:%s", + self.zmq_address, remote, rank) + elif data["cmd"] == "PUT": + tensor_id = data["tensor_id"] + try: + with torch.cuda.stream(self.recv_stream): + tensor = torch.empty(data["shape"], + dtype=getattr( + torch, data["dtype"]), + device=self.device) + self.router_socket.send_multipart([remote_address, b"0"]) + comm, rank = self.comms[remote] + self.recv(comm, tensor, rank ^ 1, self.recv_stream) + tensor_size = tensor.element_size() * tensor.numel() + if (self.buffer_size + tensor_size + > self.buffer_size_threshold): + # Store Tensor in memory pool + addr = self.pool.store_tensor(tensor) + tensor = (addr, tensor.dtype, tensor.shape) logger.warning( - "🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, " - "data:%s", self.zmq_address, - remote_address.decode(), data) - - with self.recv_store_cv: - self.recv_store[tensor_id] = tensor - self._have_received_tensor_id(tensor_id) - self.recv_store_cv.notify() - - elif data["cmd"] == "GET": - tensor_id = data["tensor_id"] - with self.send_store_cv: - tensor = self.send_store.pop(tensor_id, None) - if tensor is not None: - data = { - "ret": 0, - "shape": tensor.shape, - "dtype": - str(tensor.dtype).replace("torch.", "") - } - # LRU - self.send_store[tensor_id] = tensor - self._have_sent_tensor_id(tensor_id) - else: - data = {"ret": 1} - - self.router_socket.send_multipart( - [remote_address, msgpack.dumps(data)]) - - if data["ret"] == 0: - comm, rank = self.comms[remote_address.decode()] - self._send(comm, tensor.to(self.device), rank ^ 1, - self.send_stream) - else: + "🔴[PUT]Recv Tensor, Out Of Threshold, " + "%s👈%s, data:%s, addr:%d", self.zmq_address, + remote, data, addr) + else: + self.buffer_size += tensor_size + except torch.cuda.OutOfMemoryError: + self.router_socket.send_multipart([remote_address, b"1"]) + tensor = None logger.warning( - "🚧Unexpected, Received message from %s, data:%s", - remote_address, data) + "🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, " + "data:%s", self.zmq_address, remote, data) + + with self.recv_store_cv: + self.recv_store[tensor_id] = tensor + self.have_received_tensor_id(tensor_id) + self.recv_store_cv.notify() + + logger.debug( + "🔵[PUT]Recv Tensor, %s👈%s, is_success:%s, data:%s", + self.zmq_address, remote, tensor is not None, data) + + elif data["cmd"] == "GET": + tensor_id = data["tensor_id"] + with self.send_store_cv: + tensor = self.send_store.pop(tensor_id, None) + if tensor is not None: + data = { + "ret": 0, + "shape": tensor.shape, + "dtype": str(tensor.dtype).replace("torch.", "") + } + # LRU + self.send_store[tensor_id] = tensor + self.have_sent_tensor_id(tensor_id) + else: + data = {"ret": 1} + + self.router_socket.send_multipart( + [remote_address, msgpack.dumps(data)]) + + if data["ret"] == 0: + comm, rank = self.comms[remote] + self.send(comm, tensor.to(self.device), rank ^ 1, + self.send_stream) + else: + logger.warning( + "🚧Unexpected, Received message from %s, data:%s", + remote_address, data) - def _have_sent_tensor_id(self, tensor_id: str): + def get_num_layers(self): + if self.num_layers == 0: + self.num_layers = len( + self.compilation_config.static_forward_context) + logger.debug("get_num_layers, num_layers:%d", self.num_layers) + return self.num_layers + + def have_sent_tensor_id(self, tensor_id: str): request_id = tensor_id.split('#')[0] if request_id not in self.send_request_id_to_tensor_ids: self.send_request_id_to_tensor_ids[request_id] = set() self.send_request_id_to_tensor_ids[request_id].add(tensor_id) + if self.get_num_layers() == len( + self.send_request_id_to_tensor_ids[request_id]): + self.finished_sending.add(request_id) - def _have_received_tensor_id(self, tensor_id: str): + def have_received_tensor_id(self, tensor_id: str): request_id = tensor_id.split('#')[0] if request_id not in self.recv_request_id_to_tensor_ids: self.recv_request_id_to_tensor_ids[request_id] = set() self.recv_request_id_to_tensor_ids[request_id].add(tensor_id) + if self.get_num_layers() == len( + self.recv_request_id_to_tensor_ids[request_id]): + self.finished_recving.add(request_id) - def _send_async(self): + def send_async(self): while True: with self.send_queue_cv: while not self.send_queue: self.send_queue_cv.wait() tensor_id, remote_address, tensor = self.send_queue.popleft() - if not self.send_queue: - self.send_queue_cv.notify() - self._send_sync(tensor_id, tensor, remote_address) - - def wait_for_sent(self): - if self.send_type == "PUT_ASYNC": - start_time = time.time() - with self.send_queue_cv: - while self.send_queue: - self.send_queue_cv.wait() - duration = time.time() - start_time - logger.debug( - "🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue" - " to be empty, rank:%d", duration * 1000, self.rank) + self.send_sync(tensor_id, tensor, remote_address) - def _send_sync( + def send_sync( self, tensor_id: str, tensor: torch.Tensor, @@ -417,7 +420,7 @@ def _send_sync( if remote_address is None: return False if remote_address not in self.socks: - self._create_connect(remote_address) + self.create_connect(remote_address) sock = self.socks[remote_address] comm, rank = self.comms[remote_address] @@ -439,10 +442,12 @@ def _send_sync( response.decode()) return False - self._send(comm, tensor.to(self.device), rank ^ 1, self.send_stream) + self.send(comm, tensor.to(self.device), rank ^ 1, self.send_stream) + + self.have_sent_tensor_id(tensor_id) - if self.send_type == "PUT_ASYNC": - self._have_sent_tensor_id(tensor_id) + logger.debug("🔵[PUT]Send Tensor, %s👉%s, data:%s", self.zmq_address, + remote_address, data) return True @@ -471,20 +476,21 @@ def get_finished( request_id, None) self.recv_request_id_to_tensor_ids.pop( request_id, None) - addr = 0 if isinstance(tensor, tuple): addr, _, _ = tensor self.pool.free(addr) - - # TODO:Retrieve requests that have already sent the KV cache. - finished_sending: set[str] = set() - - # TODO:Retrieve requests that have already received the KV cache. - finished_recving: set[str] = set() - + logger.debug("🔵get_finished, request_id:%s", request_id) + + # Retrieve requests that have already sent the KV cache. + finished_sending = self.finished_sending.copy() + # Retrieve requests that have already received the KV cache. + finished_recving = self.finished_recving.copy() + self.finished_sending.clear() + self.finished_recving.clear() + # TODO: Add failed requests (e.g., transmission errors) return finished_sending or None, finished_recving or None - def _ping(self): + def ping(self): sock = self.context.socket(zmq.DEALER) sock.setsockopt_string(zmq.IDENTITY, self.zmq_address) logger.debug("ping start, zmq_address:%s", self.zmq_address) @@ -498,35 +504,31 @@ def _ping(self): sock.send(msgpack.dumps(data)) time.sleep(3) - def _send(self, comm, tensor: torch.Tensor, dst: int, stream=None): - assert tensor.device == self.device, ( - f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {tensor.device}") - if stream is None: - stream = current_stream() - - with torch.cuda.stream(stream): - self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(), - ncclDataTypeEnum.from_torch(tensor.dtype), dst, - comm, cudaStream_t(stream.cuda_stream)) - stream.synchronize() - - def _recv(self, comm, tensor: torch.Tensor, src: int, stream=None): + def send(self, comm, tensor: torch.Tensor, dst: int, stream=None): assert tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {tensor.device}") - if stream is None: - stream = current_stream() - + stream = stream if stream is not None else current_stream() + event = torch.cuda.Event() + self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), dst, + comm, cudaStream_t(stream.cuda_stream)) + event.record(stream) + event.synchronize() + + def recv(self, comm, tensor: torch.Tensor, src: int, stream=None): + stream = stream if stream is not None else current_stream() + event = torch.cuda.Event() with torch.cuda.stream(stream): self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(), ncclDataTypeEnum.from_torch(tensor.dtype), src, comm, cudaStream_t(stream.cuda_stream)) - stream.synchronize() + event.record(stream) + event.synchronize() + return tensor def close(self) -> None: - self._listener_thread.join() - if self.send_type == "PUT_ASYNC": - self._send_thread.join() - if self._ping_thread is not None: - self._ping_thread.join() + self.listener_thread.join() + self.send_thread.join() + if self.ping_thread is not None: + self.ping_thread.join() diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 20a40d74f311..a6c379ea3ca7 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -923,6 +923,7 @@ def finish_requests( # First pass: collect requests to remove from queues for req_id in request_ids: + self.finished_recving_kv_req_ids.discard(req_id) request = self.requests.get(req_id) if request is None: # Invalid request ID. @@ -1062,8 +1063,6 @@ def _update_waiting_for_remote_kv(self, request: Request) -> bool: # Update the request state for scheduling. request.num_computed_tokens = num_computed_tokens - # Return that we are ready. - self.finished_recving_kv_req_ids.remove(request.request_id) return True def _update_from_kv_xfer_finished(self, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5bdaf4b969e7..bdb45151191d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1496,6 +1496,16 @@ def execute_model( # Get the valid generated tokens. sampled_token_ids = sampler_output.sampled_token_ids max_gen_len = sampled_token_ids.shape[-1] + + # Eliminate global synchronization in `cudaMemcpyAsync`. + gpu_event = torch.cuda.Event() + gpu_event.record() + while not gpu_event.query(): + # It can achieve a precision of around 50 microseconds. + # sched_yield can achieve a precision of around 1.25 microseconds. + # However, this can lead to very high CPU utilization. + time.sleep(0) + if max_gen_len == 1: # No spec decode tokens. valid_sampled_token_ids = sampled_token_ids.tolist()