Skip to content

Commit 9a69f5f

Browse files
committed
introduce chunksize
1 parent c4eb386 commit 9a69f5f

File tree

1 file changed

+65
-51
lines changed

1 file changed

+65
-51
lines changed

ucm/integration/vllm/uc_connector.py

Lines changed: 65 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,17 @@ class RequestBlockInfo:
6666
block_operations: list[BlockOperation] = field(default_factory=list)
6767
# Next block position to process
6868
start_position: int = 0
69+
# vllm_block_ids in HBM
70+
vllm_block_ids: list[int] = field(default_factory=list)
6971

7072

7173
@dataclass
7274
class ReqMeta:
7375
request_id: str
7476
# list[(block_hash, vllm_block_id)]
75-
load_blocks: list[tuple[str, int]] = field(default_factory=list)
77+
load_blocks: list[tuple[str, torch.Tensor]] = field(default_factory=list)
7678
# list[(block_hash, vllm_block_id)]
77-
dump_blocks: list[tuple[str, int]] = field(default_factory=list)
79+
dump_blocks: list[tuple[str, torch.Tensor]] = field(default_factory=list)
7880
# Whether use load_async
7981
load_async: bool = False
8082

@@ -158,6 +160,8 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
158160
"use_layerwise"
159161
]
160162
)
163+
self.chunk_size = 256
164+
self.blocks_per_chunk = self.chunk_size // self.block_size
161165

162166
def _init_kv_caches_from_forward_context(self, forward_context: "ForwardContext"):
163167
for layer_name in forward_context.no_compile_layers:
@@ -204,24 +208,24 @@ def DataOffset(self, kv_layer, rank, layer_id, is_v):
204208
)
205209

206210
def get_tensor_and_offset_layerwise(
207-
self, vllm_block_ids: List[int], kv_layer: torch.Tensor, layer_name: str
211+
self, vllm_block_ids_tensors: List[torch.Tensor], kv_layer: torch.Tensor, layer_name: str
208212
) -> tuple[List[torch.Tensor], List[int]]:
209213
k_tensors = []
210214
k_offsets = []
211215
v_tensors = []
212216
v_offsets = []
213217
layer_id = self._extract_layer_index(layer_name)
214218

215-
for blk_id in vllm_block_ids:
219+
for vllm_block_ids_tensor in vllm_block_ids_tensors:
216220
k_data_offset = self.DataOffset(kv_layer, self.rank, layer_id, False)
217221
if self.is_mla:
218-
k_tensors.append(kv_layer[blk_id])
222+
k_tensors.append(kv_layer[vllm_block_ids_tensor])
219223
else:
220-
k_tensors.append(kv_layer[0][blk_id])
224+
k_tensors.append(kv_layer[0][vllm_block_ids_tensor])
221225
k_offsets.append(k_data_offset)
222226
if not self.is_mla:
223227
v_data_offset = self.DataOffset(kv_layer, self.rank, layer_id, True)
224-
v_tensors.append(kv_layer[1][blk_id])
228+
v_tensors.append(kv_layer[1][vllm_block_ids_tensor])
225229
v_offsets.append(v_data_offset)
226230
return k_tensors + v_tensors, k_offsets + v_offsets
227231

@@ -266,14 +270,15 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
266270
continue
267271

268272
storage_block_ids = [block[0] for block in request.load_blocks]
269-
vllm_block_ids = [block[1] for block in request.load_blocks]
273+
vllm_block_ids_tensors = [block[1] for block in request.load_blocks]
270274
blocks_len = len(storage_block_ids)
271-
self._load_req_to_blocks.setdefault(request.request_id, set()).update(
272-
vllm_block_ids
273-
)
275+
for vllm_block_ids_tensor in vllm_block_ids_tensors:
276+
self._load_req_to_blocks.setdefault(request.request_id, set()).update(
277+
vllm_block_ids_tensor.tolist()
278+
)
274279
for layer_name, kv_layer in self.kv_caches.items():
275280
tensors, offsets = self.get_tensor_and_offset_layerwise(
276-
vllm_block_ids, kv_layer, layer_name
281+
vllm_block_ids_tensors, kv_layer, layer_name
277282
)
278283
k_task_id = self.connector.load(
279284
storage_block_ids, offsets[:blocks_len], tensors[:blocks_len]
@@ -397,10 +402,10 @@ def save_kv_layer(
397402
# Example: [("hash_123", 5), ("hash_456", 8), ("hash_789", 12)]
398403
# ["hash_123", "hash_456", "hash_789"]
399404
storage_block_ids = [block[0] for block in request.dump_blocks]
400-
vllm_block_ids = [block[1] for block in request.dump_blocks] # [5, 8, 12]
405+
vllm_block_ids_tensors = [block[1] for block in request.dump_blocks] # [5, 8, 12]
401406
blocks_len = len(storage_block_ids)
402407
tensors, offsets = self.get_tensor_and_offset_layerwise(
403-
vllm_block_ids, kv_layer, layer_name
408+
vllm_block_ids_tensors, kv_layer, layer_name
404409
)
405410

406411
if kv_layer[0].device.type == "npu":
@@ -457,11 +462,11 @@ def wait_for_tasks():
457462
continue
458463

459464
storage_block_ids = [block[0] for block in request.dump_blocks]
460-
vllm_block_ids = [block[1] for block in request.dump_blocks]
465+
vllm_block_ids_tensors = [block[1] for block in request.dump_blocks]
461466
blocks_len = len(storage_block_ids)
462467
for layer_name, kv_layer in self.kv_caches.items():
463468
tensors, offsets = self.get_tensor_and_offset_layerwise(
464-
vllm_block_ids, kv_layer, layer_name
469+
vllm_block_ids_tensors, kv_layer, layer_name
465470
)
466471
for block_id, offset, tensor in zip(
467472
storage_block_ids, offsets[:blocks_len], tensors[:blocks_len]
@@ -580,13 +585,13 @@ def hash_request_tokens(
580585
return ret
581586

582587
assert num_computed_tokens % self.block_size == 0
583-
block_hashes = hash_request_tokens(md5, self.block_size, request)
588+
block_hashes = hash_request_tokens(md5, self.chunk_size, request)
584589
if not block_hashes:
585590
logger.debug("Maybe tokens too short to load.")
586591
return 0, False
587592

588593
# Calculate start position (exclude blocks already in HBM)
589-
start_position = num_computed_tokens // self.block_size
594+
start_position = num_computed_tokens // self.chunk_size
590595

591596
block_operations = [BlockOperation.NONE] * len(block_hashes)
592597

@@ -655,12 +660,14 @@ def update_state_after_alloc(
655660
"""
656661
if request.request_id in self._need_load_reqs:
657662
local_block_ids = (
658-
# since we use unhashed blocks, so we don't need to reset start_position
659-
blocks.get_unhashed_block_ids()
663+
blocks.get_block_ids()
660664
if num_external_tokens > 0
661665
else []
662666
)
663-
self._need_load_reqs[request.request_id] = local_block_ids
667+
self._need_load_reqs[request.request_id] = local_block_ids[0]
668+
request_block_info = self.request_block_infos.get(request.request_id, None)
669+
if request_block_info:
670+
request_block_info.start_position = 0
664671
return
665672

666673
request_block_info = self.request_block_infos.get(request.request_id, None)
@@ -699,15 +706,16 @@ def build_connector_meta(
699706
for req_id, block_ids in self._need_load_reqs.items():
700707
block_info = self.request_block_infos.get(req_id)
701708
if block_info:
702-
load_blocks, dump_blocks = self._extract_blocks(block_ids, block_info)
703-
meta.requests.append(
704-
ReqMeta(
705-
request_id=req_id,
706-
load_blocks=load_blocks,
707-
dump_blocks=dump_blocks,
708-
load_async=True,
709+
block_info.vllm_block_ids = block_ids
710+
load_blocks, dump_blocks = self._extract_blocks(block_info)
711+
meta.requests.append(
712+
ReqMeta(
713+
request_id=req_id,
714+
load_blocks=load_blocks,
715+
dump_blocks=dump_blocks,
716+
load_async=True,
717+
)
709718
)
710-
)
711719
self._need_load_reqs.clear()
712720

713721
for new_req in scheduler_output.scheduled_new_reqs:
@@ -716,9 +724,8 @@ def build_connector_meta(
716724

717725
block_info = self.request_block_infos.get(req_id)
718726
if block_info:
719-
load_blocks, dump_blocks = self._extract_blocks(
720-
vllm_block_ids, block_info
721-
)
727+
block_info.vllm_block_ids = vllm_block_ids
728+
load_blocks, dump_blocks = self._extract_blocks(block_info)
722729
if load_blocks or dump_blocks:
723730
meta.requests.append(
724731
ReqMeta(
@@ -756,9 +763,8 @@ def get_requests():
756763
for req_id, new_block_ids in get_requests():
757764
block_info = self.request_block_infos.get(req_id)
758765
if block_info:
759-
load_blocks, dump_blocks = self._extract_blocks(
760-
new_block_ids[0], block_info
761-
)
766+
block_info.vllm_block_ids.extend(new_block_ids[0])
767+
load_blocks, dump_blocks = self._extract_blocks(block_info)
762768
if load_blocks or dump_blocks:
763769
meta.requests.append(
764770
ReqMeta(
@@ -791,8 +797,8 @@ def request_finished(
791797
return False, None
792798

793799
def _extract_blocks(
794-
self, vllm_block_ids: list[int], block_info: RequestBlockInfo
795-
) -> tuple[list[tuple[str, int]], list[tuple[str, int]]]:
800+
self, block_info: RequestBlockInfo
801+
) -> tuple[list[tuple[str, torch.Tensor]], list[tuple[str, torch.Tensor]]]:
796802
"""
797803
Extract blocks that need load and dump, block_info.start_position
798804
is the next block position to process, only return blocks that need
@@ -802,23 +808,31 @@ def _extract_blocks(
802808

803809
if start_pos >= len(block_info.block_operations):
804810
return [], []
805-
806-
process_length = min(
807-
len(block_info.block_operations) - start_pos, len(vllm_block_ids)
808-
)
809-
ops = block_info.block_operations[start_pos : start_pos + process_length]
810-
hashes = block_info.block_hashes[start_pos : start_pos + process_length]
811-
vllm_ids = vllm_block_ids[:process_length]
812-
811+
813812
load_blocks = []
814813
dump_blocks = []
815-
for op, hash, vllm_id in zip(ops, hashes, vllm_ids):
816-
if op == BlockOperation.LOAD:
817-
load_blocks.append((hash, vllm_id))
818-
elif op == BlockOperation.DUMP:
819-
dump_blocks.append((hash, vllm_id))
820814

821-
block_info.start_position += process_length
815+
block_mapping: dict[str, torch.Tensor] = {}
816+
vllm_block_ids = block_info.vllm_block_ids
817+
for idx, vllm_block_id in enumerate(vllm_block_ids[start_pos * self.blocks_per_chunk :], start_pos * self.blocks_per_chunk):
818+
chunk_idx = idx // self.blocks_per_chunk
819+
if chunk_idx >= len(block_info.block_hashes):
820+
break
821+
if idx + self.blocks_per_chunk > len(vllm_block_ids):
822+
break
823+
chunk_blocks = vllm_block_ids[idx : idx + self.blocks_per_chunk]
824+
block_mapping[block_info.block_hashes[chunk_idx]] = torch.tensor(chunk_blocks)
825+
826+
for i in range(start_pos, start_pos + len(block_mapping)):
827+
if block_info.block_operations[i] == BlockOperation.LOAD:
828+
chunk_hash = block_info.block_hashes[i]
829+
load_blocks.append((chunk_hash, block_mapping[chunk_hash]))
830+
elif block_info.block_operations[i] == BlockOperation.DUMP:
831+
chunk_hash = block_info.block_hashes[i]
832+
dump_blocks.append((chunk_hash, block_mapping[chunk_hash]))
833+
834+
block_info.start_position += len(block_mapping)
835+
822836
return load_blocks, dump_blocks
823837

824838
def get_block_ids_with_load_errors(self) -> set[int]:

0 commit comments

Comments
 (0)