Skip to content

Commit d52b42e

Browse files
authored
Enforced kv layer ordering (#2312)
1 parent d5c452f commit d52b42e

File tree

2 files changed

+31
-15
lines changed

2 files changed

+31
-15
lines changed

lib/bindings/python/rust/llm/block_manager/vllm/connector/worker.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -96,25 +96,25 @@ impl KvConnectorWorker {
9696
page_size: usize,
9797
device_id: usize,
9898
dtype_width_bytes: usize,
99-
kv_caches: HashMap<String, Py<PyAny>>,
99+
kv_caches: Vec<(String, Py<PyAny>)>,
100100
) -> PyResult<()> {
101101
if self.kvbm_worker.get().is_some() {
102102
tracing::warn!("kvbm worker already registered");
103103
return Err(to_pyerr(anyhow::anyhow!("kvbm worker already registered")));
104104
}
105105

106-
// TODO: pass in the sorted (layer_name, tensor) such that the order of the list matches the order of layer execution in the model
106+
// Process kv_caches in layer execution order (already sorted by layer index)
107+
let mut vllm_tensors = Vec::new();
107108
for (layer_name, torch_tensor) in kv_caches {
108109
let vllm_tensor = Arc::new(VllmTensor::new(torch_tensor).map_err(to_pyerr)?);
109110
tracing::trace!("Registering KV cache layer: {layer_name}, tensor: {vllm_tensor:?}");
110-
self.kv_caches.insert(layer_name, vllm_tensor);
111-
}
112111

113-
let vllm_tensors: Vec<Arc<dyn TorchTensor>> = self
114-
.kv_caches
115-
.values()
116-
.map(|tensor| tensor.clone() as Arc<dyn TorchTensor>)
117-
.collect();
112+
// Store for later lookup by name
113+
self.kv_caches.insert(layer_name, vllm_tensor.clone());
114+
115+
// Build ordered tensor list for worker config
116+
vllm_tensors.push(vllm_tensor as Arc<dyn TorchTensor>);
117+
}
118118

119119
let config = KvbmWorkerConfig::builder()
120120
.drt(self.drt.clone())

lib/bindings/python/src/dynamo/llm/vllm_integration/connector_worker.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch
1313
from vllm.config import VllmConfig
1414
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
15+
from vllm.model_executor.models.utils import extract_layer_index
1516
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
1617

1718
if TYPE_CHECKING:
@@ -63,28 +64,43 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
6364
)
6465
cache_config = self.vllm_config.cache_config
6566

66-
shape = list(kv_caches.values())[0].shape
67+
# Create ordered list of (layer_name, tensor) tuples sorted by layer index
68+
ordered_kv_caches = [
69+
(layer_name, tensor)
70+
for layer_name, tensor in sorted(
71+
kv_caches.items(), key=lambda item: extract_layer_index(item[0])
72+
)
73+
]
74+
75+
# Get first tensor to extract common properties
76+
first_tensor = ordered_kv_caches[0][1]
77+
shape = first_tensor.shape
6778

79+
# Validate all tensors have same shape
6880
if not all(t.shape == shape for t in kv_caches.values()):
6981
raise NotImplementedError(
7082
"Hybrid models with different KV cache shapes are not supported yet."
7183
)
7284

85+
# Extract parameters
7386
# TODO: Assume the block dimension is within the first 2. This will break if you're doing something weird like having 1 or 2 device blocks.
7487
num_device_blocks = max(shape[0], shape[1])
7588
page_size = cache_config.block_size
76-
tensors = list(kv_caches.values())
89+
device_id = first_tensor.device.index
7790

91+
# Determine cache dtype
7892
if cache_config.cache_dtype == "auto":
7993
kv_cache_dtype = self.vllm_config.model_config.dtype
8094
else:
8195
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
8296

83-
device_id = tensors[0].device.index
84-
85-
# extract necessary bits to construct a KvbmWorker
97+
# Register with connector using ordered data
8698
self._connector.register_kv_caches(
87-
num_device_blocks, page_size, device_id, kv_cache_dtype.itemsize, kv_caches
99+
num_device_blocks,
100+
page_size,
101+
device_id,
102+
kv_cache_dtype.itemsize,
103+
ordered_kv_caches,
88104
)
89105

90106
def bind_connector_metadata(self, data: bytes) -> None:

0 commit comments

Comments
 (0)