Skip to content

Commit bf8bb7e

Browse files
NickLuccheheheda12345
authored andcommitted
[NIXL] Add support for MLA caches with different latent dim (#25902)
Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
1 parent eea2536 commit bf8bb7e

File tree

2 files changed

+66
-42
lines changed

2 files changed

+66
-42
lines changed

tests/v1/kv_connector/unit/test_nixl_connector.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,9 @@ def _nixl_handshake(self, host: str, port: int, remote_tp_size: int,
255255
time.sleep(self._hand_shake_latency)
256256
# These should've been done in register_kv_caches(), called by
257257
# gpu_model_runner. Here we just hardcode some dummy values.
258-
self.slot_size_bytes = 4096
259-
self.block_len = self.slot_size_bytes * self.block_size
258+
slot_size_bytes = 4096
259+
self.slot_size_per_layer = [slot_size_bytes]
260+
self.block_len_per_layer = [slot_size_bytes * self.block_size]
260261
self.num_blocks = 1
261262
self.dst_num_blocks[self.engine_id] = self.num_blocks
262263

@@ -268,7 +269,7 @@ def _nixl_handshake(self, host: str, port: int, remote_tp_size: int,
268269
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
269270
kv_caches_base_addr=[0],
270271
num_blocks=1,
271-
block_len=self.block_len,
272+
block_lens=self.block_len_per_layer,
272273
attn_backend_name=self.backend_name,
273274
# `self.kv_cache_layout` is only forced to HND when vllm engine
274275
# is started. We mock HND here.
@@ -485,8 +486,8 @@ def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init):
485486
worker = connector.connector_worker
486487

487488
# Minimal local registration params used by add_remote_agent
488-
worker.slot_size_bytes = 4096
489-
worker.block_len = worker.slot_size_bytes * worker.block_size
489+
worker.slot_size_per_layer = [4096]
490+
worker.block_len_per_layer = [4096 * worker.block_size]
490491
worker.num_blocks = 1
491492
worker.dst_num_blocks[worker.engine_id] = worker.num_blocks
492493

@@ -498,7 +499,7 @@ def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init):
498499
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
499500
kv_caches_base_addr=[0],
500501
num_blocks=1,
501-
block_len=worker.block_len,
502+
block_lens=worker.block_len_per_layer,
502503
attn_backend_name=worker.backend_name,
503504
kv_cache_layout=mismatched_layout,
504505
)

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 59 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ class NixlAgentMetadata(
8787
agent_metadata: bytes
8888
kv_caches_base_addr: list[int]
8989
num_blocks: int
90-
block_len: int
90+
block_lens: list[int]
9191
attn_backend_name: str
9292
kv_cache_layout: str
9393

@@ -772,6 +772,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
772772
split_k_and_v = not (self.use_mla or self._use_pallas
773773
or self._use_flashinfer)
774774
tensor_size_bytes = None
775+
# Enable different block lengths for different layers when MLA is used.
776+
self.block_len_per_layer = list[int]()
777+
self.slot_size_per_layer = list[int]() # HD bytes in kv terms
775778
for layer_name, cache_or_caches in xfer_buffers.items():
776779
cache_list = cache_or_caches if split_k_and_v else [
777780
cache_or_caches
@@ -789,10 +792,25 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
789792
tensor_size_bytes = curr_tensor_size_bytes
790793
self.num_blocks = cache.shape[0]
791794

792-
assert tensor_size_bytes == curr_tensor_size_bytes, \
793-
"All kv cache tensors must have the same size"
795+
assert cache.shape[0] == self.num_blocks, \
796+
"All kv cache tensors must have the same number of blocks"
797+
798+
self.block_len_per_layer.append(curr_tensor_size_bytes //
799+
self.num_blocks)
800+
self.slot_size_per_layer.append(self.block_len_per_layer[-1] //
801+
self.block_size)
802+
803+
if not self.use_mla:
804+
# Different kv cache shape is not supported by HeteroTP
805+
assert tensor_size_bytes == curr_tensor_size_bytes, \
806+
"All kv cache tensors must have the same size"
794807
caches_data.append(
795-
(base_addr, tensor_size_bytes, self.tp_rank, ""))
808+
(base_addr, curr_tensor_size_bytes, self.tp_rank, ""))
809+
810+
logger.debug("Different block lengths collected: %s",
811+
set(self.block_len_per_layer))
812+
assert len(self.block_len_per_layer) == len(seen_base_addresses)
813+
assert self.num_blocks != 0
796814

797815
self.kv_caches_base_addr[self.engine_id] = seen_base_addresses
798816
self.num_regions = len(caches_data)
@@ -805,16 +823,12 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
805823
logger.debug("Done registering descs")
806824
self._registered_descs.append(descs)
807825

808-
assert tensor_size_bytes is not None
809-
assert self.num_blocks != 0
810-
assert tensor_size_bytes % self.num_blocks == 0
811-
self.block_len = tensor_size_bytes // self.num_blocks
812-
self.slot_size_bytes = self.block_len // self.block_size
813826
self.device_kv_caches = kv_caches
814827
self.dst_num_blocks[self.engine_id] = self.num_blocks
815828
if self._use_flashinfer:
816-
assert self.slot_size_bytes % 2 == 0
817-
self.slot_size_bytes /= 2
829+
for i in range(len(self.slot_size_per_layer)):
830+
assert self.slot_size_per_layer[i] % 2 == 0
831+
self.slot_size_per_layer[i] //= 2
818832

819833
# NOTE (NickLucche) When FlashInfer is used, memory is registered
820834
# with joint KV for each block. This minimizes the overhead in
@@ -824,17 +838,17 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
824838
# of 'virtual' regions here and halve `block_len` below.
825839
self.num_regions *= 2
826840

827-
kv_block_len = self.get_backend_aware_kv_block_len()
828841
# Register local/src descr for NIXL xfer.
829842
blocks_data = []
830-
for base_addr in seen_base_addresses:
843+
for i, base_addr in enumerate(seen_base_addresses):
844+
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
831845
# NOTE With heter-TP, more blocks are prepared than what are
832846
# needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
833847
# could create fewer, but then _get_block_descs_ids needs to
834848
# select agent_meta.num_blocks instead of self.num_blocks for
835849
# local descr, and that makes handling regular flow less clean.
836850
for block_id in range(self.num_blocks):
837-
block_offset = block_id * self.block_len
851+
block_offset = block_id * self.block_len_per_layer[i]
838852
addr = base_addr + block_offset
839853
# (addr, len, device id)
840854
blocks_data.append((addr, kv_block_len, self.tp_rank))
@@ -844,7 +858,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
844858
# descs ordering. This is needed for selecting contiguous heads
845859
# when split across TP ranks.
846860
for block_id in range(self.num_blocks):
847-
block_offset = block_id * self.block_len
861+
block_offset = block_id * self.block_len_per_layer[i]
848862
addr = base_addr + block_offset
849863
# Register addresses for V cache (K registered first).
850864
v_addr = addr + kv_block_len
@@ -884,7 +898,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
884898
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
885899
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
886900
num_blocks=self.num_blocks,
887-
block_len=self.block_len,
901+
block_lens=self.block_len_per_layer,
888902
attn_backend_name=self.backend_name,
889903
kv_cache_layout=self.kv_cache_layout)
890904
ready_event = threading.Event()
@@ -909,7 +923,7 @@ def add_remote_agent(self,
909923
The latter, assuming D.world_size > P.world_size, requires that two or
910924
more local TP worker share the xfer from a single TP worker.
911925
912-
Here's an example:
926+
Here's an example (non-MLA case):
913927
914928
rank_offset p_remote_tp_rank
915929
(kv split no)
@@ -965,14 +979,20 @@ def add_remote_agent(self,
965979
total_num_kv_heads = self.model_config.get_total_num_kv_heads()
966980
is_kv_replicated = self._tp_size[engine_id] // total_num_kv_heads >= 1
967981

982+
remote_block_len = nixl_agent_meta.block_lens[0]
968983
if self.use_mla or is_kv_replicated:
969-
# With MLA the only difference is in the number of blocks.
970-
remote_block_size = nixl_agent_meta.block_len // (
971-
self.slot_size_bytes)
972-
assert self.block_len == nixl_agent_meta.block_len
984+
# With replicated KV cache, only the number of blocks can differ.
985+
assert self.block_len_per_layer == nixl_agent_meta.block_lens, \
986+
"KV cache sizes must match between P and D when replicated"
987+
remote_block_size = remote_block_len // (
988+
self.slot_size_per_layer[0])
973989
else:
974-
remote_block_size = nixl_agent_meta.block_len // (
975-
self.slot_size_bytes * tp_ratio)
990+
# When MLA is not used, this is a list of the same block length
991+
for block_len in nixl_agent_meta.block_lens:
992+
assert block_len == remote_block_len, \
993+
"All remote layers must have the same block size"
994+
remote_block_size = remote_block_len // (
995+
self.slot_size_per_layer[0] * tp_ratio)
976996
if self._use_flashinfer:
977997
# With flashinfer, KV are sent in the same message.
978998
remote_block_size //= 2
@@ -983,14 +1003,14 @@ def add_remote_agent(self,
9831003
raise ValueError(
9841004
"Heterogeneous TP is not supported on XPU")
9851005

986-
assert nixl_agent_meta.block_len == self.block_len * tp_ratio, (
1006+
assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, (
9871007
"Remote P worker KV layer cache must be of shape [2, N, "
9881008
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
9891009
)
9901010

9911011
assert self.block_size == remote_block_size, (
992-
"Remote P worker with different block size is not supported "
993-
f"{self.block_size=} {remote_block_size=}")
1012+
"Remote P worker with different page/block size is not supported "
1013+
f"{self.block_size=}, {remote_block_size=}")
9941014

9951015
# Create dst descs and xfer side handles. TP workers have same #blocks.
9961016
if engine_id in self.dst_num_blocks:
@@ -1005,13 +1025,16 @@ def add_remote_agent(self,
10051025
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
10061026
self.kv_caches_base_addr[
10071027
engine_id] = nixl_agent_meta.kv_caches_base_addr
1008-
kv_block_len = self.get_backend_aware_kv_block_len()
1009-
rank_offset = self.tp_rank % tp_ratio * kv_block_len \
1010-
if not (self.use_mla or is_kv_replicated) else 0
1028+
1029+
assert len(nixl_agent_meta.kv_caches_base_addr) == len(
1030+
self.block_len_per_layer)
10111031
# Register all remote blocks, but only the corresponding kv heads.
1012-
for base_addr in nixl_agent_meta.kv_caches_base_addr:
1032+
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
1033+
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
1034+
rank_offset = self.tp_rank % tp_ratio * kv_block_len \
1035+
if not (self.use_mla or is_kv_replicated) else 0
10131036
for block_id in range(nixl_agent_meta.num_blocks):
1014-
block_offset = block_id * nixl_agent_meta.block_len
1037+
block_offset = block_id * nixl_agent_meta.block_lens[i]
10151038
# For each block, grab the heads chunk belonging to rank_i
10161039
# of size remote_nheads // tp_ratio, which correspond to
10171040
# self.block_len == remote_block_len//tp_ratio bytes.
@@ -1022,9 +1045,9 @@ def add_remote_agent(self,
10221045
if self._use_flashinfer:
10231046
# With FlashInfer index V separately to allow head splitting.
10241047
for block_id in range(nixl_agent_meta.num_blocks):
1025-
block_offset = block_id * nixl_agent_meta.block_len
1048+
block_offset = block_id * nixl_agent_meta.block_lens[i]
10261049
addr = base_addr + block_offset + rank_offset
1027-
v_addr = addr + nixl_agent_meta.block_len // 2
1050+
v_addr = addr + nixl_agent_meta.block_lens[i] // 2
10281051
blocks_data.append((v_addr, kv_block_len, remote_tp_rank))
10291052

10301053
logger.debug(
@@ -1351,7 +1374,7 @@ def _get_block_descs_ids(self,
13511374
descs_ids = region_ids * num_blocks + block_ids
13521375
return descs_ids.flatten()
13531376

1354-
def get_backend_aware_kv_block_len(self):
1377+
def get_backend_aware_kv_block_len(self, layer_idx: int):
13551378
"""
13561379
Get the block length for one K/V element (K and V have the same size).
13571380
@@ -1362,9 +1385,9 @@ def get_backend_aware_kv_block_len(self):
13621385
"""
13631386
if self._use_flashinfer:
13641387
# For indexing only half (either just the K or V part).
1365-
block_len = self.block_len // 2
1388+
block_len = self.block_len_per_layer[layer_idx] // 2
13661389
else:
1367-
block_len = self.block_len
1390+
block_len = self.block_len_per_layer[layer_idx]
13681391
return block_len
13691392

13701393
def get_kv_connector_stats(self) -> Optional[KVConnectorStats]:

0 commit comments

Comments
 (0)