Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -96,25 +96,25 @@ impl KvConnectorWorker {
page_size: usize,
device_id: usize,
dtype_width_bytes: usize,
kv_caches: HashMap<String, Py<PyAny>>,
kv_caches: Vec<(String, Py<PyAny>)>,
) -> PyResult<()> {
if self.kvbm_worker.get().is_some() {
tracing::warn!("kvbm worker already registered");
return Err(to_pyerr(anyhow::anyhow!("kvbm worker already registered")));
}

// TODO: pass in the sorted (layer_name, tensor) such that the order of the list matches the order of layer execution in the model
// Process kv_caches in layer execution order (already sorted by layer index)
let mut vllm_tensors = Vec::new();
for (layer_name, torch_tensor) in kv_caches {
let vllm_tensor = Arc::new(VllmTensor::new(torch_tensor).map_err(to_pyerr)?);
tracing::trace!("Registering KV cache layer: {layer_name}, tensor: {vllm_tensor:?}");
self.kv_caches.insert(layer_name, vllm_tensor);
}

let vllm_tensors: Vec<Arc<dyn TorchTensor>> = self
.kv_caches
.values()
.map(|tensor| tensor.clone() as Arc<dyn TorchTensor>)
.collect();
// Store for later lookup by name
self.kv_caches.insert(layer_name, vllm_tensor.clone());

// Build ordered tensor list for worker config
vllm_tensors.push(vllm_tensor as Arc<dyn TorchTensor>);
}

let config = KvbmWorkerConfig::builder()
.drt(self.drt.clone())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
from vllm.model_executor.models.utils import extract_layer_index
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE

if TYPE_CHECKING:
Expand Down Expand Up @@ -63,28 +64,43 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
)
cache_config = self.vllm_config.cache_config

shape = list(kv_caches.values())[0].shape
# Create ordered list of (layer_name, tensor) tuples sorted by layer index
ordered_kv_caches = [
(layer_name, tensor)
for layer_name, tensor in sorted(
kv_caches.items(), key=lambda item: extract_layer_index(item[0])
)
]

# Get first tensor to extract common properties
first_tensor = ordered_kv_caches[0][1]
shape = first_tensor.shape

# Validate all tensors have same shape
if not all(t.shape == shape for t in kv_caches.values()):
raise NotImplementedError(
"Hybrid models with different KV cache shapes are not supported yet."
)

# Extract parameters
# TODO: Assume the block dimension is within the first 2. This will break if you're doing something weird like having 1 or 2 device blocks.
num_device_blocks = max(shape[0], shape[1])
page_size = cache_config.block_size
tensors = list(kv_caches.values())
device_id = first_tensor.device.index

# Determine cache dtype
if cache_config.cache_dtype == "auto":
kv_cache_dtype = self.vllm_config.model_config.dtype
else:
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]

device_id = tensors[0].device.index

# extract necessary bits to construct a KvbmWorker
# Register with connector using ordered data
self._connector.register_kv_caches(
num_device_blocks, page_size, device_id, kv_cache_dtype.itemsize, kv_caches
num_device_blocks,
page_size,
device_id,
kv_cache_dtype.itemsize,
ordered_kv_caches,
)

def bind_connector_metadata(self, data: bytes) -> None:
Expand Down
Loading