|
12 | 12 | import torch |
13 | 13 | from vllm.config import VllmConfig |
14 | 14 | from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata |
| 15 | +from vllm.model_executor.models.utils import extract_layer_index |
15 | 16 | from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE |
16 | 17 |
|
17 | 18 | if TYPE_CHECKING: |
@@ -63,28 +64,43 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): |
63 | 64 | ) |
64 | 65 | cache_config = self.vllm_config.cache_config |
65 | 66 |
|
66 | | - shape = list(kv_caches.values())[0].shape |
| 67 | + # Create ordered list of (layer_name, tensor) tuples sorted by layer index |
| 68 | + ordered_kv_caches = [ |
| 69 | + (layer_name, tensor) |
| 70 | + for layer_name, tensor in sorted( |
| 71 | + kv_caches.items(), key=lambda item: extract_layer_index(item[0]) |
| 72 | + ) |
| 73 | + ] |
| 74 | + |
| 75 | + # Get first tensor to extract common properties |
| 76 | + first_tensor = ordered_kv_caches[0][1] |
| 77 | + shape = first_tensor.shape |
67 | 78 |
|
| 79 | + # Validate all tensors have same shape |
68 | 80 | if not all(t.shape == shape for t in kv_caches.values()): |
69 | 81 | raise NotImplementedError( |
70 | 82 | "Hybrid models with different KV cache shapes are not supported yet." |
71 | 83 | ) |
72 | 84 |
|
| 85 | + # Extract parameters |
73 | 86 | # TODO: Assume the block dimension is within the first 2. This will break if you're doing something weird like having 1 or 2 device blocks. |
74 | 87 | num_device_blocks = max(shape[0], shape[1]) |
75 | 88 | page_size = cache_config.block_size |
76 | | - tensors = list(kv_caches.values()) |
| 89 | + device_id = first_tensor.device.index |
77 | 90 |
|
| 91 | + # Determine cache dtype |
78 | 92 | if cache_config.cache_dtype == "auto": |
79 | 93 | kv_cache_dtype = self.vllm_config.model_config.dtype |
80 | 94 | else: |
81 | 95 | kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] |
82 | 96 |
|
83 | | - device_id = tensors[0].device.index |
84 | | - |
85 | | - # extract necessary bits to construct a KvbmWorker |
| 97 | + # Register with connector using ordered data |
86 | 98 | self._connector.register_kv_caches( |
87 | | - num_device_blocks, page_size, device_id, kv_cache_dtype.itemsize, kv_caches |
| 99 | + num_device_blocks, |
| 100 | + page_size, |
| 101 | + device_id, |
| 102 | + kv_cache_dtype.itemsize, |
| 103 | + ordered_kv_caches, |
88 | 104 | ) |
89 | 105 |
|
90 | 106 | def bind_connector_metadata(self, data: bytes) -> None: |
|
0 commit comments