diff --git a/lib/bindings/python/rust/llm/block_manager/vllm/connector/worker.rs b/lib/bindings/python/rust/llm/block_manager/vllm/connector/worker.rs index 871f4ef2c9..c3da870ebf 100644 --- a/lib/bindings/python/rust/llm/block_manager/vllm/connector/worker.rs +++ b/lib/bindings/python/rust/llm/block_manager/vllm/connector/worker.rs @@ -96,25 +96,25 @@ impl KvConnectorWorker { page_size: usize, device_id: usize, dtype_width_bytes: usize, - kv_caches: HashMap>, + kv_caches: Vec<(String, Py)>, ) -> PyResult<()> { if self.kvbm_worker.get().is_some() { tracing::warn!("kvbm worker already registered"); return Err(to_pyerr(anyhow::anyhow!("kvbm worker already registered"))); } - // TODO: pass in the sorted (layer_name, tensor) such that the order of the list matches the order of layer execution in the model + // Process kv_caches in layer execution order (already sorted by layer index) + let mut vllm_tensors = Vec::new(); for (layer_name, torch_tensor) in kv_caches { let vllm_tensor = Arc::new(VllmTensor::new(torch_tensor).map_err(to_pyerr)?); tracing::trace!("Registering KV cache layer: {layer_name}, tensor: {vllm_tensor:?}"); - self.kv_caches.insert(layer_name, vllm_tensor); - } - let vllm_tensors: Vec> = self - .kv_caches - .values() - .map(|tensor| tensor.clone() as Arc) - .collect(); + // Store for later lookup by name + self.kv_caches.insert(layer_name, vllm_tensor.clone()); + + // Build ordered tensor list for worker config + vllm_tensors.push(vllm_tensor as Arc); + } let config = KvbmWorkerConfig::builder() .drt(self.drt.clone()) diff --git a/lib/bindings/python/src/dynamo/llm/vllm_integration/connector_worker.py b/lib/bindings/python/src/dynamo/llm/vllm_integration/connector_worker.py index 21a29268e6..3892e74ea8 100644 --- a/lib/bindings/python/src/dynamo/llm/vllm_integration/connector_worker.py +++ b/lib/bindings/python/src/dynamo/llm/vllm_integration/connector_worker.py @@ -12,6 +12,7 @@ import torch from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata +from vllm.model_executor.models.utils import extract_layer_index from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE if TYPE_CHECKING: @@ -63,28 +64,43 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): ) cache_config = self.vllm_config.cache_config - shape = list(kv_caches.values())[0].shape + # Create ordered list of (layer_name, tensor) tuples sorted by layer index + ordered_kv_caches = [ + (layer_name, tensor) + for layer_name, tensor in sorted( + kv_caches.items(), key=lambda item: extract_layer_index(item[0]) + ) + ] + + # Get first tensor to extract common properties + first_tensor = ordered_kv_caches[0][1] + shape = first_tensor.shape + # Validate all tensors have same shape if not all(t.shape == shape for t in kv_caches.values()): raise NotImplementedError( "Hybrid models with different KV cache shapes are not supported yet." ) + # Extract parameters # TODO: Assume the block dimension is within the first 2. This will break if you're doing something weird like having 1 or 2 device blocks. num_device_blocks = max(shape[0], shape[1]) page_size = cache_config.block_size - tensors = list(kv_caches.values()) + device_id = first_tensor.device.index + # Determine cache dtype if cache_config.cache_dtype == "auto": kv_cache_dtype = self.vllm_config.model_config.dtype else: kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] - device_id = tensors[0].device.index - - # extract necessary bits to construct a KvbmWorker + # Register with connector using ordered data self._connector.register_kv_caches( - num_device_blocks, page_size, device_id, kv_cache_dtype.itemsize, kv_caches + num_device_blocks, + page_size, + device_id, + kv_cache_dtype.itemsize, + ordered_kv_caches, ) def bind_connector_metadata(self, data: bytes) -> None: