Skip to content

Commit b1a2677

Browse files
committed
refactor: separate python and rust in vllm connector worker, similar to leader
1 parent 8c4beb7 commit b1a2677

File tree

3 files changed

+76
-10
lines changed

3 files changed

+76
-10
lines changed

lib/bindings/python/rust/llm/block_manager/vllm.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ fn _vllm_integration(m: &Bound<'_, PyModule>) -> PyResult<()> {
4747
m.add_class::<BlockStates>()?;
4848
m.add_class::<SlotUpdate>()?;
4949

50-
m.add_class::<connector::worker::KvConnectorWorker>()?;
50+
m.add_class::<connector::worker::PyKvConnectorWorker>()?;
5151
m.add_class::<connector::leader::PyKvConnectorLeader>()?;
5252
m.add_class::<connector::SchedulerOutput>()?;
5353
Ok(())

lib/bindings/python/rust/llm/block_manager/vllm/connector/worker.rs

Lines changed: 74 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,28 @@ use dynamo_llm::block_manager::storage::torch::TorchTensor;
2121
use dynamo_runtime::utils::task::CriticalTaskExecutionHandle;
2222
use dynamo_runtime::DistributedRuntime;
2323

24-
#[pyclass]
24+
pub trait Worker: Send + Sync {
25+
fn register_kv_caches(
26+
&mut self,
27+
num_device_blocks: usize,
28+
page_size: usize,
29+
device_id: usize,
30+
dtype_width_bytes: usize,
31+
kv_caches: Vec<(String, Py<PyAny>)>,
32+
) -> PyResult<()>;
33+
34+
fn bind_connector_metadata(&mut self, metadata: Vec<u8>) -> PyResult<()>;
35+
36+
fn clear_connector_metadata(&mut self);
37+
38+
fn save_kv_layer(&mut self, layer_name: String, kv_layer: Py<PyAny>) -> PyResult<()>;
39+
40+
fn get_finished(
41+
&mut self,
42+
finished_requests: HashSet<String>,
43+
) -> (HashSet<String>, HashSet<String>);
44+
}
45+
2546
pub struct KvConnectorWorker {
2647
drt: DistributedRuntime,
2748
kvbm_worker: OnceLock<KvbmWorker>,
@@ -45,9 +66,7 @@ pub struct KvConnectorWorker {
4566
layers_complete: usize,
4667
}
4768

48-
#[pymethods]
4969
impl KvConnectorWorker {
50-
#[new]
5170
fn new(py_drt: PyDistributedRuntime, vllm_worker_id: String) -> PyResult<Self> {
5271
let drt = py_drt.inner.clone();
5372
let runtime = drt.runtime().primary();
@@ -85,12 +104,14 @@ impl KvConnectorWorker {
85104
layers_complete: 0,
86105
})
87106
}
107+
}
88108

109+
impl Worker for KvConnectorWorker {
89110
/// Registers the KV caches with the KVBM worker.
90111
///
91112
/// The Dynamo KVBM worker is lazily initialized when the first KV cache is registered.
92113
/// This process establishes a connection between all KVBM workers and the leader.
93-
pub fn register_kv_caches(
114+
fn register_kv_caches(
94115
&mut self,
95116
num_device_blocks: usize,
96117
page_size: usize,
@@ -148,7 +169,7 @@ impl KvConnectorWorker {
148169
/// Loads the metadata from the leader.
149170
/// This action translates the metadata into a set of actions that the worker will perform.
150171
/// All actions much be assigned to a slot before [`KvConnectorWorker::clear_metadata`] is called.
151-
pub fn bind_connector_metadata(&mut self, metadata: Vec<u8>) -> PyResult<()> {
172+
fn bind_connector_metadata(&mut self, metadata: Vec<u8>) -> PyResult<()> {
152173
debug_assert!(!self.bound, "connector metadata already bound");
153174
let metadata: ConnectorMetadata = serde_json::from_slice(&metadata).map_err(to_pyerr)?;
154175
self.bound = true;
@@ -214,7 +235,7 @@ impl KvConnectorWorker {
214235
}
215236

216237
/// Clears the connector metadata and marks the iteration as complete.
217-
pub fn clear_connector_metadata(&mut self) {
238+
fn clear_connector_metadata(&mut self) {
218239
tracing::debug!(iteration = self.iteration, "clearing connector metadata");
219240
debug_assert!(self.bound, "connector metadata not bound");
220241
self.bound = false;
@@ -227,7 +248,7 @@ impl KvConnectorWorker {
227248

228249
/// Trigger layer-wise completion signals.
229250
/// Trigger block-wise completion signals afer last layer.
230-
pub fn save_kv_layer(&mut self, _layer_name: String, _kv_layer: Py<PyAny>) -> PyResult<()> {
251+
fn save_kv_layer(&mut self, layer_name: String, kv_layer: Py<PyAny>) -> PyResult<()> {
231252
self.layers_complete += 1;
232253
if self.layers_complete == self.kv_caches.len() {
233254
let offloading_operations = std::mem::take(&mut self.offloading_operations);
@@ -238,7 +259,7 @@ impl KvConnectorWorker {
238259
Ok(())
239260
}
240261

241-
pub fn get_finished(
262+
fn get_finished(
242263
&mut self,
243264
finished_requests: HashSet<String>,
244265
) -> (HashSet<String>, HashSet<String>) {
@@ -321,3 +342,48 @@ impl KvConnectorWorker {
321342
(is_finished_offloading, is_finished_onboarding)
322343
}
323344
}
345+
346+
#[pyclass]
347+
pub struct PyKvConnectorWorker {
348+
connector_worker: Box<dyn Worker>,
349+
}
350+
351+
#[pymethods]
352+
impl PyKvConnectorWorker {
353+
#[new]
354+
#[pyo3(signature = (py_drt, vllm_worker_id))]
355+
pub fn new(py_drt: PyDistributedRuntime, vllm_worker_id: String) -> PyResult<Self> {
356+
let connector_worker: Box<dyn Worker> = Box::new(KvConnectorWorker::new(py_drt, vllm_worker_id)?);
357+
Ok(Self { connector_worker })
358+
}
359+
360+
pub fn register_kv_caches(
361+
&mut self,
362+
num_device_blocks: usize,
363+
page_size: usize,
364+
device_id: usize,
365+
dtype_width_bytes: usize,
366+
kv_caches: Vec<(String, Py<PyAny>)>,
367+
) -> PyResult<()> {
368+
self.connector_worker.register_kv_caches(num_device_blocks, page_size, device_id, dtype_width_bytes, kv_caches)
369+
}
370+
371+
pub fn bind_connector_metadata(&mut self, metadata: Vec<u8>) -> PyResult<()> {
372+
self.connector_worker.bind_connector_metadata(metadata)
373+
}
374+
375+
pub fn clear_connector_metadata(&mut self) {
376+
self.connector_worker.clear_connector_metadata()
377+
}
378+
379+
pub fn save_kv_layer(&mut self, layer_name: String, kv_layer: Py<PyAny>) -> PyResult<()> {
380+
self.connector_worker.save_kv_layer(layer_name, kv_layer)
381+
}
382+
383+
pub fn get_finished(
384+
&mut self,
385+
finished_requests: HashSet<String>,
386+
) -> (HashSet<String>, HashSet<String>) {
387+
self.connector_worker.get_finished(finished_requests)
388+
}
389+
}

lib/bindings/python/src/dynamo/llm/vllm_integration/rust.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
BlockStates = getattr(_vllm_integration, "BlockStates")
1717
SlotUpdate = getattr(_vllm_integration, "SlotUpdate")
1818

19-
KvConnectorWorker = getattr(_vllm_integration, "KvConnectorWorker")
19+
KvConnectorWorker = getattr(_vllm_integration, "PyKvConnectorWorker")
2020
KvConnectorLeader = getattr(_vllm_integration, "PyKvConnectorLeader")
2121
SchedulerOutput = getattr(_vllm_integration, "SchedulerOutput")
2222

0 commit comments

Comments
 (0)