From 42a832ced63ff95ba32d900170fbfdf63b14bd79 Mon Sep 17 00:00:00 2001 From: nicklucche Date: Sat, 10 May 2025 09:54:07 +0000 Subject: [PATCH 1/7] split kv_cache along head dim address race condition optimize req state checking loop release_xfer_handle on status DONE send notif to agent_name fix req abort Signed-off-by: nicklucche --- .../nixl_integration/run_accuracy_test.sh | 11 +- .../kv_connector/v1/nixl_connector.py | 319 +++++++++++++----- vllm/v1/core/sched/scheduler.py | 5 + 3 files changed, 240 insertions(+), 95 deletions(-) diff --git a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh index c17784e0a263..b48655d80eef 100755 --- a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh @@ -8,7 +8,9 @@ MODELS=( # Number of prefill and decode instances to create NUM_PREFILL_INSTANCES=${NUM_PREFILL_INSTANCES:-1} # Default to 1 -NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-2} # Default to 2 +NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-1} # Default to 1 +PREFILLER_TP_SIZE=${PREFILLER_TP_SIZE:-1} +DECODER_TP_SIZE=${DECODER_TP_SIZE:-1} # Find the git repository root directory GIT_ROOT=$(git rev-parse --show-toplevel) @@ -74,9 +76,10 @@ run_tests_for_model() { for i in $(seq 0 $((NUM_PREFILL_INSTANCES-1))); do # Calculate GPU ID - we'll distribute across available GPUs GPU_ID=$((i % $(get_num_gpus))) + # Calculate port number (base port + instance number) PORT=$((8100 + i)) - # Calculate side channel port + # Calculate side channel port. Avoid clash with with TP workers. SIDE_CHANNEL_PORT=$((5559 + i)) echo "Starting prefill instance $i on GPU $GPU_ID, port $PORT" @@ -87,6 +90,7 @@ run_tests_for_model() { --enforce-eager \ --disable-log-requests \ --gpu-memory-utilization 0.2 \ + --tensor-parallel-size $PREFILLER_TP_SIZE \ --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'" if [ -n "$model_args" ]; then @@ -109,7 +113,7 @@ run_tests_for_model() { # Calculate port number (base port + instance number) PORT=$((8200 + i)) # Calculate side channel port - SIDE_CHANNEL_PORT=$((5659 + i)) + SIDE_CHANNEL_PORT=$((5659 + i * $DECODER_TP_SIZE)) echo "Starting decode instance $i on GPU $GPU_ID, port $PORT" @@ -119,6 +123,7 @@ run_tests_for_model() { --enforce-eager \ --disable-log-requests \ --gpu-memory-utilization 0.2 \ + --tensor-parallel-size $DECODER_TP_SIZE \ --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'" if [ -n "$model_args" ]; then diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index fd22280126d6..ee36cbd1f33a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -54,6 +54,8 @@ class NixlAgentMetadata( agent_metadata: bytes kv_caches_base_addr: list[int] num_blocks: int + tp_size: int + block_len: int @dataclass @@ -331,10 +333,14 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): logger.info("Initializing NIXL wrapper") logger.info("Initializing NIXL worker %s", engine_id) + # Config. + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + # Agent. self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None) - # Map of engine_id -> agent_name. - self._remote_agents: dict[str, str] = {} + # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}. + self._remote_agents: dict[str, dict[int, str]] = defaultdict(dict) # NIXL handshake port. # NOTE(rob): Within a DP group, each DP rank gets its own @@ -354,27 +360,31 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # KV Caches and nixl tracking data. self.kv_caches: dict[str, torch.Tensor] = {} - # Map of engine_id -> kv_caches_base_addr - self.kv_caches_base_addr: dict[str, list[int]] = {} + # Map of engine_id -> kv_caches_base_addr. For TP case, each local + # rank will still only pull from a single remote TP worker. + self.kv_caches_base_addr: dict[str, list[int]] = dict() # Number of NIXL regions. Currently one region per cache # (so 1 per layer for MLA, otherwise 2 per layer) self.num_regions = 0 self.num_layers = 0 - # nixl_prepped_dlist_handle (int). + # nixl_prepped_dlist_handle. self.src_xfer_side_handle: int = 0 # Map of engine_id -> nixl_prepped_dlist_handle (int)]. - self.dst_xfer_side_handles: dict[str, int] = {} + self.dst_xfer_side_handles: dict[str, int] = dict() - # Map of engine_id -> num_blocks. - self.dst_num_blocks: dict[str, int] = {} + # Map of engine_id -> num_blocks. All ranks in the same deployment will + # have the same number of blocks. + self.dst_num_blocks: dict[str, int] = dict() self._registered_descs: list[Any] = [] # In progress transfers. # [req_id -> list[handle]] - self._recving_transfers: defaultdict[str, list[Any]] = defaultdict( - list[Any]) + self._recving_transfers: defaultdict[str, + list[tuple[int, + float]]] = defaultdict( + list[Any]) # Complete transfer tracker. Used by the rank 0 to track finished # transactions on ranks 1 to N-1. @@ -395,6 +405,12 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # List of block window sizes for each layer for local attention self.block_window_per_layer: list[Optional[int]] = [] + self._tp_size: dict[str, int] = {self.engine_id: self.world_size} + # With heterogeneous TP, P must wait for all assigned D TP workers to + # finish reading before safely freeing the blocks. + self.consumer_notification_counts_by_req: dict[str, + int] = defaultdict(int) + @staticmethod def _nixl_handshake_listener(metadata: NixlAgentMetadata, ready_event: threading.Event, base_port: int, @@ -426,27 +442,44 @@ def _nixl_handshake(self, host: str, port: int): """Do a NIXL handshake with a remote instance.""" start_time = time.perf_counter() - # NOTE(rob): we need each tp_rank to have a unique port. - # This is a hack to keep us moving. We will switch when - # we switch to HTTP-based NIXL metadata exchange. - path = make_zmq_path("tcp", host, port + self.tp_rank) - logger.debug("Querying metadata on path: %s", path) - with zmq_ctx(zmq.REQ, path) as sock: - # Send query for the request. - sock.send(GET_META_MSG) - metadata_bytes = sock.recv() - decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) - metadata = decoder.decode(metadata_bytes) - got_metadata_time = time.perf_counter() - # Register Remote agent. - self.add_remote_agent(metadata) - setup_agent_time = time.perf_counter() + # NOTE(rob): we need each rank to have a unique port. This is + # a hack to keep us moving. We will switch when moving to etcd + # or where we have a single ZMQ socket in the scheduler. - logger.debug("NIXL handshake: get metadata took: %s", - got_metadata_time - start_time) - logger.debug("NIXL handshake: add agent took: %s", - setup_agent_time - got_metadata_time) + def handshake(path: str, rank: int) -> NixlAgentMetadata: + # Send query for the request. + with zmq_ctx(zmq.REQ, path) as sock: + sock.send(GET_META_MSG) + metadata_bytes = sock.recv() + decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) + metadata = decoder.decode(metadata_bytes) + got_metadata_time = time.perf_counter() + + # Register Remote agent. + self.add_remote_agent(metadata, rank) + setup_agent_time = time.perf_counter() + + logger.debug("NIXL handshake: get metadata took: %s", + got_metadata_time - start_time) + logger.debug("NIXL handshake: add agent took: %s", + setup_agent_time - got_metadata_time) + return metadata + + # Handshake with remote agent-rank0 first to get the tp_size of remote + path = f"tcp://{host}:{port}" + logger.debug("Querying master rank metadata on path: %s", path) + metadata = handshake(path, 0) + + # Handshake only with the other TP remote the current local rank will + # pull from. With homogeneous TP it happens to be the same rank_i. + tp_rate = self._tp_size[self.engine_id] // metadata.tp_size + p_remote_rank = self.tp_rank // tp_rate + if p_remote_rank > 0: + path = f"tcp://{host}:{port + p_remote_rank}" + logger.debug("Querying metadata on path: %s at remote rank %s", + path, p_remote_rank) + _ = handshake(path, p_remote_rank) def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in nixl.""" @@ -461,14 +494,20 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.num_blocks = first_kv_cache.shape[0] block_rank = 2 # [block_size, latent_dim] block_shape = first_kv_cache.shape[-block_rank:] + block_size, kv_latent_dim = block_shape + self.slot_size_bytes = kv_elem_size * kv_latent_dim else: - # [2 (k and v), num_blocks, ...] + # [2 (k and v), num_blocks, block_size, kv_heads, head_dim] self.num_blocks = first_kv_cache.shape[1] block_rank = 3 # [block_size, kv_heads, head_dim] block_shape = first_kv_cache.shape[-block_rank:] - + block_size, n_kv_heads, head_dim = block_shape + # head size in bytes. + self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim + assert block_size == self.block_size # TODO(tms): self.block_len needs to be per-layer for sliding window, # hybrid attn, etc + # block size in bytes self.block_len = kv_elem_size * math.prod(block_shape) logger.debug("Registering KV_Caches. use_mla: %s, shape %s", use_mla, @@ -524,16 +563,39 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): logger.debug("Registering descs: %s", caches_data) self.nixl_wrapper.register_memory(descs) logger.debug("Done registering descs") - self._registered_descs.append(descs) + # Register local/src descr for NIXL xfer. + blocks_data = [] + for base_addr in self.kv_caches_base_addr[self.engine_id]: + # NOTE With heter-TP, more blocks are prepared than what are + # needed as self.num_blocks >= nixl_agent_meta.num_blocks. We + # could create fewer, but then _get_block_descs_ids needs to + # select agent_meta.num_blocks instead of self.num_blocks for + # local descr, and that makes handling regular flow less clean. + for block_id in range(self.num_blocks): + block_offset = block_id * self.block_len + for slot_idx in range(self.block_size): + slot_offset = slot_idx * self.slot_size_bytes + addr = base_addr + block_offset + slot_offset + # (addr, len, device id) + blocks_data.append((addr, self.slot_size_bytes, self.tp_rank)) + logger.debug("Created %s blocks for src engine %s and rank %s", + len(blocks_data), self.engine_id, self.tp_rank) + + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") + # NIXL_INIT_AGENT to be used for preparations of local descs. + self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( + "NIXL_INIT_AGENT", descs) + # After KV Caches registered, listen for new connections. metadata = NixlAgentMetadata( engine_id=self.engine_id, agent_metadata=self.nixl_wrapper.get_agent_metadata(), kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], num_blocks=self.num_blocks, - ) + tp_size=self.world_size, + block_len=self.block_len) ready_event = threading.Event() self._nixl_handshake_listener_t = threading.Thread( target=self._nixl_handshake_listener, @@ -543,50 +605,108 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self._nixl_handshake_listener_t.start() ready_event.wait() - def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata): + def add_remote_agent(self, + nixl_agent_meta: NixlAgentMetadata, + remote_rank: int = 0): + """ + Add the remote NIXL agent and prepare the descriptors for reading cache + blocks from remote. + + In particular, handle both homogeneous and heterogeneous TP. The latter + requires local rank_i to read from remote rank_i. + The former, assuming D.world_size > P.world_size, requires that two or + more local TP worker share the xfer from a single TP worker. + + Here's an example: + + rank_offset p_remote_rank + (kv split no) + -------------------------------- + 0 0 Worker0 ---- 1st half of KV ----> Worker0 [ KV Cache ] + / + 1 0 Worker1 ---- 2nd half of KV -----/ + + 0 1 Worker2 ---- 1st half of KV ----> Worker1 [ KV Cache ] + / + 1 1 Worker3 ---- 2nd half of KV -----/ + + + Decoder TP workers Prefix TP workers + (world_size=4) (world_size=2) + tp_ratio = 4 // 2 = 2 + + Considering the KV Caches, if P-Worker_i has cache size [2, num_blocksP, block_size, kv_heads, head_dim] + then D-Worker_j has [2, num_blocksD, block_size, kv_heads//tp_ratio, head_dim]. + Assuming num_blocksD >= num_blocksP, D-Worker0 reads from P-Worker0 by preparing the kv_heads//tp_ratio + first heads from all the slots of all the blocks in the case. + D-Worker1 will do the same, but reading the second split along the kv_heads dimension. + + Note that the above will also hold true for the homogeneous TP case. + """ # noqa: E501 + engine_id = nixl_agent_meta.engine_id - assert engine_id != self.engine_id, "Conflict engine id found!" - if engine_id in self._remote_agents: + # TODO re-evaluate refreshing for scaling/recovery + if (engine_id in self._remote_agents and \ + remote_rank in self._remote_agents[engine_id]): return - self._remote_agents[engine_id] = self.nixl_wrapper.add_remote_agent( - nixl_agent_meta.agent_metadata) - self.kv_caches_base_addr[ - engine_id] = nixl_agent_meta.kv_caches_base_addr - - # Create src descs and xfer side handles. - blocks_data = [] - for base_addr in self.kv_caches_base_addr[self.engine_id]: - for block_id in range(self.num_blocks): - block_offset = block_id * self.block_len - # (addr, len, device id) - blocks_data.append( - (base_addr + block_offset, self.block_len, self.tp_rank)) - logger.debug("Created %s blocks for src engine %s and tp_rank %s", - len(blocks_data), self.engine_id, self.tp_rank) + if engine_id in self._tp_size: + assert self._tp_size[engine_id] == nixl_agent_meta.tp_size + self._tp_size[engine_id] = nixl_agent_meta.tp_size + self._remote_agents[engine_id][ + remote_rank] = self.nixl_wrapper.add_remote_agent( + nixl_agent_meta.agent_metadata) + + # Number of D TP workers reading from a single P TP worker. This is + # 1 when P and D `--tensor-parallel-size` match. + tp_ratio = self._tp_size[self.engine_id] // self._tp_size[engine_id] + assert tp_ratio > 0, "Decode TP cannot be smaller than" + " prefill TP" + + # TODO we should also check hidden_dim and kv precision, they must match + remote_block_size = nixl_agent_meta.block_len / (self.slot_size_bytes * + tp_ratio) + assert self.block_size == remote_block_size, "Remote P worker with " + "different block size is not supported" + + # Create dst descs and xfer side handles. TP workers have same #blocks. + if engine_id in self.dst_num_blocks: + assert self.dst_num_blocks[engine_id] == nixl_agent_meta.num_blocks - # Register with NIXL. - descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") - self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( - "NIXL_INIT_AGENT", descs) - - # Create dst descs and xfer side handles. self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks + blocks_data = [] - for base_addr in self.kv_caches_base_addr[engine_id]: - for block_id in range(nixl_agent_meta.num_blocks): - block_offset = block_id * self.block_len - # (addr, len, device id) - blocks_data.append( - (base_addr + block_offset, self.block_len, self.tp_rank)) - logger.debug("Created %s blocks for dst engine %s and tp_rank %s", - len(blocks_data), engine_id, self.tp_rank) + # With homogeneous TP, D pulls the whole kv cache from corresponding + # rank. With heterogeneous TP, prepare the descriptors by splitting the + # P KV cache along kv_head dim, of D worker's kv_head size (D>P). + # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. + p_remote_rank = self.tp_rank // tp_ratio + # Only register the remote's descriptors if current rank pulls from it. + if p_remote_rank == remote_rank: + self.kv_caches_base_addr[ + engine_id] = nixl_agent_meta.kv_caches_base_addr + rank_offset = self.tp_rank % tp_ratio * self.slot_size_bytes + # Register all remote blocks, but only the corresponding kv heads. + for base_addr in nixl_agent_meta.kv_caches_base_addr: + for block_id in range(nixl_agent_meta.num_blocks): + block_offset = block_id * nixl_agent_meta.block_len + for slot_idx in range(self.block_size): + # Remote has `tp_ratio` times the kv_heads of local. + slot_offset = slot_idx * self.slot_size_bytes * tp_ratio + addr = base_addr + block_offset + slot_offset + # (addr, len, device id) + blocks_data.append((addr + rank_offset, + self.slot_size_bytes, remote_rank)) + logger.debug( + "Created %s blocks for dst engine %s with remote rank %s and " \ + "local rank %s", + len(blocks_data), engine_id, remote_rank, self.tp_rank) - # Register with NIXL. - descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") - self.dst_xfer_side_handles[ - engine_id] = self.nixl_wrapper.prep_xfer_dlist( - self._remote_agents[engine_id], descs) + # Register with NIXL. + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") + self.dst_xfer_side_handles[ + engine_id] = self.nixl_wrapper.prep_xfer_dlist( + self._remote_agents[engine_id][remote_rank], descs) def get_finished(self) -> tuple[set[str], set[str]]: """ @@ -657,13 +777,19 @@ def _get_new_notifs(self) -> set[str]: """Get req_ids which got a remote xfer message.""" notified_req_ids: set[str] = set() - for req_ids in self.nixl_wrapper.get_new_notifs().values(): - for req_id in req_ids: - assert req_id not in notified_req_ids - notified_req_ids.add(req_id.decode("utf-8")) + for notifs in self.nixl_wrapper.get_new_notifs().values(): + for notif in notifs: + req_id, tp_ratio = notif.decode("utf-8").rsplit(":", 1) + self.consumer_notification_counts_by_req[req_id] += 1 + # Wait all consumers (D) to be done reading before freeing. + if self.consumer_notification_counts_by_req[req_id] == int( + tp_ratio): + notified_req_ids.add(req_id) + del self.consumer_notification_counts_by_req[req_id] return notified_req_ids - def _pop_done_transfers(self, transfers: dict[str, list[int]]) -> set[str]: + def _pop_done_transfers( + self, transfers: dict[str, list[tuple[int, float]]]) -> set[str]: """ Pop completed xfers by checking for DONE state. Args: @@ -673,23 +799,17 @@ def _pop_done_transfers(self, transfers: dict[str, list[int]]) -> set[str]: """ done_req_ids: set[str] = set() for req_id, handles in list(transfers.items()): - running_reqs = [] - for handle in handles: + for handle, xfer_stime in handles: xfer_state = self.nixl_wrapper.check_xfer_state(handle) if xfer_state == "DONE": - # TODO ptarasiewicz: why abort is throwing errors? - # self.nixl_wrapper.release_xfer_handle(handle) + self.nixl_wrapper.release_xfer_handle(handle) + done_req_ids.add(req_id) + del transfers[req_id] + elif xfer_state == "PROC": continue - if xfer_state == "PROC": - running_reqs.append(handle) else: raise RuntimeError("Transfer failed with state %s", xfer_state) - if len(running_reqs) == 0: - done_req_ids.add(req_id) - del transfers[req_id] - else: - transfers[req_id] = running_reqs return done_req_ids def start_load_kv(self, metadata: NixlConnectorMetadata): @@ -735,13 +855,19 @@ def _read_blocks( # saturate IB with heterogeneous TP sizes. We should remove the staging # blocks until we are ready. + # Number of D TP workers that will read from dst P. Propagate tp_ratio + # on notification so that dst worker can wait before freeing blocks. + tp_ratio = self._tp_size[ + self.engine_id] // self._tp_size[dst_engine_id] + notif_id = f"{request_id}:{tp_ratio}".encode() + # Full prefix cache hit: do not need to read remote blocks, # just notify P worker that we have the blocks we need. num_local_blocks = len(local_block_ids) if num_local_blocks == 0: - agent_name = self._remote_agents[dst_engine_id] - self.nixl_wrapper.send_notif(agent_name, - notif_msg=request_id.encode("utf-8")) + remote_rank = self.tp_rank // tp_ratio + agent_name = self._remote_agents[dst_engine_id][remote_rank] + self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id) return # Partial prefix cache hit: just read uncomputed blocks. @@ -754,6 +880,10 @@ def _read_blocks( local_xfer_side_handle = self.src_xfer_side_handle remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id] + # NOTE (nicolo) With homogeneous TP, each TP worker loads KV from + # corresponding rank. With heterogeneous TP, fixing D>P, the D tp + # workers will issue xfers to parts of the P worker remote kv caches. + # Get descs ids. local_block_descs_ids: list[int] = [] remote_block_descs_ids: list[int] = [] @@ -797,14 +927,16 @@ def _read_blocks( local_block_descs_ids, remote_xfer_side_handle, remote_block_descs_ids, - notif_msg=request_id.encode("utf-8"), + notif_msg=notif_id, ) # Begin async xfer. self.nixl_wrapper.transfer(handle) # Use handle to check completion in future step(). - self._recving_transfers[request_id].append(handle) + # TODO (NickLucche) surface xfer elapsed time + self._recving_transfers[request_id].append( + (handle, time.perf_counter())) def _get_block_descs_ids(self, engine_id: str, @@ -815,6 +947,7 @@ def _get_block_descs_ids(self, If layer_idx is provided, we use the region_ids for the given layer. Otherwise, we use all regions. """ + # TODO TP docs if layer_idx is None: region_ids = range(self.num_regions) @@ -837,7 +970,9 @@ def _get_block_descs_ids(self, descs_ids: list[int] = [] for reg_id in region_ids: for block_id in block_ids: - descs_ids.append(reg_id * num_blocks + block_id) + for slot_id in range(self.block_size): + descs_ids.append(reg_id * num_blocks * self.block_size + + block_id * self.block_size + slot_id) return descs_ids diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 32d03b311a4e..bb21aca5fc0e 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -986,6 +986,11 @@ def _connector_finished( return False, None assert len(self.kv_cache_config.kv_cache_groups ) == 1, "KV connector only supports one KV cache group now" + if (request.status == RequestStatus.FINISHED_ABORTED and \ + request.request_id not in + self.kv_cache_manager.single_type_manager.req_to_blocks): + return False, None + block_ids = self.kv_cache_manager.get_block_ids(request.request_id)[0] return self.connector.request_finished(request, block_ids) From ddf969d67814ddc096ffe36b73b1b9471ce1213e Mon Sep 17 00:00:00 2001 From: nicklucche Date: Tue, 27 May 2025 14:58:13 +0000 Subject: [PATCH 2/7] transpose kv cache for faster xfers docs Signed-off-by: nicklucche --- .../kv_connector/v1/nixl_connector.py | 59 ++++++++++--------- vllm/v1/worker/gpu_model_runner.py | 1 + 2 files changed, 31 insertions(+), 29 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index ee36cbd1f33a..0a492e841553 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -489,6 +489,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # TODO(tms): Find a more robust way to detect and handle MLA use_mla = len(first_kv_cache.shape) == 3 + # FIXME Actual memory layout is NOT the one you expect from + # dims, needs docs + assert not first_kv_cache.is_contiguous() if use_mla: # MLA case. self.num_blocks = first_kv_cache.shape[0] @@ -575,11 +578,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # local descr, and that makes handling regular flow less clean. for block_id in range(self.num_blocks): block_offset = block_id * self.block_len - for slot_idx in range(self.block_size): - slot_offset = slot_idx * self.slot_size_bytes - addr = base_addr + block_offset + slot_offset - # (addr, len, device id) - blocks_data.append((addr, self.slot_size_bytes, self.tp_rank)) + addr = base_addr + block_offset + # (addr, len, device id) + blocks_data.append((addr, self.block_len, self.rank)) logger.debug("Created %s blocks for src engine %s and rank %s", len(blocks_data), self.engine_id, self.tp_rank) @@ -635,15 +636,14 @@ def add_remote_agent(self, (world_size=4) (world_size=2) tp_ratio = 4 // 2 = 2 - Considering the KV Caches, if P-Worker_i has cache size [2, num_blocksP, block_size, kv_heads, head_dim] - then D-Worker_j has [2, num_blocksD, block_size, kv_heads//tp_ratio, head_dim]. + Considering the KV Caches, if P-Worker_i has cache size [2, num_blocksP, kv_heads, block_size, head_dim] + then D-Worker_j has [2, num_blocksD, kv_heads//tp_ratio, block_size, head_dim]. Mind the "HND" layout format. Assuming num_blocksD >= num_blocksP, D-Worker0 reads from P-Worker0 by preparing the kv_heads//tp_ratio - first heads from all the slots of all the blocks in the case. - D-Worker1 will do the same, but reading the second split along the kv_heads dimension. + first heads from all the slots of all the blocks. D-Worker1 will do the same, but reading the second split + along the kv_heads dimension, and so forth until "tp_ratio" D TP workers have pulled from P-Worker0. - Note that the above will also hold true for the homogeneous TP case. + Note that the above will also hold true for the homogeneous TP case, where tp_ratio evaluates to 1. """ # noqa: E501 - engine_id = nixl_agent_meta.engine_id # TODO re-evaluate refreshing for scaling/recovery if (engine_id in self._remote_agents and \ @@ -663,11 +663,14 @@ def add_remote_agent(self, assert tp_ratio > 0, "Decode TP cannot be smaller than" " prefill TP" - # TODO we should also check hidden_dim and kv precision, they must match remote_block_size = nixl_agent_meta.block_len / (self.slot_size_bytes * tp_ratio) - assert self.block_size == remote_block_size, "Remote P worker with " - "different block size is not supported" + assert self.block_size == remote_block_size, "Remote P worker with \ + different block size is not supported" + + assert nixl_agent_meta.block_len == self.block_len * tp_ratio, "Remote\ + P worker KV layer cache must be of shape \ + [2, N, local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." # Create dst descs and xfer side handles. TP workers have same #blocks. if engine_id in self.dst_num_blocks: @@ -685,18 +688,17 @@ def add_remote_agent(self, if p_remote_rank == remote_rank: self.kv_caches_base_addr[ engine_id] = nixl_agent_meta.kv_caches_base_addr - rank_offset = self.tp_rank % tp_ratio * self.slot_size_bytes + rank_offset = self.rank % tp_ratio * self.block_len # Register all remote blocks, but only the corresponding kv heads. for base_addr in nixl_agent_meta.kv_caches_base_addr: for block_id in range(nixl_agent_meta.num_blocks): block_offset = block_id * nixl_agent_meta.block_len - for slot_idx in range(self.block_size): - # Remote has `tp_ratio` times the kv_heads of local. - slot_offset = slot_idx * self.slot_size_bytes * tp_ratio - addr = base_addr + block_offset + slot_offset - # (addr, len, device id) - blocks_data.append((addr + rank_offset, - self.slot_size_bytes, remote_rank)) + # For each block, grab the heads chunk belonging to rank_i + # of size remote_nheads // tp_ratio, which correspond to + # self.block_len == remote_block_len//tp_ratio bytes. + addr = base_addr + block_offset + rank_offset + # (addr, len, device id) + blocks_data.append((addr, self.block_len, remote_rank)) logger.debug( "Created %s blocks for dst engine %s with remote rank %s and " \ "local rank %s", @@ -774,8 +776,11 @@ def get_finished(self) -> tuple[set[str], set[str]]: return done_sending, done_recving def _get_new_notifs(self) -> set[str]: - """Get req_ids which got a remote xfer message.""" - + """ + Get req_ids which got a remote xfer message. When multiple consumers + are reading from the same producer (heterogeneous TP scenario), wait + for all consumers to be done pulling. + """ notified_req_ids: set[str] = set() for notifs in self.nixl_wrapper.get_new_notifs().values(): for notif in notifs: @@ -947,8 +952,6 @@ def _get_block_descs_ids(self, If layer_idx is provided, we use the region_ids for the given layer. Otherwise, we use all regions. """ - # TODO TP docs - if layer_idx is None: region_ids = range(self.num_regions) else: @@ -970,9 +973,7 @@ def _get_block_descs_ids(self, descs_ids: list[int] = [] for reg_id in region_ids: for block_id in block_ids: - for slot_id in range(self.block_size): - descs_ids.append(reg_id * num_blocks * self.block_size + - block_id * self.block_size + slot_id) + descs_ids.append(reg_id * num_blocks + block_id) return descs_ids diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4a67e37781bf..d784d835ae9f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2119,6 +2119,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_cache_shape = self.attn_backends[i].get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + # kv_cache_shape=(16, 8, 128)=>( 8, 16, 128) dtype = kv_cache_spec.dtype try: kv_cache_stride_order = self.attn_backends[ From 4120a427cd952370356990810a5699d19bc87023 Mon Sep 17 00:00:00 2001 From: nicklucche Date: Thu, 29 May 2025 12:46:13 +0000 Subject: [PATCH 3/7] postpone req abort change Signed-off-by: nicklucche --- vllm/v1/core/sched/scheduler.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index bb21aca5fc0e..32d03b311a4e 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -986,11 +986,6 @@ def _connector_finished( return False, None assert len(self.kv_cache_config.kv_cache_groups ) == 1, "KV connector only supports one KV cache group now" - if (request.status == RequestStatus.FINISHED_ABORTED and \ - request.request_id not in - self.kv_cache_manager.single_type_manager.req_to_blocks): - return False, None - block_ids = self.kv_cache_manager.get_block_ids(request.request_id)[0] return self.connector.request_finished(request, block_ids) From b84b6e279384e02b365c5b90a53697d2501b1139 Mon Sep 17 00:00:00 2001 From: nicklucche Date: Thu, 29 May 2025 16:02:46 +0000 Subject: [PATCH 4/7] working MLA Signed-off-by: nicklucche --- .../nixl_integration/test_accuracy.py | 1 + .../kv_connector/v1/nixl_connector.py | 41 ++++++++++++------- vllm/v1/worker/gpu_model_runner.py | 1 - 3 files changed, 28 insertions(+), 15 deletions(-) diff --git a/tests/v1/kv_connector/nixl_integration/test_accuracy.py b/tests/v1/kv_connector/nixl_integration/test_accuracy.py index 2b2b147ce3e1..e5d66ffeeeb2 100644 --- a/tests/v1/kv_connector/nixl_integration/test_accuracy.py +++ b/tests/v1/kv_connector/nixl_integration/test_accuracy.py @@ -14,6 +14,7 @@ # Model-specific expected values EXPECTED_VALUES = { "Qwen/Qwen3-0.6B": 0.41, + "deepseek-ai/deepseek-vl2-small": 0.59 } SIMPLE_PROMPT = "The best part about working on vLLM is that I got to meet so many people across various different organizations like UCB, Google, and Meta which means", # noqa: E501 diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 0a492e841553..60ed8c0bf79c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -488,11 +488,12 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): kv_elem_size = first_kv_cache.element_size() # TODO(tms): Find a more robust way to detect and handle MLA - use_mla = len(first_kv_cache.shape) == 3 - # FIXME Actual memory layout is NOT the one you expect from - # dims, needs docs - assert not first_kv_cache.is_contiguous() - if use_mla: + self.use_mla = len(first_kv_cache.shape) == 3 + # NOTE (NickLucche) To move blocks efficiently with NIXL, the expected + # KV memory layout is HND, as opposed to the default NHD. Note that it + # will only affects the strides. For MLA instead, we make require no + # such thing and resort to the standard layout. + if self.use_mla: # MLA case. self.num_blocks = first_kv_cache.shape[0] block_rank = 2 # [block_size, latent_dim] @@ -513,8 +514,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # block size in bytes self.block_len = kv_elem_size * math.prod(block_shape) - logger.debug("Registering KV_Caches. use_mla: %s, shape %s", use_mla, - first_kv_cache.shape) + logger.debug("Registering KV_Caches. use_mla: %s, shape %s", + self.use_mla, first_kv_cache.shape) logger.debug("num_blocks: %s, block_shape: %s", self.num_blocks, block_shape) logger.debug("Per layer kv cache size: %s", first_kv_cache.shape) @@ -531,7 +532,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # (roughly 8KB vs 5KB). for cache_or_caches in kv_caches.values(): # Normalize to always be a list of caches - cache_list = [cache_or_caches] if use_mla else cache_or_caches + cache_list = [cache_or_caches] if self.use_mla else cache_or_caches for cache in cache_list: base_addr = cache.data_ptr() region_len = self.num_blocks * self.block_len @@ -643,6 +644,9 @@ def add_remote_agent(self, along the kv_heads dimension, and so forth until "tp_ratio" D TP workers have pulled from P-Worker0. Note that the above will also hold true for the homogeneous TP case, where tp_ratio evaluates to 1. + + Regarding MLA case, the cache is replicated across TP workers so the rank_offset will just always be 0 + so that the whole cache is shared by "tp_ratio" D TP workers. """ # noqa: E501 engine_id = nixl_agent_meta.engine_id # TODO re-evaluate refreshing for scaling/recovery @@ -662,15 +666,23 @@ def add_remote_agent(self, tp_ratio = self._tp_size[self.engine_id] // self._tp_size[engine_id] assert tp_ratio > 0, "Decode TP cannot be smaller than" " prefill TP" + if self.use_mla: + # With MLA the only difference is in the number of blocks. + remote_block_size = nixl_agent_meta.block_len / ( + self.slot_size_bytes) + assert self.block_len == nixl_agent_meta.block_len + else: + remote_block_size = nixl_agent_meta.block_len / ( + self.slot_size_bytes * tp_ratio) + + assert nixl_agent_meta.block_len == self.block_len * tp_ratio, \ + "Remote P worker KV layer cache must be of shape [2, N, \ + local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." - remote_block_size = nixl_agent_meta.block_len / (self.slot_size_bytes * - tp_ratio) assert self.block_size == remote_block_size, "Remote P worker with \ different block size is not supported" - assert nixl_agent_meta.block_len == self.block_len * tp_ratio, "Remote\ - P worker KV layer cache must be of shape \ - [2, N, local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." + assert self.num_blocks >= nixl_agent_meta.num_blocks # Create dst descs and xfer side handles. TP workers have same #blocks. if engine_id in self.dst_num_blocks: @@ -688,7 +700,8 @@ def add_remote_agent(self, if p_remote_rank == remote_rank: self.kv_caches_base_addr[ engine_id] = nixl_agent_meta.kv_caches_base_addr - rank_offset = self.rank % tp_ratio * self.block_len + rank_offset = self.rank % tp_ratio * self.block_len \ + if not self.use_mla else 0 # Register all remote blocks, but only the corresponding kv heads. for base_addr in nixl_agent_meta.kv_caches_base_addr: for block_id in range(nixl_agent_meta.num_blocks): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d784d835ae9f..4a67e37781bf 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2119,7 +2119,6 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_cache_shape = self.attn_backends[i].get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) - # kv_cache_shape=(16, 8, 128)=>( 8, 16, 128) dtype = kv_cache_spec.dtype try: kv_cache_stride_order = self.attn_backends[ From 909823219ecef7d219adbdfaea6b202f229b2b60 Mon Sep 17 00:00:00 2001 From: nicklucche Date: Fri, 30 May 2025 10:29:55 +0000 Subject: [PATCH 5/7] FA stride order for nixl+rebase cruft Signed-off-by: nicklucche --- .../kv_transfer/kv_connector/utils.py | 20 +++++++++++++++- .../kv_connector/v1/nixl_connector.py | 23 ++++++++++--------- vllm/v1/attention/backends/flash_attn.py | 16 +++++++++++++ vllm/worker/worker_base.py | 3 ++- 4 files changed, 49 insertions(+), 13 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index c62444e756cf..949bbd74a630 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -3,11 +3,13 @@ """ KV cache helper for store. """ +import functools + import torch import vllm.envs as envs from vllm import _custom_ops as ops -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_current_vllm_config from vllm.logger import init_logger logger = init_logger(__name__) @@ -90,3 +92,19 @@ def put_kv_to_cache(self, model_executable: torch.nn.Module, keys, values, layer.self_attn.attn._k_scale, layer.self_attn.attn._v_scale, ) + + +@functools.lru_cache +def get_kv_connector_cache_layout(): + vllm_config = get_current_vllm_config() + kv_config = vllm_config.kv_transfer_config + if vllm_config.model_config is None: + logger.warning("Unable to detect current VLLM config. " \ + "Defaulting to NHD kv cache layout.") + else: + use_mla = vllm_config.model_config.use_mla + if not use_mla and kv_config.kv_connector == "NixlConnector": + logger.info("NixlConnector detected. Setting KV cache " \ + "layout to HND for better xfer performance.") + return "HND" + return "NHD" diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 60ed8c0bf79c..5b2b3f1beba9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -32,6 +32,7 @@ from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.request import Request +Transfer = tuple[int, float] GET_META_MSG = b"get_meta_msg" logger = init_logger(__name__) @@ -362,7 +363,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # Map of engine_id -> kv_caches_base_addr. For TP case, each local # rank will still only pull from a single remote TP worker. - self.kv_caches_base_addr: dict[str, list[int]] = dict() + self.kv_caches_base_addr: dict[str, list[int]] = {} # Number of NIXL regions. Currently one region per cache # (so 1 per layer for MLA, otherwise 2 per layer) @@ -372,19 +373,17 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # nixl_prepped_dlist_handle. self.src_xfer_side_handle: int = 0 # Map of engine_id -> nixl_prepped_dlist_handle (int)]. - self.dst_xfer_side_handles: dict[str, int] = dict() + self.dst_xfer_side_handles: dict[str, int] = {} # Map of engine_id -> num_blocks. All ranks in the same deployment will # have the same number of blocks. - self.dst_num_blocks: dict[str, int] = dict() + self.dst_num_blocks: dict[str, int] = {} self._registered_descs: list[Any] = [] # In progress transfers. # [req_id -> list[handle]] - self._recving_transfers: defaultdict[str, - list[tuple[int, - float]]] = defaultdict( - list[Any]) + self._recving_transfers: defaultdict[ + str, list[Transfer]] = defaultdict(list) # Complete transfer tracker. Used by the rank 0 to track finished # transactions on ranks 1 to N-1. @@ -467,7 +466,7 @@ def handshake(path: str, rank: int) -> NixlAgentMetadata: return metadata # Handshake with remote agent-rank0 first to get the tp_size of remote - path = f"tcp://{host}:{port}" + path = make_zmq_path("tcp", host, port) logger.debug("Querying master rank metadata on path: %s", path) metadata = handshake(path, 0) @@ -476,7 +475,7 @@ def handshake(path: str, rank: int) -> NixlAgentMetadata: tp_rate = self._tp_size[self.engine_id] // metadata.tp_size p_remote_rank = self.tp_rank // tp_rate if p_remote_rank > 0: - path = f"tcp://{host}:{port + p_remote_rank}" + path = make_zmq_path("tcp", host, port + p_remote_rank) logger.debug("Querying metadata on path: %s at remote rank %s", path, p_remote_rank) _ = handshake(path, p_remote_rank) @@ -581,7 +580,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): block_offset = block_id * self.block_len addr = base_addr + block_offset # (addr, len, device id) - blocks_data.append((addr, self.block_len, self.rank)) + blocks_data.append((addr, self.block_len, self.tp_rank)) logger.debug("Created %s blocks for src engine %s and rank %s", len(blocks_data), self.engine_id, self.tp_rank) @@ -663,6 +662,8 @@ def add_remote_agent(self, # Number of D TP workers reading from a single P TP worker. This is # 1 when P and D `--tensor-parallel-size` match. + assert self._tp_size[self.engine_id] % self._tp_size[engine_id] == 0, \ + "Local TP size must be divisible by remote TP size." tp_ratio = self._tp_size[self.engine_id] // self._tp_size[engine_id] assert tp_ratio > 0, "Decode TP cannot be smaller than" " prefill TP" @@ -700,7 +701,7 @@ def add_remote_agent(self, if p_remote_rank == remote_rank: self.kv_caches_base_addr[ engine_id] = nixl_agent_meta.kv_caches_base_addr - rank_offset = self.rank % tp_ratio * self.block_len \ + rank_offset = self.tp_rank % tp_ratio * self.block_len \ if not self.use_mla else 0 # Register all remote blocks, but only the corresponding kv heads. for base_addr in nixl_agent_meta.kv_caches_base_addr: diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index a9f748d026f4..91a7c43cd8d8 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -16,6 +16,8 @@ from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8, get_flash_attn_version) from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.distributed.kv_transfer.kv_connector.utils import ( + get_kv_connector_cache_layout) from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import cdiv @@ -70,6 +72,20 @@ def get_kv_cache_shape( raise ValueError("Block size must be a multiple of 16.") return (2, num_blocks, block_size, num_kv_heads, head_size) + @staticmethod + def get_kv_cache_stride_order() -> tuple[int, ...]: + # NOTE When running disaggregated PD with NIXL, HND layout is used for + # faster transfer. `stride_order` indicates the permutation that gets + # us from `get_kv_cache_shape` to the actual memory layout we want. + cache_layout = get_kv_connector_cache_layout() + if cache_layout == "NHD": + stride_order = (0, 1, 2, 3, 4) + elif cache_layout == "HND": + stride_order = (0, 1, 3, 2, 4) + else: + raise ValueError("Unknown cache layout format %s.", cache_layout) + return stride_order + @dataclass class FlashAttentionMetadata: diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index db1ca2d8ff30..0b37caa71669 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -597,7 +597,8 @@ def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None: def initialize_from_config(self, kv_cache_configs: List[Any]) -> None: kv_cache_config = kv_cache_configs[self.rpc_rank] - self.worker.initialize_from_config(kv_cache_config) # type: ignore + with set_current_vllm_config(self.vllm_config): + self.worker.initialize_from_config(kv_cache_config) # type: ignore def init_device(self): with set_current_vllm_config(self.vllm_config): From 112717d776b525ae06505c0b234b0b370ee73a5f Mon Sep 17 00:00:00 2001 From: nicklucche Date: Wed, 4 Jun 2025 14:41:07 +0000 Subject: [PATCH 6/7] remove get_kv_connector_cache_layout caching Signed-off-by: nicklucche --- vllm/distributed/kv_transfer/kv_connector/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 949bbd74a630..b9bed06d791c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -3,7 +3,6 @@ """ KV cache helper for store. """ -import functools import torch @@ -94,7 +93,6 @@ def put_kv_to_cache(self, model_executable: torch.nn.Module, keys, values, ) -@functools.lru_cache def get_kv_connector_cache_layout(): vllm_config = get_current_vllm_config() kv_config = vllm_config.kv_transfer_config From 3d83f79e087501266530468727a87a24effa8994 Mon Sep 17 00:00:00 2001 From: nicklucche Date: Wed, 4 Jun 2025 20:51:16 +0000 Subject: [PATCH 7/7] address review Signed-off-by: nicklucche --- .../kv_connector/v1/nixl_connector.py | 42 +++++++++---------- 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 5b2b3f1beba9..400fa29c90f0 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -32,7 +32,7 @@ from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.request import Request -Transfer = tuple[int, float] +Transfer = tuple[int, float] # (xfer_handle, start_time) GET_META_MSG = b"get_meta_msg" logger = init_logger(__name__) @@ -382,8 +382,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # In progress transfers. # [req_id -> list[handle]] - self._recving_transfers: defaultdict[ - str, list[Transfer]] = defaultdict(list) + self._recving_transfers = defaultdict[str, list[Transfer]](list) # Complete transfer tracker. Used by the rank 0 to track finished # transactions on ranks 1 to N-1. @@ -407,8 +406,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self._tp_size: dict[str, int] = {self.engine_id: self.world_size} # With heterogeneous TP, P must wait for all assigned D TP workers to # finish reading before safely freeing the blocks. - self.consumer_notification_counts_by_req: dict[str, - int] = defaultdict(int) + self.consumer_notification_counts_by_req = defaultdict[str, int](int) @staticmethod def _nixl_handshake_listener(metadata: NixlAgentMetadata, @@ -472,8 +470,8 @@ def handshake(path: str, rank: int) -> NixlAgentMetadata: # Handshake only with the other TP remote the current local rank will # pull from. With homogeneous TP it happens to be the same rank_i. - tp_rate = self._tp_size[self.engine_id] // metadata.tp_size - p_remote_rank = self.tp_rank // tp_rate + tp_ratio = self._tp_size[self.engine_id] // metadata.tp_size + p_remote_rank = self.tp_rank // tp_ratio if p_remote_rank > 0: path = make_zmq_path("tcp", host, port + p_remote_rank) logger.debug("Querying metadata on path: %s at remote rank %s", @@ -608,19 +606,19 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, - remote_rank: int = 0): + remote_tp_rank: int = 0): """ Add the remote NIXL agent and prepare the descriptors for reading cache blocks from remote. - In particular, handle both homogeneous and heterogeneous TP. The latter + In particular, handle both homogeneous and heterogeneous TP. The former requires local rank_i to read from remote rank_i. - The former, assuming D.world_size > P.world_size, requires that two or + The latter, assuming D.world_size > P.world_size, requires that two or more local TP worker share the xfer from a single TP worker. Here's an example: - rank_offset p_remote_rank + rank_offset p_remote_tp_rank (kv split no) -------------------------------- 0 0 Worker0 ---- 1st half of KV ----> Worker0 [ KV Cache ] @@ -649,15 +647,15 @@ def add_remote_agent(self, """ # noqa: E501 engine_id = nixl_agent_meta.engine_id # TODO re-evaluate refreshing for scaling/recovery - if (engine_id in self._remote_agents and \ - remote_rank in self._remote_agents[engine_id]): + if remote_tp_rank in self._remote_agents.get(engine_id, ()): return if engine_id in self._tp_size: assert self._tp_size[engine_id] == nixl_agent_meta.tp_size - self._tp_size[engine_id] = nixl_agent_meta.tp_size + else: + self._tp_size[engine_id] = nixl_agent_meta.tp_size self._remote_agents[engine_id][ - remote_rank] = self.nixl_wrapper.add_remote_agent( + remote_tp_rank] = self.nixl_wrapper.add_remote_agent( nixl_agent_meta.agent_metadata) # Number of D TP workers reading from a single P TP worker. This is @@ -688,17 +686,17 @@ def add_remote_agent(self, # Create dst descs and xfer side handles. TP workers have same #blocks. if engine_id in self.dst_num_blocks: assert self.dst_num_blocks[engine_id] == nixl_agent_meta.num_blocks - - self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks + else: + self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks blocks_data = [] # With homogeneous TP, D pulls the whole kv cache from corresponding # rank. With heterogeneous TP, prepare the descriptors by splitting the # P KV cache along kv_head dim, of D worker's kv_head size (D>P). # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. - p_remote_rank = self.tp_rank // tp_ratio + p_remote_tp_rank = self.tp_rank // tp_ratio # Only register the remote's descriptors if current rank pulls from it. - if p_remote_rank == remote_rank: + if p_remote_tp_rank == remote_tp_rank: self.kv_caches_base_addr[ engine_id] = nixl_agent_meta.kv_caches_base_addr rank_offset = self.tp_rank % tp_ratio * self.block_len \ @@ -712,17 +710,17 @@ def add_remote_agent(self, # self.block_len == remote_block_len//tp_ratio bytes. addr = base_addr + block_offset + rank_offset # (addr, len, device id) - blocks_data.append((addr, self.block_len, remote_rank)) + blocks_data.append((addr, self.block_len, remote_tp_rank)) logger.debug( "Created %s blocks for dst engine %s with remote rank %s and " \ "local rank %s", - len(blocks_data), engine_id, remote_rank, self.tp_rank) + len(blocks_data), engine_id, remote_tp_rank, self.tp_rank) # Register with NIXL. descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") self.dst_xfer_side_handles[ engine_id] = self.nixl_wrapper.prep_xfer_dlist( - self._remote_agents[engine_id][remote_rank], descs) + self._remote_agents[engine_id][remote_tp_rank], descs) def get_finished(self) -> tuple[set[str], set[str]]: """