@@ -62,7 +62,6 @@ class NixlAgentMetadata(
6262 agent_metadata : bytes
6363 kv_caches_base_addr : list [int ]
6464 num_blocks : int
65- tp_size : int
6665 block_len : int
6766 attn_backend_name : str
6867
@@ -73,7 +72,8 @@ class ReqMeta:
7372 remote_block_ids : list [int ]
7473 remote_host : str
7574 remote_port : int
76- remote_engine_id : EngineId
75+ remote_engine_id : str
76+ tp_size : int
7777
7878
7979class NixlConnectorMetadata (KVConnectorMetadata ):
@@ -93,6 +93,8 @@ def add_new_req(
9393 remote_engine_id = kv_transfer_params ["remote_engine_id" ],
9494 remote_host = kv_transfer_params ["remote_host" ],
9595 remote_port = kv_transfer_params ["remote_port" ],
96+ # P workers don't need to receive tp_size from proxy here.
97+ tp_size = kv_transfer_params .get ("tp_size" , 1 ),
9698 )
9799
98100
@@ -330,7 +332,7 @@ def request_finished(
330332 remote_engine_id = self .engine_id ,
331333 remote_host = self .side_channel_host ,
332334 remote_port = self .side_channel_port ,
333- )
335+ tp_size = self . vllm_config . parallel_config . tensor_parallel_size )
334336
335337
336338class NixlConnectorWorker :
@@ -473,7 +475,8 @@ def _nixl_handshake_listener(metadata: NixlAgentMetadata,
473475 "Connection listener got unexpected message %s" , msg )
474476 sock .send_multipart ((identity , b"" , encoded_data ))
475477
476- def _nixl_handshake (self , host : str , port : int ) -> dict [int , str ]:
478+ def _nixl_handshake (self , host : str , port : int ,
479+ remote_tp_size : int ) -> dict [int , str ]:
477480 """Do a NIXL handshake with a remote instance."""
478481
479482 start_time = time .perf_counter ()
@@ -482,7 +485,7 @@ def _nixl_handshake(self, host: str, port: int) -> dict[int, str]:
482485 # a hack to keep us moving. We will switch when moving to etcd
483486 # or where we have a single ZMQ socket in the scheduler.
484487
485- def handshake (path : str , rank : int ) -> tuple [ NixlAgentMetadata , str ] :
488+ def handshake (path : str , rank : int ) -> str :
486489 # Send query for the request.
487490 with zmq_ctx (zmq .REQ , path ) as sock :
488491 sock .send (GET_META_MSG )
@@ -492,33 +495,25 @@ def handshake(path: str, rank: int) -> tuple[NixlAgentMetadata, str]:
492495 got_metadata_time = time .perf_counter ()
493496
494497 # Register Remote agent.
495- remote_agent_name = self .add_remote_agent (metadata , rank )
498+ remote_agent_name = self .add_remote_agent (
499+ metadata , rank , remote_tp_size )
496500 setup_agent_time = time .perf_counter ()
497501
498502 logger .debug ("NIXL handshake: get metadata took: %s" ,
499503 got_metadata_time - start_time )
500504 logger .debug ("NIXL handshake: add agent took: %s" ,
501505 setup_agent_time - got_metadata_time )
502- return metadata , remote_agent_name
506+ return remote_agent_name
503507
504- # Handshake with remote agent-rank0 first to get the tp_size of remote
505- path = make_zmq_path ("tcp" , host , port )
506- logger .debug ("Querying master rank metadata on path: %s" , path )
507- rank_to_agent_name : dict [int , str ] = {}
508- metadata , rank_to_agent_name [0 ] = handshake (path , 0 )
509-
510- # Handshake only with the other TP remote the current local rank will
508+ # Handshake only with the remote TP rank that current local rank will
511509 # pull from. With homogeneous TP it happens to be the same rank_i.
512- tp_ratio = self ._tp_size [self .engine_id ] // metadata . tp_size
510+ tp_ratio = self ._tp_size [self .engine_id ] // remote_tp_size
513511 p_remote_rank = self .tp_rank // tp_ratio
514- if p_remote_rank > 0 :
515- path = make_zmq_path ("tcp" , host , port + p_remote_rank )
516- logger .debug ("Querying metadata on path: %s at remote rank %s" ,
517- path , p_remote_rank )
518- _ , rank_to_agent_name [p_remote_rank ] = handshake (
519- path , p_remote_rank )
520-
521- return rank_to_agent_name
512+ path = make_zmq_path ("tcp" , host , port + p_remote_rank )
513+ logger .debug ("Querying metadata on path: %s at remote rank %s" , path ,
514+ p_remote_rank )
515+ # Remote rank -> agent name.
516+ return {p_remote_rank : handshake (path , p_remote_rank )}
522517
523518 def register_kv_caches (self , kv_caches : dict [str , torch .Tensor ]):
524519 """Register the KV Cache data in nixl."""
@@ -645,7 +640,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
645640 agent_metadata = self .nixl_wrapper .get_agent_metadata (),
646641 kv_caches_base_addr = self .kv_caches_base_addr [self .engine_id ],
647642 num_blocks = self .num_blocks ,
648- tp_size = self .world_size ,
649643 block_len = self .block_len ,
650644 attn_backend_name = self .backend_name )
651645 ready_event = threading .Event ()
@@ -659,7 +653,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
659653
660654 def add_remote_agent (self ,
661655 nixl_agent_meta : NixlAgentMetadata ,
662- remote_tp_rank : int = 0 ) -> str :
656+ remote_tp_rank : int = 0 ,
657+ remote_tp_size : int = 1 ) -> str :
663658 """
664659 Add the remote NIXL agent and prepare the descriptors for reading cache
665660 blocks from remote.
@@ -704,9 +699,9 @@ def add_remote_agent(self,
704699 return self ._remote_agents [engine_id ][remote_tp_rank ]
705700
706701 if engine_id in self ._tp_size :
707- assert self ._tp_size [engine_id ] == nixl_agent_meta . tp_size
702+ assert self ._tp_size [engine_id ] == remote_tp_size
708703 else :
709- self ._tp_size [engine_id ] = nixl_agent_meta . tp_size
704+ self ._tp_size [engine_id ] = remote_tp_size
710705 # We may eventually enable this after asserting equality in cache
711706 # layout and close outputs.
712707 assert nixl_agent_meta .attn_backend_name == self .backend_name
@@ -756,33 +751,31 @@ def add_remote_agent(self,
756751 # rank. With heterogeneous TP, prepare the descriptors by splitting the
757752 # P KV cache along kv_head dim, of D worker's kv_head size (D>P).
758753 # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
759- p_remote_tp_rank = self .tp_rank // tp_ratio
760754 # Only register the remote's descriptors if current rank pulls from it.
761- if p_remote_tp_rank == remote_tp_rank :
762- self .kv_caches_base_addr [
763- engine_id ] = nixl_agent_meta .kv_caches_base_addr
764- rank_offset = self .tp_rank % tp_ratio * self .block_len \
765- if not (self .use_mla or is_kv_replicated ) else 0
766- # Register all remote blocks, but only the corresponding kv heads.
767- for base_addr in nixl_agent_meta .kv_caches_base_addr :
768- for block_id in range (nixl_agent_meta .num_blocks ):
769- block_offset = block_id * nixl_agent_meta .block_len
770- # For each block, grab the heads chunk belonging to rank_i
771- # of size remote_nheads // tp_ratio, which correspond to
772- # self.block_len == remote_block_len//tp_ratio bytes.
773- addr = base_addr + block_offset + rank_offset
774- # (addr, len, device id)
775- blocks_data .append ((addr , self .block_len , remote_tp_rank ))
776- logger .debug (
777- "Created %s blocks for dst engine %s with remote rank %s and "
778- "local rank %s" , len (blocks_data ), engine_id , remote_tp_rank ,
779- self .tp_rank )
755+ self .kv_caches_base_addr [
756+ engine_id ] = nixl_agent_meta .kv_caches_base_addr
757+ rank_offset = self .tp_rank % tp_ratio * self .block_len \
758+ if not (self .use_mla or is_kv_replicated ) else 0
759+ # Register all remote blocks, but only the corresponding kv heads.
760+ for base_addr in nixl_agent_meta .kv_caches_base_addr :
761+ for block_id in range (nixl_agent_meta .num_blocks ):
762+ block_offset = block_id * nixl_agent_meta .block_len
763+ # For each block, grab the heads chunk belonging to rank_i
764+ # of size remote_nheads // tp_ratio, which correspond to
765+ # self.block_len == remote_block_len//tp_ratio bytes.
766+ addr = base_addr + block_offset + rank_offset
767+ # (addr, len, device id)
768+ blocks_data .append ((addr , self .block_len , remote_tp_rank ))
769+ logger .debug (
770+ "Created %s blocks for dst engine %s with remote rank %s and "
771+ "local rank %s" , len (blocks_data ), engine_id , remote_tp_rank ,
772+ self .tp_rank )
780773
781- # Register with NIXL.
782- descs = self .nixl_wrapper .get_xfer_descs (blocks_data , "VRAM" )
783- self .dst_xfer_side_handles [
784- engine_id ] = self .nixl_wrapper .prep_xfer_dlist (
785- remote_agent_name , descs )
774+ # Register with NIXL.
775+ descs = self .nixl_wrapper .get_xfer_descs (blocks_data , "VRAM" )
776+ self .dst_xfer_side_handles [
777+ engine_id ] = self .nixl_wrapper .prep_xfer_dlist (
778+ remote_agent_name , descs )
786779
787780 return remote_agent_name
788781
@@ -917,7 +910,7 @@ def start_load_kv(self, metadata: NixlConnectorMetadata):
917910 if fut is None :
918911 fut = self ._handshake_initiation_executor .submit (
919912 self ._nixl_handshake , meta .remote_host ,
920- meta .remote_port )
913+ meta .remote_port , meta . tp_size )
921914 self ._handshake_futures [remote_engine_id ] = fut
922915
923916 def done_callback (f : Future [dict [int , str ]],
@@ -957,13 +950,9 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
957950 remote_block_ids = meta .remote_block_ids ,
958951 )
959952
960- def _read_blocks (
961- self ,
962- local_block_ids : list [int ],
963- remote_block_ids : list [int ],
964- dst_engine_id : str ,
965- request_id : str ,
966- ):
953+ def _read_blocks (self , local_block_ids : list [int ],
954+ remote_block_ids : list [int ], dst_engine_id : str ,
955+ request_id : str ):
967956 # NOTE(rob): having the staging blocks be on the READER side is
968957 # not going to work well (since we will have to call rearrange tensors).
969958 # after we detect the txn is complete (which means we cannot make the
0 commit comments