@@ -66,15 +66,17 @@ class RequestBlockInfo:
6666 block_operations : list [BlockOperation ] = field (default_factory = list )
6767 # Next block position to process
6868 start_position : int = 0
69+ # vllm_block_ids in HBM
70+ vllm_block_ids : list [int ] = field (default_factory = list )
6971
7072
7173@dataclass
7274class ReqMeta :
7375 request_id : str
7476 # list[(block_hash, vllm_block_id)]
75- load_blocks : list [tuple [str , int ]] = field (default_factory = list )
77+ load_blocks : list [tuple [str , torch . Tensor ]] = field (default_factory = list )
7678 # list[(block_hash, vllm_block_id)]
77- dump_blocks : list [tuple [str , int ]] = field (default_factory = list )
79+ dump_blocks : list [tuple [str , torch . Tensor ]] = field (default_factory = list )
7880 # Whether use load_async
7981 load_async : bool = False
8082
@@ -158,6 +160,8 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
158160 "use_layerwise"
159161 ]
160162 )
163+ self .chunk_size = 256
164+ self .blocks_per_chunk = self .chunk_size // self .block_size
161165
162166 def _init_kv_caches_from_forward_context (self , forward_context : "ForwardContext" ):
163167 for layer_name in forward_context .no_compile_layers :
@@ -204,24 +208,24 @@ def DataOffset(self, kv_layer, rank, layer_id, is_v):
204208 )
205209
206210 def get_tensor_and_offset_layerwise (
207- self , vllm_block_ids : List [int ], kv_layer : torch .Tensor , layer_name : str
211+ self , vllm_block_ids_tensors : List [torch . Tensor ], kv_layer : torch .Tensor , layer_name : str
208212 ) -> tuple [List [torch .Tensor ], List [int ]]:
209213 k_tensors = []
210214 k_offsets = []
211215 v_tensors = []
212216 v_offsets = []
213217 layer_id = self ._extract_layer_index (layer_name )
214218
215- for blk_id in vllm_block_ids :
219+ for vllm_block_ids_tensor in vllm_block_ids_tensors :
216220 k_data_offset = self .DataOffset (kv_layer , self .rank , layer_id , False )
217221 if self .is_mla :
218- k_tensors .append (kv_layer [blk_id ])
222+ k_tensors .append (kv_layer [vllm_block_ids_tensor ])
219223 else :
220- k_tensors .append (kv_layer [0 ][blk_id ])
224+ k_tensors .append (kv_layer [0 ][vllm_block_ids_tensor ])
221225 k_offsets .append (k_data_offset )
222226 if not self .is_mla :
223227 v_data_offset = self .DataOffset (kv_layer , self .rank , layer_id , True )
224- v_tensors .append (kv_layer [1 ][blk_id ])
228+ v_tensors .append (kv_layer [1 ][vllm_block_ids_tensor ])
225229 v_offsets .append (v_data_offset )
226230 return k_tensors + v_tensors , k_offsets + v_offsets
227231
@@ -266,14 +270,15 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
266270 continue
267271
268272 storage_block_ids = [block [0 ] for block in request .load_blocks ]
269- vllm_block_ids = [block [1 ] for block in request .load_blocks ]
273+ vllm_block_ids_tensors = [block [1 ] for block in request .load_blocks ]
270274 blocks_len = len (storage_block_ids )
271- self ._load_req_to_blocks .setdefault (request .request_id , set ()).update (
272- vllm_block_ids
273- )
275+ for vllm_block_ids_tensor in vllm_block_ids_tensors :
276+ self ._load_req_to_blocks .setdefault (request .request_id , set ()).update (
277+ vllm_block_ids_tensor .tolist ()
278+ )
274279 for layer_name , kv_layer in self .kv_caches .items ():
275280 tensors , offsets = self .get_tensor_and_offset_layerwise (
276- vllm_block_ids , kv_layer , layer_name
281+ vllm_block_ids_tensors , kv_layer , layer_name
277282 )
278283 k_task_id = self .connector .load (
279284 storage_block_ids , offsets [:blocks_len ], tensors [:blocks_len ]
@@ -397,10 +402,10 @@ def save_kv_layer(
397402 # Example: [("hash_123", 5), ("hash_456", 8), ("hash_789", 12)]
398403 # ["hash_123", "hash_456", "hash_789"]
399404 storage_block_ids = [block [0 ] for block in request .dump_blocks ]
400- vllm_block_ids = [block [1 ] for block in request .dump_blocks ] # [5, 8, 12]
405+ vllm_block_ids_tensors = [block [1 ] for block in request .dump_blocks ] # [5, 8, 12]
401406 blocks_len = len (storage_block_ids )
402407 tensors , offsets = self .get_tensor_and_offset_layerwise (
403- vllm_block_ids , kv_layer , layer_name
408+ vllm_block_ids_tensors , kv_layer , layer_name
404409 )
405410
406411 if kv_layer [0 ].device .type == "npu" :
@@ -457,11 +462,11 @@ def wait_for_tasks():
457462 continue
458463
459464 storage_block_ids = [block [0 ] for block in request .dump_blocks ]
460- vllm_block_ids = [block [1 ] for block in request .dump_blocks ]
465+ vllm_block_ids_tensors = [block [1 ] for block in request .dump_blocks ]
461466 blocks_len = len (storage_block_ids )
462467 for layer_name , kv_layer in self .kv_caches .items ():
463468 tensors , offsets = self .get_tensor_and_offset_layerwise (
464- vllm_block_ids , kv_layer , layer_name
469+ vllm_block_ids_tensors , kv_layer , layer_name
465470 )
466471 for block_id , offset , tensor in zip (
467472 storage_block_ids , offsets [:blocks_len ], tensors [:blocks_len ]
@@ -580,13 +585,13 @@ def hash_request_tokens(
580585 return ret
581586
582587 assert num_computed_tokens % self .block_size == 0
583- block_hashes = hash_request_tokens (md5 , self .block_size , request )
588+ block_hashes = hash_request_tokens (md5 , self .chunk_size , request )
584589 if not block_hashes :
585590 logger .debug ("Maybe tokens too short to load." )
586591 return 0 , False
587592
588593 # Calculate start position (exclude blocks already in HBM)
589- start_position = num_computed_tokens // self .block_size
594+ start_position = num_computed_tokens // self .chunk_size
590595
591596 block_operations = [BlockOperation .NONE ] * len (block_hashes )
592597
@@ -655,12 +660,14 @@ def update_state_after_alloc(
655660 """
656661 if request .request_id in self ._need_load_reqs :
657662 local_block_ids = (
658- # since we use unhashed blocks, so we don't need to reset start_position
659- blocks .get_unhashed_block_ids ()
663+ blocks .get_block_ids ()
660664 if num_external_tokens > 0
661665 else []
662666 )
663- self ._need_load_reqs [request .request_id ] = local_block_ids
667+ self ._need_load_reqs [request .request_id ] = local_block_ids [0 ]
668+ request_block_info = self .request_block_infos .get (request .request_id , None )
669+ if request_block_info :
670+ request_block_info .start_position = 0
664671 return
665672
666673 request_block_info = self .request_block_infos .get (request .request_id , None )
@@ -699,15 +706,16 @@ def build_connector_meta(
699706 for req_id , block_ids in self ._need_load_reqs .items ():
700707 block_info = self .request_block_infos .get (req_id )
701708 if block_info :
702- load_blocks , dump_blocks = self ._extract_blocks (block_ids , block_info )
703- meta .requests .append (
704- ReqMeta (
705- request_id = req_id ,
706- load_blocks = load_blocks ,
707- dump_blocks = dump_blocks ,
708- load_async = True ,
709+ block_info .vllm_block_ids = block_ids
710+ load_blocks , dump_blocks = self ._extract_blocks (block_info )
711+ meta .requests .append (
712+ ReqMeta (
713+ request_id = req_id ,
714+ load_blocks = load_blocks ,
715+ dump_blocks = dump_blocks ,
716+ load_async = True ,
717+ )
709718 )
710- )
711719 self ._need_load_reqs .clear ()
712720
713721 for new_req in scheduler_output .scheduled_new_reqs :
@@ -716,9 +724,8 @@ def build_connector_meta(
716724
717725 block_info = self .request_block_infos .get (req_id )
718726 if block_info :
719- load_blocks , dump_blocks = self ._extract_blocks (
720- vllm_block_ids , block_info
721- )
727+ block_info .vllm_block_ids = vllm_block_ids
728+ load_blocks , dump_blocks = self ._extract_blocks (block_info )
722729 if load_blocks or dump_blocks :
723730 meta .requests .append (
724731 ReqMeta (
@@ -756,9 +763,8 @@ def get_requests():
756763 for req_id , new_block_ids in get_requests ():
757764 block_info = self .request_block_infos .get (req_id )
758765 if block_info :
759- load_blocks , dump_blocks = self ._extract_blocks (
760- new_block_ids [0 ], block_info
761- )
766+ block_info .vllm_block_ids .extend (new_block_ids [0 ])
767+ load_blocks , dump_blocks = self ._extract_blocks (block_info )
762768 if load_blocks or dump_blocks :
763769 meta .requests .append (
764770 ReqMeta (
@@ -791,8 +797,8 @@ def request_finished(
791797 return False , None
792798
793799 def _extract_blocks (
794- self , vllm_block_ids : list [ int ], block_info : RequestBlockInfo
795- ) -> tuple [list [tuple [str , int ]], list [tuple [str , int ]]]:
800+ self , block_info : RequestBlockInfo
801+ ) -> tuple [list [tuple [str , torch . Tensor ]], list [tuple [str , torch . Tensor ]]]:
796802 """
797803 Extract blocks that need load and dump, block_info.start_position
798804 is the next block position to process, only return blocks that need
@@ -802,23 +808,31 @@ def _extract_blocks(
802808
803809 if start_pos >= len (block_info .block_operations ):
804810 return [], []
805-
806- process_length = min (
807- len (block_info .block_operations ) - start_pos , len (vllm_block_ids )
808- )
809- ops = block_info .block_operations [start_pos : start_pos + process_length ]
810- hashes = block_info .block_hashes [start_pos : start_pos + process_length ]
811- vllm_ids = vllm_block_ids [:process_length ]
812-
811+
813812 load_blocks = []
814813 dump_blocks = []
815- for op , hash , vllm_id in zip (ops , hashes , vllm_ids ):
816- if op == BlockOperation .LOAD :
817- load_blocks .append ((hash , vllm_id ))
818- elif op == BlockOperation .DUMP :
819- dump_blocks .append ((hash , vllm_id ))
820814
821- block_info .start_position += process_length
815+ block_mapping : dict [str , torch .Tensor ] = {}
816+ vllm_block_ids = block_info .vllm_block_ids
817+ for idx , vllm_block_id in enumerate (vllm_block_ids [start_pos * self .blocks_per_chunk :], start_pos * self .blocks_per_chunk ):
818+ chunk_idx = idx // self .blocks_per_chunk
819+ if chunk_idx >= len (block_info .block_hashes ):
820+ break
821+ if idx + self .blocks_per_chunk > len (vllm_block_ids ):
822+ break
823+ chunk_blocks = vllm_block_ids [idx : idx + self .blocks_per_chunk ]
824+ block_mapping [block_info .block_hashes [chunk_idx ]] = torch .tensor (chunk_blocks )
825+
826+ for i in range (start_pos , start_pos + len (block_mapping )):
827+ if block_info .block_operations [i ] == BlockOperation .LOAD :
828+ chunk_hash = block_info .block_hashes [i ]
829+ load_blocks .append ((chunk_hash , block_mapping [chunk_hash ]))
830+ elif block_info .block_operations [i ] == BlockOperation .DUMP :
831+ chunk_hash = block_info .block_hashes [i ]
832+ dump_blocks .append ((chunk_hash , block_mapping [chunk_hash ]))
833+
834+ block_info .start_position += len (block_mapping )
835+
822836 return load_blocks , dump_blocks
823837
824838 def get_block_ids_with_load_errors (self ) -> set [int ]:
0 commit comments