diff --git a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh index c17784e0a263..b48655d80eef 100755 --- a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh @@ -8,7 +8,9 @@ MODELS=( # Number of prefill and decode instances to create NUM_PREFILL_INSTANCES=${NUM_PREFILL_INSTANCES:-1} # Default to 1 -NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-2} # Default to 2 +NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-1} # Default to 1 +PREFILLER_TP_SIZE=${PREFILLER_TP_SIZE:-1} +DECODER_TP_SIZE=${DECODER_TP_SIZE:-1} # Find the git repository root directory GIT_ROOT=$(git rev-parse --show-toplevel) @@ -74,9 +76,10 @@ run_tests_for_model() { for i in $(seq 0 $((NUM_PREFILL_INSTANCES-1))); do # Calculate GPU ID - we'll distribute across available GPUs GPU_ID=$((i % $(get_num_gpus))) + # Calculate port number (base port + instance number) PORT=$((8100 + i)) - # Calculate side channel port + # Calculate side channel port. Avoid clash with with TP workers. SIDE_CHANNEL_PORT=$((5559 + i)) echo "Starting prefill instance $i on GPU $GPU_ID, port $PORT" @@ -87,6 +90,7 @@ run_tests_for_model() { --enforce-eager \ --disable-log-requests \ --gpu-memory-utilization 0.2 \ + --tensor-parallel-size $PREFILLER_TP_SIZE \ --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'" if [ -n "$model_args" ]; then @@ -109,7 +113,7 @@ run_tests_for_model() { # Calculate port number (base port + instance number) PORT=$((8200 + i)) # Calculate side channel port - SIDE_CHANNEL_PORT=$((5659 + i)) + SIDE_CHANNEL_PORT=$((5659 + i * $DECODER_TP_SIZE)) echo "Starting decode instance $i on GPU $GPU_ID, port $PORT" @@ -119,6 +123,7 @@ run_tests_for_model() { --enforce-eager \ --disable-log-requests \ --gpu-memory-utilization 0.2 \ + --tensor-parallel-size $DECODER_TP_SIZE \ --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'" if [ -n "$model_args" ]; then diff --git a/tests/v1/kv_connector/nixl_integration/test_accuracy.py b/tests/v1/kv_connector/nixl_integration/test_accuracy.py index 2b2b147ce3e1..e5d66ffeeeb2 100644 --- a/tests/v1/kv_connector/nixl_integration/test_accuracy.py +++ b/tests/v1/kv_connector/nixl_integration/test_accuracy.py @@ -14,6 +14,7 @@ # Model-specific expected values EXPECTED_VALUES = { "Qwen/Qwen3-0.6B": 0.41, + "deepseek-ai/deepseek-vl2-small": 0.59 } SIMPLE_PROMPT = "The best part about working on vLLM is that I got to meet so many people across various different organizations like UCB, Google, and Meta which means", # noqa: E501 diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index c62444e756cf..b9bed06d791c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -3,11 +3,12 @@ """ KV cache helper for store. """ + import torch import vllm.envs as envs from vllm import _custom_ops as ops -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_current_vllm_config from vllm.logger import init_logger logger = init_logger(__name__) @@ -90,3 +91,18 @@ def put_kv_to_cache(self, model_executable: torch.nn.Module, keys, values, layer.self_attn.attn._k_scale, layer.self_attn.attn._v_scale, ) + + +def get_kv_connector_cache_layout(): + vllm_config = get_current_vllm_config() + kv_config = vllm_config.kv_transfer_config + if vllm_config.model_config is None: + logger.warning("Unable to detect current VLLM config. " \ + "Defaulting to NHD kv cache layout.") + else: + use_mla = vllm_config.model_config.use_mla + if not use_mla and kv_config.kv_connector == "NixlConnector": + logger.info("NixlConnector detected. Setting KV cache " \ + "layout to HND for better xfer performance.") + return "HND" + return "NHD" diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index fd22280126d6..400fa29c90f0 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -32,6 +32,7 @@ from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.request import Request +Transfer = tuple[int, float] # (xfer_handle, start_time) GET_META_MSG = b"get_meta_msg" logger = init_logger(__name__) @@ -54,6 +55,8 @@ class NixlAgentMetadata( agent_metadata: bytes kv_caches_base_addr: list[int] num_blocks: int + tp_size: int + block_len: int @dataclass @@ -331,10 +334,14 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): logger.info("Initializing NIXL wrapper") logger.info("Initializing NIXL worker %s", engine_id) + # Config. + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + # Agent. self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None) - # Map of engine_id -> agent_name. - self._remote_agents: dict[str, str] = {} + # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}. + self._remote_agents: dict[str, dict[int, str]] = defaultdict(dict) # NIXL handshake port. # NOTE(rob): Within a DP group, each DP rank gets its own @@ -354,7 +361,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # KV Caches and nixl tracking data. self.kv_caches: dict[str, torch.Tensor] = {} - # Map of engine_id -> kv_caches_base_addr + # Map of engine_id -> kv_caches_base_addr. For TP case, each local + # rank will still only pull from a single remote TP worker. self.kv_caches_base_addr: dict[str, list[int]] = {} # Number of NIXL regions. Currently one region per cache @@ -362,19 +370,19 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.num_regions = 0 self.num_layers = 0 - # nixl_prepped_dlist_handle (int). + # nixl_prepped_dlist_handle. self.src_xfer_side_handle: int = 0 # Map of engine_id -> nixl_prepped_dlist_handle (int)]. self.dst_xfer_side_handles: dict[str, int] = {} - # Map of engine_id -> num_blocks. + # Map of engine_id -> num_blocks. All ranks in the same deployment will + # have the same number of blocks. self.dst_num_blocks: dict[str, int] = {} self._registered_descs: list[Any] = [] # In progress transfers. # [req_id -> list[handle]] - self._recving_transfers: defaultdict[str, list[Any]] = defaultdict( - list[Any]) + self._recving_transfers = defaultdict[str, list[Transfer]](list) # Complete transfer tracker. Used by the rank 0 to track finished # transactions on ranks 1 to N-1. @@ -395,6 +403,11 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # List of block window sizes for each layer for local attention self.block_window_per_layer: list[Optional[int]] = [] + self._tp_size: dict[str, int] = {self.engine_id: self.world_size} + # With heterogeneous TP, P must wait for all assigned D TP workers to + # finish reading before safely freeing the blocks. + self.consumer_notification_counts_by_req = defaultdict[str, int](int) + @staticmethod def _nixl_handshake_listener(metadata: NixlAgentMetadata, ready_event: threading.Event, base_port: int, @@ -426,27 +439,44 @@ def _nixl_handshake(self, host: str, port: int): """Do a NIXL handshake with a remote instance.""" start_time = time.perf_counter() - # NOTE(rob): we need each tp_rank to have a unique port. - # This is a hack to keep us moving. We will switch when - # we switch to HTTP-based NIXL metadata exchange. - path = make_zmq_path("tcp", host, port + self.tp_rank) - logger.debug("Querying metadata on path: %s", path) - with zmq_ctx(zmq.REQ, path) as sock: - # Send query for the request. - sock.send(GET_META_MSG) - metadata_bytes = sock.recv() - decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) - metadata = decoder.decode(metadata_bytes) - got_metadata_time = time.perf_counter() - # Register Remote agent. - self.add_remote_agent(metadata) - setup_agent_time = time.perf_counter() + # NOTE(rob): we need each rank to have a unique port. This is + # a hack to keep us moving. We will switch when moving to etcd + # or where we have a single ZMQ socket in the scheduler. - logger.debug("NIXL handshake: get metadata took: %s", - got_metadata_time - start_time) - logger.debug("NIXL handshake: add agent took: %s", - setup_agent_time - got_metadata_time) + def handshake(path: str, rank: int) -> NixlAgentMetadata: + # Send query for the request. + with zmq_ctx(zmq.REQ, path) as sock: + sock.send(GET_META_MSG) + metadata_bytes = sock.recv() + decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) + metadata = decoder.decode(metadata_bytes) + got_metadata_time = time.perf_counter() + + # Register Remote agent. + self.add_remote_agent(metadata, rank) + setup_agent_time = time.perf_counter() + + logger.debug("NIXL handshake: get metadata took: %s", + got_metadata_time - start_time) + logger.debug("NIXL handshake: add agent took: %s", + setup_agent_time - got_metadata_time) + return metadata + + # Handshake with remote agent-rank0 first to get the tp_size of remote + path = make_zmq_path("tcp", host, port) + logger.debug("Querying master rank metadata on path: %s", path) + metadata = handshake(path, 0) + + # Handshake only with the other TP remote the current local rank will + # pull from. With homogeneous TP it happens to be the same rank_i. + tp_ratio = self._tp_size[self.engine_id] // metadata.tp_size + p_remote_rank = self.tp_rank // tp_ratio + if p_remote_rank > 0: + path = make_zmq_path("tcp", host, port + p_remote_rank) + logger.debug("Querying metadata on path: %s at remote rank %s", + path, p_remote_rank) + _ = handshake(path, p_remote_rank) def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in nixl.""" @@ -455,24 +485,34 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): kv_elem_size = first_kv_cache.element_size() # TODO(tms): Find a more robust way to detect and handle MLA - use_mla = len(first_kv_cache.shape) == 3 - if use_mla: + self.use_mla = len(first_kv_cache.shape) == 3 + # NOTE (NickLucche) To move blocks efficiently with NIXL, the expected + # KV memory layout is HND, as opposed to the default NHD. Note that it + # will only affects the strides. For MLA instead, we make require no + # such thing and resort to the standard layout. + if self.use_mla: # MLA case. self.num_blocks = first_kv_cache.shape[0] block_rank = 2 # [block_size, latent_dim] block_shape = first_kv_cache.shape[-block_rank:] + block_size, kv_latent_dim = block_shape + self.slot_size_bytes = kv_elem_size * kv_latent_dim else: - # [2 (k and v), num_blocks, ...] + # [2 (k and v), num_blocks, block_size, kv_heads, head_dim] self.num_blocks = first_kv_cache.shape[1] block_rank = 3 # [block_size, kv_heads, head_dim] block_shape = first_kv_cache.shape[-block_rank:] - + block_size, n_kv_heads, head_dim = block_shape + # head size in bytes. + self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim + assert block_size == self.block_size # TODO(tms): self.block_len needs to be per-layer for sliding window, # hybrid attn, etc + # block size in bytes self.block_len = kv_elem_size * math.prod(block_shape) - logger.debug("Registering KV_Caches. use_mla: %s, shape %s", use_mla, - first_kv_cache.shape) + logger.debug("Registering KV_Caches. use_mla: %s, shape %s", + self.use_mla, first_kv_cache.shape) logger.debug("num_blocks: %s, block_shape: %s", self.num_blocks, block_shape) logger.debug("Per layer kv cache size: %s", first_kv_cache.shape) @@ -489,7 +529,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # (roughly 8KB vs 5KB). for cache_or_caches in kv_caches.values(): # Normalize to always be a list of caches - cache_list = [cache_or_caches] if use_mla else cache_or_caches + cache_list = [cache_or_caches] if self.use_mla else cache_or_caches for cache in cache_list: base_addr = cache.data_ptr() region_len = self.num_blocks * self.block_len @@ -524,16 +564,37 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): logger.debug("Registering descs: %s", caches_data) self.nixl_wrapper.register_memory(descs) logger.debug("Done registering descs") - self._registered_descs.append(descs) + # Register local/src descr for NIXL xfer. + blocks_data = [] + for base_addr in self.kv_caches_base_addr[self.engine_id]: + # NOTE With heter-TP, more blocks are prepared than what are + # needed as self.num_blocks >= nixl_agent_meta.num_blocks. We + # could create fewer, but then _get_block_descs_ids needs to + # select agent_meta.num_blocks instead of self.num_blocks for + # local descr, and that makes handling regular flow less clean. + for block_id in range(self.num_blocks): + block_offset = block_id * self.block_len + addr = base_addr + block_offset + # (addr, len, device id) + blocks_data.append((addr, self.block_len, self.tp_rank)) + logger.debug("Created %s blocks for src engine %s and rank %s", + len(blocks_data), self.engine_id, self.tp_rank) + + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") + # NIXL_INIT_AGENT to be used for preparations of local descs. + self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( + "NIXL_INIT_AGENT", descs) + # After KV Caches registered, listen for new connections. metadata = NixlAgentMetadata( engine_id=self.engine_id, agent_metadata=self.nixl_wrapper.get_agent_metadata(), kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], num_blocks=self.num_blocks, - ) + tp_size=self.world_size, + block_len=self.block_len) ready_event = threading.Event() self._nixl_handshake_listener_t = threading.Thread( target=self._nixl_handshake_listener, @@ -543,50 +604,123 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self._nixl_handshake_listener_t.start() ready_event.wait() - def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata): + def add_remote_agent(self, + nixl_agent_meta: NixlAgentMetadata, + remote_tp_rank: int = 0): + """ + Add the remote NIXL agent and prepare the descriptors for reading cache + blocks from remote. + + In particular, handle both homogeneous and heterogeneous TP. The former + requires local rank_i to read from remote rank_i. + The latter, assuming D.world_size > P.world_size, requires that two or + more local TP worker share the xfer from a single TP worker. + + Here's an example: + + rank_offset p_remote_tp_rank + (kv split no) + -------------------------------- + 0 0 Worker0 ---- 1st half of KV ----> Worker0 [ KV Cache ] + / + 1 0 Worker1 ---- 2nd half of KV -----/ + + 0 1 Worker2 ---- 1st half of KV ----> Worker1 [ KV Cache ] + / + 1 1 Worker3 ---- 2nd half of KV -----/ + + + Decoder TP workers Prefix TP workers + (world_size=4) (world_size=2) + tp_ratio = 4 // 2 = 2 + + Considering the KV Caches, if P-Worker_i has cache size [2, num_blocksP, kv_heads, block_size, head_dim] + then D-Worker_j has [2, num_blocksD, kv_heads//tp_ratio, block_size, head_dim]. Mind the "HND" layout format. + Assuming num_blocksD >= num_blocksP, D-Worker0 reads from P-Worker0 by preparing the kv_heads//tp_ratio + first heads from all the slots of all the blocks. D-Worker1 will do the same, but reading the second split + along the kv_heads dimension, and so forth until "tp_ratio" D TP workers have pulled from P-Worker0. + + Note that the above will also hold true for the homogeneous TP case, where tp_ratio evaluates to 1. + + Regarding MLA case, the cache is replicated across TP workers so the rank_offset will just always be 0 + so that the whole cache is shared by "tp_ratio" D TP workers. + """ # noqa: E501 engine_id = nixl_agent_meta.engine_id - assert engine_id != self.engine_id, "Conflict engine id found!" - if engine_id in self._remote_agents: + # TODO re-evaluate refreshing for scaling/recovery + if remote_tp_rank in self._remote_agents.get(engine_id, ()): return - self._remote_agents[engine_id] = self.nixl_wrapper.add_remote_agent( - nixl_agent_meta.agent_metadata) - self.kv_caches_base_addr[ - engine_id] = nixl_agent_meta.kv_caches_base_addr + if engine_id in self._tp_size: + assert self._tp_size[engine_id] == nixl_agent_meta.tp_size + else: + self._tp_size[engine_id] = nixl_agent_meta.tp_size + self._remote_agents[engine_id][ + remote_tp_rank] = self.nixl_wrapper.add_remote_agent( + nixl_agent_meta.agent_metadata) + + # Number of D TP workers reading from a single P TP worker. This is + # 1 when P and D `--tensor-parallel-size` match. + assert self._tp_size[self.engine_id] % self._tp_size[engine_id] == 0, \ + "Local TP size must be divisible by remote TP size." + tp_ratio = self._tp_size[self.engine_id] // self._tp_size[engine_id] + assert tp_ratio > 0, "Decode TP cannot be smaller than" + " prefill TP" + if self.use_mla: + # With MLA the only difference is in the number of blocks. + remote_block_size = nixl_agent_meta.block_len / ( + self.slot_size_bytes) + assert self.block_len == nixl_agent_meta.block_len + else: + remote_block_size = nixl_agent_meta.block_len / ( + self.slot_size_bytes * tp_ratio) + + assert nixl_agent_meta.block_len == self.block_len * tp_ratio, \ + "Remote P worker KV layer cache must be of shape [2, N, \ + local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." - # Create src descs and xfer side handles. - blocks_data = [] - for base_addr in self.kv_caches_base_addr[self.engine_id]: - for block_id in range(self.num_blocks): - block_offset = block_id * self.block_len - # (addr, len, device id) - blocks_data.append( - (base_addr + block_offset, self.block_len, self.tp_rank)) - logger.debug("Created %s blocks for src engine %s and tp_rank %s", - len(blocks_data), self.engine_id, self.tp_rank) + assert self.block_size == remote_block_size, "Remote P worker with \ + different block size is not supported" - # Register with NIXL. - descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") - self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( - "NIXL_INIT_AGENT", descs) + assert self.num_blocks >= nixl_agent_meta.num_blocks + + # Create dst descs and xfer side handles. TP workers have same #blocks. + if engine_id in self.dst_num_blocks: + assert self.dst_num_blocks[engine_id] == nixl_agent_meta.num_blocks + else: + self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks - # Create dst descs and xfer side handles. - self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks blocks_data = [] - for base_addr in self.kv_caches_base_addr[engine_id]: - for block_id in range(nixl_agent_meta.num_blocks): - block_offset = block_id * self.block_len - # (addr, len, device id) - blocks_data.append( - (base_addr + block_offset, self.block_len, self.tp_rank)) - logger.debug("Created %s blocks for dst engine %s and tp_rank %s", - len(blocks_data), engine_id, self.tp_rank) + # With homogeneous TP, D pulls the whole kv cache from corresponding + # rank. With heterogeneous TP, prepare the descriptors by splitting the + # P KV cache along kv_head dim, of D worker's kv_head size (D>P). + # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. + p_remote_tp_rank = self.tp_rank // tp_ratio + # Only register the remote's descriptors if current rank pulls from it. + if p_remote_tp_rank == remote_tp_rank: + self.kv_caches_base_addr[ + engine_id] = nixl_agent_meta.kv_caches_base_addr + rank_offset = self.tp_rank % tp_ratio * self.block_len \ + if not self.use_mla else 0 + # Register all remote blocks, but only the corresponding kv heads. + for base_addr in nixl_agent_meta.kv_caches_base_addr: + for block_id in range(nixl_agent_meta.num_blocks): + block_offset = block_id * nixl_agent_meta.block_len + # For each block, grab the heads chunk belonging to rank_i + # of size remote_nheads // tp_ratio, which correspond to + # self.block_len == remote_block_len//tp_ratio bytes. + addr = base_addr + block_offset + rank_offset + # (addr, len, device id) + blocks_data.append((addr, self.block_len, remote_tp_rank)) + logger.debug( + "Created %s blocks for dst engine %s with remote rank %s and " \ + "local rank %s", + len(blocks_data), engine_id, remote_tp_rank, self.tp_rank) - # Register with NIXL. - descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") - self.dst_xfer_side_handles[ - engine_id] = self.nixl_wrapper.prep_xfer_dlist( - self._remote_agents[engine_id], descs) + # Register with NIXL. + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") + self.dst_xfer_side_handles[ + engine_id] = self.nixl_wrapper.prep_xfer_dlist( + self._remote_agents[engine_id][remote_tp_rank], descs) def get_finished(self) -> tuple[set[str], set[str]]: """ @@ -654,16 +788,25 @@ def get_finished(self) -> tuple[set[str], set[str]]: return done_sending, done_recving def _get_new_notifs(self) -> set[str]: - """Get req_ids which got a remote xfer message.""" - + """ + Get req_ids which got a remote xfer message. When multiple consumers + are reading from the same producer (heterogeneous TP scenario), wait + for all consumers to be done pulling. + """ notified_req_ids: set[str] = set() - for req_ids in self.nixl_wrapper.get_new_notifs().values(): - for req_id in req_ids: - assert req_id not in notified_req_ids - notified_req_ids.add(req_id.decode("utf-8")) + for notifs in self.nixl_wrapper.get_new_notifs().values(): + for notif in notifs: + req_id, tp_ratio = notif.decode("utf-8").rsplit(":", 1) + self.consumer_notification_counts_by_req[req_id] += 1 + # Wait all consumers (D) to be done reading before freeing. + if self.consumer_notification_counts_by_req[req_id] == int( + tp_ratio): + notified_req_ids.add(req_id) + del self.consumer_notification_counts_by_req[req_id] return notified_req_ids - def _pop_done_transfers(self, transfers: dict[str, list[int]]) -> set[str]: + def _pop_done_transfers( + self, transfers: dict[str, list[tuple[int, float]]]) -> set[str]: """ Pop completed xfers by checking for DONE state. Args: @@ -673,23 +816,17 @@ def _pop_done_transfers(self, transfers: dict[str, list[int]]) -> set[str]: """ done_req_ids: set[str] = set() for req_id, handles in list(transfers.items()): - running_reqs = [] - for handle in handles: + for handle, xfer_stime in handles: xfer_state = self.nixl_wrapper.check_xfer_state(handle) if xfer_state == "DONE": - # TODO ptarasiewicz: why abort is throwing errors? - # self.nixl_wrapper.release_xfer_handle(handle) + self.nixl_wrapper.release_xfer_handle(handle) + done_req_ids.add(req_id) + del transfers[req_id] + elif xfer_state == "PROC": continue - if xfer_state == "PROC": - running_reqs.append(handle) else: raise RuntimeError("Transfer failed with state %s", xfer_state) - if len(running_reqs) == 0: - done_req_ids.add(req_id) - del transfers[req_id] - else: - transfers[req_id] = running_reqs return done_req_ids def start_load_kv(self, metadata: NixlConnectorMetadata): @@ -735,13 +872,19 @@ def _read_blocks( # saturate IB with heterogeneous TP sizes. We should remove the staging # blocks until we are ready. + # Number of D TP workers that will read from dst P. Propagate tp_ratio + # on notification so that dst worker can wait before freeing blocks. + tp_ratio = self._tp_size[ + self.engine_id] // self._tp_size[dst_engine_id] + notif_id = f"{request_id}:{tp_ratio}".encode() + # Full prefix cache hit: do not need to read remote blocks, # just notify P worker that we have the blocks we need. num_local_blocks = len(local_block_ids) if num_local_blocks == 0: - agent_name = self._remote_agents[dst_engine_id] - self.nixl_wrapper.send_notif(agent_name, - notif_msg=request_id.encode("utf-8")) + remote_rank = self.tp_rank // tp_ratio + agent_name = self._remote_agents[dst_engine_id][remote_rank] + self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id) return # Partial prefix cache hit: just read uncomputed blocks. @@ -754,6 +897,10 @@ def _read_blocks( local_xfer_side_handle = self.src_xfer_side_handle remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id] + # NOTE (nicolo) With homogeneous TP, each TP worker loads KV from + # corresponding rank. With heterogeneous TP, fixing D>P, the D tp + # workers will issue xfers to parts of the P worker remote kv caches. + # Get descs ids. local_block_descs_ids: list[int] = [] remote_block_descs_ids: list[int] = [] @@ -797,14 +944,16 @@ def _read_blocks( local_block_descs_ids, remote_xfer_side_handle, remote_block_descs_ids, - notif_msg=request_id.encode("utf-8"), + notif_msg=notif_id, ) # Begin async xfer. self.nixl_wrapper.transfer(handle) # Use handle to check completion in future step(). - self._recving_transfers[request_id].append(handle) + # TODO (NickLucche) surface xfer elapsed time + self._recving_transfers[request_id].append( + (handle, time.perf_counter())) def _get_block_descs_ids(self, engine_id: str, @@ -815,7 +964,6 @@ def _get_block_descs_ids(self, If layer_idx is provided, we use the region_ids for the given layer. Otherwise, we use all regions. """ - if layer_idx is None: region_ids = range(self.num_regions) else: diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index a9f748d026f4..91a7c43cd8d8 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -16,6 +16,8 @@ from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8, get_flash_attn_version) from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.distributed.kv_transfer.kv_connector.utils import ( + get_kv_connector_cache_layout) from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import cdiv @@ -70,6 +72,20 @@ def get_kv_cache_shape( raise ValueError("Block size must be a multiple of 16.") return (2, num_blocks, block_size, num_kv_heads, head_size) + @staticmethod + def get_kv_cache_stride_order() -> tuple[int, ...]: + # NOTE When running disaggregated PD with NIXL, HND layout is used for + # faster transfer. `stride_order` indicates the permutation that gets + # us from `get_kv_cache_shape` to the actual memory layout we want. + cache_layout = get_kv_connector_cache_layout() + if cache_layout == "NHD": + stride_order = (0, 1, 2, 3, 4) + elif cache_layout == "HND": + stride_order = (0, 1, 3, 2, 4) + else: + raise ValueError("Unknown cache layout format %s.", cache_layout) + return stride_order + @dataclass class FlashAttentionMetadata: diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index db1ca2d8ff30..0b37caa71669 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -597,7 +597,8 @@ def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None: def initialize_from_config(self, kv_cache_configs: List[Any]) -> None: kv_cache_config = kv_cache_configs[self.rpc_rank] - self.worker.initialize_from_config(kv_cache_config) # type: ignore + with set_current_vllm_config(self.vllm_config): + self.worker.initialize_from_config(kv_cache_config) # type: ignore def init_device(self): with set_current_vllm_config(self.vllm_config):