Skip to content

Commit c47741c

Browse files
NickLuccheheheda12345
authored andcommitted
Add support for MLA caches with different latent dim (vllm-project#25902)
Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: simon-mo <simon.mo@hey.com>
1 parent a731678 commit c47741c

File tree

2 files changed

+66
-42
lines changed

2 files changed

+66
-42
lines changed

tests/v1/kv_connector/unit/test_nixl_connector.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,9 @@ def _nixl_handshake(self, host: str, port: int, remote_tp_size: int,
255255
time.sleep(self._hand_shake_latency)
256256
# These should've been done in register_kv_caches(), called by
257257
# gpu_model_runner. Here we just hardcode some dummy values.
258-
self.slot_size_bytes = 4096
259-
self.block_len = self.slot_size_bytes * self.block_size
258+
slot_size_bytes = 4096
259+
self.slot_size_per_layer = [slot_size_bytes]
260+
self.block_len_per_layer = [slot_size_bytes * self.block_size]
260261
self.num_blocks = 1
261262
self.dst_num_blocks[self.engine_id] = self.num_blocks
262263

@@ -268,7 +269,7 @@ def _nixl_handshake(self, host: str, port: int, remote_tp_size: int,
268269
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
269270
kv_caches_base_addr=[0],
270271
num_blocks=1,
271-
block_len=self.block_len,
272+
block_lens=self.block_len_per_layer,
272273
attn_backend_name=self.backend_name,
273274
# `self.kv_cache_layout` is only forced to HND when vllm engine
274275
# is started. We mock HND here.
@@ -485,8 +486,8 @@ def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init):
485486
worker = connector.connector_worker
486487

487488
# Minimal local registration params used by add_remote_agent
488-
worker.slot_size_bytes = 4096
489-
worker.block_len = worker.slot_size_bytes * worker.block_size
489+
worker.slot_size_per_layer = [4096]
490+
worker.block_len_per_layer = [4096 * worker.block_size]
490491
worker.num_blocks = 1
491492
worker.dst_num_blocks[worker.engine_id] = worker.num_blocks
492493

@@ -498,7 +499,7 @@ def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init):
498499
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
499500
kv_caches_base_addr=[0],
500501
num_blocks=1,
501-
block_len=worker.block_len,
502+
block_lens=worker.block_len_per_layer,
502503
attn_backend_name=worker.backend_name,
503504
kv_cache_layout=mismatched_layout,
504505
)

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 59 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class NixlAgentMetadata(
8484
agent_metadata: bytes
8585
kv_caches_base_addr: list[int]
8686
num_blocks: int
87-
block_len: int
87+
block_lens: list[int]
8888
attn_backend_name: str
8989
kv_cache_layout: str
9090

@@ -766,6 +766,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
766766
split_k_and_v = not (self.use_mla or self._use_pallas
767767
or self._use_flashinfer)
768768
tensor_size_bytes = None
769+
# Enable different block lengths for different layers when MLA is used.
770+
self.block_len_per_layer = list[int]()
771+
self.slot_size_per_layer = list[int]() # HD bytes in kv terms
769772
for layer_name, cache_or_caches in xfer_buffers.items():
770773
cache_list = cache_or_caches if split_k_and_v else [
771774
cache_or_caches
@@ -783,10 +786,25 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
783786
tensor_size_bytes = curr_tensor_size_bytes
784787
self.num_blocks = cache.shape[0]
785788

786-
assert tensor_size_bytes == curr_tensor_size_bytes, \
787-
"All kv cache tensors must have the same size"
789+
assert cache.shape[0] == self.num_blocks, \
790+
"All kv cache tensors must have the same number of blocks"
791+
792+
self.block_len_per_layer.append(curr_tensor_size_bytes //
793+
self.num_blocks)
794+
self.slot_size_per_layer.append(self.block_len_per_layer[-1] //
795+
self.block_size)
796+
797+
if not self.use_mla:
798+
# Different kv cache shape is not supported by HeteroTP
799+
assert tensor_size_bytes == curr_tensor_size_bytes, \
800+
"All kv cache tensors must have the same size"
788801
caches_data.append(
789-
(base_addr, tensor_size_bytes, self.tp_rank, ""))
802+
(base_addr, curr_tensor_size_bytes, self.tp_rank, ""))
803+
804+
logger.debug("Different block lengths collected: %s",
805+
set(self.block_len_per_layer))
806+
assert len(self.block_len_per_layer) == len(seen_base_addresses)
807+
assert self.num_blocks != 0
790808

791809
self.kv_caches_base_addr[self.engine_id] = seen_base_addresses
792810
self.num_regions = len(caches_data)
@@ -799,16 +817,12 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
799817
logger.debug("Done registering descs")
800818
self._registered_descs.append(descs)
801819

802-
assert tensor_size_bytes is not None
803-
assert self.num_blocks != 0
804-
assert tensor_size_bytes % self.num_blocks == 0
805-
self.block_len = tensor_size_bytes // self.num_blocks
806-
self.slot_size_bytes = self.block_len // self.block_size
807820
self.device_kv_caches = kv_caches
808821
self.dst_num_blocks[self.engine_id] = self.num_blocks
809822
if self._use_flashinfer:
810-
assert self.slot_size_bytes % 2 == 0
811-
self.slot_size_bytes /= 2
823+
for i in range(len(self.slot_size_per_layer)):
824+
assert self.slot_size_per_layer[i] % 2 == 0
825+
self.slot_size_per_layer[i] //= 2
812826

813827
# NOTE (NickLucche) When FlashInfer is used, memory is registered
814828
# with joint KV for each block. This minimizes the overhead in
@@ -818,17 +832,17 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
818832
# of 'virtual' regions here and halve `block_len` below.
819833
self.num_regions *= 2
820834

821-
kv_block_len = self.get_backend_aware_kv_block_len()
822835
# Register local/src descr for NIXL xfer.
823836
blocks_data = []
824-
for base_addr in seen_base_addresses:
837+
for i, base_addr in enumerate(seen_base_addresses):
838+
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
825839
# NOTE With heter-TP, more blocks are prepared than what are
826840
# needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
827841
# could create fewer, but then _get_block_descs_ids needs to
828842
# select agent_meta.num_blocks instead of self.num_blocks for
829843
# local descr, and that makes handling regular flow less clean.
830844
for block_id in range(self.num_blocks):
831-
block_offset = block_id * self.block_len
845+
block_offset = block_id * self.block_len_per_layer[i]
832846
addr = base_addr + block_offset
833847
# (addr, len, device id)
834848
blocks_data.append((addr, kv_block_len, self.tp_rank))
@@ -838,7 +852,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
838852
# descs ordering. This is needed for selecting contiguous heads
839853
# when split across TP ranks.
840854
for block_id in range(self.num_blocks):
841-
block_offset = block_id * self.block_len
855+
block_offset = block_id * self.block_len_per_layer[i]
842856
addr = base_addr + block_offset
843857
# Register addresses for V cache (K registered first).
844858
v_addr = addr + kv_block_len
@@ -878,7 +892,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
878892
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
879893
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
880894
num_blocks=self.num_blocks,
881-
block_len=self.block_len,
895+
block_lens=self.block_len_per_layer,
882896
attn_backend_name=self.backend_name,
883897
kv_cache_layout=self.kv_cache_layout)
884898
ready_event = threading.Event()
@@ -903,7 +917,7 @@ def add_remote_agent(self,
903917
The latter, assuming D.world_size > P.world_size, requires that two or
904918
more local TP worker share the xfer from a single TP worker.
905919
906-
Here's an example:
920+
Here's an example (non-MLA case):
907921
908922
rank_offset p_remote_tp_rank
909923
(kv split no)
@@ -959,14 +973,20 @@ def add_remote_agent(self,
959973
total_num_kv_heads = self.model_config.get_total_num_kv_heads()
960974
is_kv_replicated = self._tp_size[engine_id] // total_num_kv_heads >= 1
961975

976+
remote_block_len = nixl_agent_meta.block_lens[0]
962977
if self.use_mla or is_kv_replicated:
963-
# With MLA the only difference is in the number of blocks.
964-
remote_block_size = nixl_agent_meta.block_len // (
965-
self.slot_size_bytes)
966-
assert self.block_len == nixl_agent_meta.block_len
978+
# With replicated KV cache, only the number of blocks can differ.
979+
assert self.block_len_per_layer == nixl_agent_meta.block_lens, \
980+
"KV cache sizes must match between P and D when replicated"
981+
remote_block_size = remote_block_len // (
982+
self.slot_size_per_layer[0])
967983
else:
968-
remote_block_size = nixl_agent_meta.block_len // (
969-
self.slot_size_bytes * tp_ratio)
984+
# When MLA is not used, this is a list of the same block length
985+
for block_len in nixl_agent_meta.block_lens:
986+
assert block_len == remote_block_len, \
987+
"All remote layers must have the same block size"
988+
remote_block_size = remote_block_len // (
989+
self.slot_size_per_layer[0] * tp_ratio)
970990
if self._use_flashinfer:
971991
# With flashinfer, KV are sent in the same message.
972992
remote_block_size //= 2
@@ -977,14 +997,14 @@ def add_remote_agent(self,
977997
raise ValueError(
978998
"Heterogeneous TP is not supported on XPU")
979999

980-
assert nixl_agent_meta.block_len == self.block_len * tp_ratio, (
1000+
assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, (
9811001
"Remote P worker KV layer cache must be of shape [2, N, "
9821002
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
9831003
)
9841004

9851005
assert self.block_size == remote_block_size, (
986-
"Remote P worker with different block size is not supported "
987-
f"{self.block_size=} {remote_block_size=}")
1006+
"Remote P worker with different page/block size is not supported "
1007+
f"{self.block_size=}, {remote_block_size=}")
9881008

9891009
# Create dst descs and xfer side handles. TP workers have same #blocks.
9901010
if engine_id in self.dst_num_blocks:
@@ -999,13 +1019,16 @@ def add_remote_agent(self,
9991019
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
10001020
self.kv_caches_base_addr[
10011021
engine_id] = nixl_agent_meta.kv_caches_base_addr
1002-
kv_block_len = self.get_backend_aware_kv_block_len()
1003-
rank_offset = self.tp_rank % tp_ratio * kv_block_len \
1004-
if not (self.use_mla or is_kv_replicated) else 0
1022+
1023+
assert len(nixl_agent_meta.kv_caches_base_addr) == len(
1024+
self.block_len_per_layer)
10051025
# Register all remote blocks, but only the corresponding kv heads.
1006-
for base_addr in nixl_agent_meta.kv_caches_base_addr:
1026+
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
1027+
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
1028+
rank_offset = self.tp_rank % tp_ratio * kv_block_len \
1029+
if not (self.use_mla or is_kv_replicated) else 0
10071030
for block_id in range(nixl_agent_meta.num_blocks):
1008-
block_offset = block_id * nixl_agent_meta.block_len
1031+
block_offset = block_id * nixl_agent_meta.block_lens[i]
10091032
# For each block, grab the heads chunk belonging to rank_i
10101033
# of size remote_nheads // tp_ratio, which correspond to
10111034
# self.block_len == remote_block_len//tp_ratio bytes.
@@ -1016,9 +1039,9 @@ def add_remote_agent(self,
10161039
if self._use_flashinfer:
10171040
# With FlashInfer index V separately to allow head splitting.
10181041
for block_id in range(nixl_agent_meta.num_blocks):
1019-
block_offset = block_id * nixl_agent_meta.block_len
1042+
block_offset = block_id * nixl_agent_meta.block_lens[i]
10201043
addr = base_addr + block_offset + rank_offset
1021-
v_addr = addr + nixl_agent_meta.block_len // 2
1044+
v_addr = addr + nixl_agent_meta.block_lens[i] // 2
10221045
blocks_data.append((v_addr, kv_block_len, remote_tp_rank))
10231046

10241047
logger.debug(
@@ -1345,7 +1368,7 @@ def _get_block_descs_ids(self,
13451368
descs_ids = region_ids * num_blocks + block_ids
13461369
return descs_ids.flatten()
13471370

1348-
def get_backend_aware_kv_block_len(self):
1371+
def get_backend_aware_kv_block_len(self, layer_idx: int):
13491372
"""
13501373
Get the block length for one K/V element (K and V have the same size).
13511374
@@ -1356,9 +1379,9 @@ def get_backend_aware_kv_block_len(self):
13561379
"""
13571380
if self._use_flashinfer:
13581381
# For indexing only half (either just the K or V part).
1359-
block_len = self.block_len // 2
1382+
block_len = self.block_len_per_layer[layer_idx] // 2
13601383
else:
1361-
block_len = self.block_len
1384+
block_len = self.block_len_per_layer[layer_idx]
13621385
return block_len
13631386

13641387
def get_kv_connector_stats(self) -> Optional[KVConnectorStats]:

0 commit comments

Comments
 (0)