1616from vllm import envs
1717from vllm .config import VllmConfig
1818from vllm .distributed .kv_transfer .kv_connector .v1 .base import (
19- KVConnectorBase_V1 , KVConnectorMetadata , KVConnectorRole , KVTransferParams )
19+ KVConnectorBase_V1 , KVConnectorMetadata , KVConnectorRole )
2020from vllm .distributed .parallel_state import (
2121 get_tensor_model_parallel_rank , get_tensor_model_parallel_world_size ,
2222 get_tp_group )
4444 NixlWrapper = None
4545
4646
47- @dataclass
48- class NixlKVTransferParams (KVTransferParams ):
49-
50- def __init__ (
51- self ,
52- do_remote_prefill : bool ,
53- do_remote_decode : bool ,
54- remote_block_ids : Optional [list [int ]] = None ,
55- remote_host : Optional [str ] = None ,
56- remote_port : Optional [int ] = None ,
57- remote_engine_id : Optional [str ] = None ,
58- ):
59- self .do_remote_prefill = do_remote_prefill
60- self .do_remote_decode = do_remote_decode
61- self .remote_block_ids = remote_block_ids
62- self .remote_host = remote_host
63- self .remote_port = remote_port
64- self .remote_engine_id = remote_engine_id
65-
66- @staticmethod
67- def from_raw_dict (
68- raw_dict : Optional [dict [str ,
69- Any ]]) -> Optional ["NixlKVTransferParams" ]:
70-
71- # If no raw transfer params passed, return None.
72- if raw_dict is None :
73- return None
74-
75- # Validate the request is formatted properly.
76- if (("do_remote_prefill" not in raw_dict )
77- or ("do_remote_decode" not in raw_dict )
78- or ("remote_block_ids" not in raw_dict )
79- or ("remote_host" not in raw_dict )
80- or ("remote_port" not in raw_dict )
81- or ("remote_engine_id" not in raw_dict )):
82- logger .warning (
83- "Got invalid KVTransferParams: %s. This "
84- "request will not utilize KVTransfer" , raw_dict )
85- return None
86-
87- return NixlKVTransferParams (
88- do_remote_prefill = raw_dict ["do_remote_prefill" ],
89- do_remote_decode = raw_dict ["do_remote_decode" ],
90- remote_block_ids = raw_dict ["remote_block_ids" ],
91- remote_host = raw_dict ["remote_host" ],
92- remote_port = raw_dict ["remote_port" ],
93- remote_engine_id = raw_dict ["remote_engine_id" ],
94- )
95-
96-
9747class NixlAgentMetadata (
9848 msgspec .Struct ,
9949 omit_defaults = True , # type: ignore[call-arg]
@@ -123,25 +73,18 @@ def add_new_req(
12373 self ,
12474 request_id : str ,
12575 local_block_ids : list [int ],
126- kv_transfer_params : NixlKVTransferParams ,
76+ kv_transfer_params : dict [ str , Any ] ,
12777 ):
128- assert request_id not in self .requests
129- assert kv_transfer_params .remote_block_ids is not None
130- assert kv_transfer_params .remote_engine_id is not None
131- assert kv_transfer_params .remote_host is not None
132- assert kv_transfer_params .remote_port is not None
133-
13478 self .requests [request_id ] = ReqMeta (
13579 local_block_ids = local_block_ids ,
136- remote_block_ids = kv_transfer_params . remote_block_ids ,
137- remote_engine_id = kv_transfer_params . remote_engine_id ,
138- remote_host = kv_transfer_params . remote_host ,
139- remote_port = kv_transfer_params . remote_port ,
80+ remote_block_ids = kv_transfer_params [ " remote_block_ids" ] ,
81+ remote_engine_id = kv_transfer_params [ " remote_engine_id" ] ,
82+ remote_host = kv_transfer_params [ " remote_host" ] ,
83+ remote_port = kv_transfer_params [ " remote_port" ] ,
14084 )
14185
14286
14387class NixlConnector (KVConnectorBase_V1 ):
144- _KVTransferParams : type [NixlKVTransferParams ] = NixlKVTransferParams
14588
14689 def __init__ (self , vllm_config : VllmConfig , role : KVConnectorRole ):
14790 assert vllm_config .kv_transfer_config is not None
@@ -253,52 +196,52 @@ def get_num_new_matched_tokens(
253196 asynchronously (between scheduler steps).
254197 """
255198
199+ params = request .kv_transfer_params
256200 logger .debug (
257201 "NIXLConnector get_num_new_matched_tokens: "
258202 "num_computed_tokens=%s, kv_transfer_params=%s" ,
259- num_computed_tokens , request .kv_transfer_params )
260-
261- # No KVTransfer for this request.
262- if request .kv_transfer_params is None :
263- return 0 , False
264- assert isinstance (request .kv_transfer_params , NixlKVTransferParams )
203+ num_computed_tokens , params )
265204
266- # Remote prefill: get all prompt blocks from remote.
267- if request . kv_transfer_params . do_remote_prefill :
205+ if params is not None and params . get ( "do_remote_prefill" ):
206+ # Remote prefill: get all prompt blocks from remote.
268207 assert num_computed_tokens % self .block_size == 0
269208 rounded_num_prompt_tokens = round_down (
270209 len (request .prompt_token_ids ), self .block_size )
271210 count = max (rounded_num_prompt_tokens - num_computed_tokens , 0 )
272211 return count , count > 0
273212
213+ # No remote prefill for this request.
274214 return 0 , False
275215
276216 def update_state_after_alloc (self , request : "Request" ,
277217 blocks : "KVCacheBlocks" ,
278218 num_external_tokens : int ):
279219
220+ params = request .kv_transfer_params
280221 logger .debug (
281222 "NIXLConnector update_state_after_alloc: "
282223 "num_external_tokens=%s, kv_transfer_params=%s" ,
283- num_external_tokens , request . kv_transfer_params )
224+ num_external_tokens , params )
284225
285- if request .kv_transfer_params is None :
286- return
287-
288- assert isinstance (request .kv_transfer_params , NixlKVTransferParams )
289- if request .kv_transfer_params .do_remote_prefill :
226+ if params is not None and params .get ("do_remote_prefill" ):
290227 # NOTE(rob): if prompt < block_size, no remote blocks
291228 # since the remote only sends fully computed blocks, so
292229 # skip recving for this request. num_external_tokens
293230 # should be 0 if there are no remote blocks.
294- if request .kv_transfer_params .remote_block_ids :
295- # Get unhashed blocks to pull from remote.
296- self ._reqs_need_recv [request .request_id ] = (
297- request , blocks .get_unhashed_block_ids ())
231+ if params .get ("remote_block_ids" ):
232+ if all (p in params for p in ("remote_engine_id" , "remote_host" ,
233+ "remote_port" )):
234+ # Get unhashed blocks to pull from remote.
235+ self ._reqs_need_recv [request .request_id ] = (
236+ request , blocks .get_unhashed_block_ids ())
237+ else :
238+ logger .warning (
239+ "Got invalid KVTransferParams: %s. This "
240+ "request will not utilize KVTransfer" , params )
298241 else :
299242 assert num_external_tokens == 0
300243 # Only trigger 1 KV transfer per request.
301- request . kv_transfer_params . do_remote_prefill = False
244+ params [ " do_remote_prefill" ] = False
302245
303246 def build_connector_meta (
304247 self ,
@@ -308,7 +251,7 @@ def build_connector_meta(
308251
309252 # Loop through scheduled reqs and convert to ReqMeta.
310253 for req_id , (req , block_ids ) in self ._reqs_need_recv .items ():
311- assert isinstance ( req .kv_transfer_params , NixlKVTransferParams )
254+ assert req .kv_transfer_params is not None
312255 meta .add_new_req (
313256 request_id = req_id ,
314257 local_block_ids = block_ids ,
@@ -330,34 +273,30 @@ def request_finished(
330273 should be freed now or will be sent asynchronously and freed later.
331274 """
332275
276+ params = request .kv_transfer_params
333277 logger .debug (
334- "NIXLConnector request_finished, "
335- "request_status=%s, kv_transfer_params=%s" , request .status ,
336- request .kv_transfer_params )
337-
338- if request .kv_transfer_params is None :
339- return False , None
340- assert isinstance (request .kv_transfer_params , NixlKVTransferParams )
278+ "NIXLConnector request_finished, request_status=%s, "
279+ "kv_transfer_params=%s" , request .status , params )
341280
342- if (( not request . kv_transfer_params . do_remote_decode )
343- or ( request .status != RequestStatus .FINISHED_LENGTH_CAPPED ) ):
281+ if (params is None or not params . get ( " do_remote_decode" )
282+ or request .status != RequestStatus .FINISHED_LENGTH_CAPPED ):
344283 return False , None
345284
346285 # Get computed blocks.
347286 all_full = request .num_computed_tokens % self .block_size == 0
348- computed_block_ids = ( block_ids if all_full else block_ids [:- 1 ])
287+ computed_block_ids = block_ids if all_full else block_ids [:- 1 ]
349288
350289 # If prompt < block_size, no xfer so free blocks immediately.
351290 delay_free_blocks = len (computed_block_ids ) > 0
352291
353- return delay_free_blocks , NixlKVTransferParams (
292+ return delay_free_blocks , dict (
354293 do_remote_prefill = True ,
355294 do_remote_decode = False ,
356295 remote_block_ids = computed_block_ids ,
357296 remote_engine_id = self .engine_id ,
358297 remote_host = envs .VLLM_NIXL_SIDE_CHANNEL_HOST ,
359298 remote_port = envs .VLLM_NIXL_SIDE_CHANNEL_PORT ,
360- ). __dict__
299+ )
361300
362301
363302class NixlConnectorWorker :
0 commit comments