@@ -21,7 +21,28 @@ use dynamo_llm::block_manager::storage::torch::TorchTensor;
2121use dynamo_runtime:: utils:: task:: CriticalTaskExecutionHandle ;
2222use dynamo_runtime:: DistributedRuntime ;
2323
24- #[ pyclass]
24+ pub trait Worker : Send + Sync {
25+ fn register_kv_caches (
26+ & mut self ,
27+ num_device_blocks : usize ,
28+ page_size : usize ,
29+ device_id : usize ,
30+ dtype_width_bytes : usize ,
31+ kv_caches : Vec < ( String , Py < PyAny > ) > ,
32+ ) -> PyResult < ( ) > ;
33+
34+ fn bind_connector_metadata ( & mut self , metadata : Vec < u8 > ) -> PyResult < ( ) > ;
35+
36+ fn clear_connector_metadata ( & mut self ) ;
37+
38+ fn save_kv_layer ( & mut self , layer_name : String , kv_layer : Py < PyAny > ) -> PyResult < ( ) > ;
39+
40+ fn get_finished (
41+ & mut self ,
42+ finished_requests : HashSet < String > ,
43+ ) -> ( HashSet < String > , HashSet < String > ) ;
44+ }
45+
2546pub struct KvConnectorWorker {
2647 drt : DistributedRuntime ,
2748 kvbm_worker : OnceLock < KvbmWorker > ,
@@ -45,9 +66,7 @@ pub struct KvConnectorWorker {
4566 layers_complete : usize ,
4667}
4768
48- #[ pymethods]
4969impl KvConnectorWorker {
50- #[ new]
5170 fn new ( py_drt : PyDistributedRuntime , vllm_worker_id : String ) -> PyResult < Self > {
5271 let drt = py_drt. inner . clone ( ) ;
5372 let runtime = drt. runtime ( ) . primary ( ) ;
@@ -85,12 +104,14 @@ impl KvConnectorWorker {
85104 layers_complete : 0 ,
86105 } )
87106 }
107+ }
88108
109+ impl Worker for KvConnectorWorker {
89110 /// Registers the KV caches with the KVBM worker.
90111 ///
91112 /// The Dynamo KVBM worker is lazily initialized when the first KV cache is registered.
92113 /// This process establishes a connection between all KVBM workers and the leader.
93- pub fn register_kv_caches (
114+ fn register_kv_caches (
94115 & mut self ,
95116 num_device_blocks : usize ,
96117 page_size : usize ,
@@ -148,7 +169,7 @@ impl KvConnectorWorker {
148169 /// Loads the metadata from the leader.
149170 /// This action translates the metadata into a set of actions that the worker will perform.
150171 /// All actions much be assigned to a slot before [`KvConnectorWorker::clear_metadata`] is called.
151- pub fn bind_connector_metadata ( & mut self , metadata : Vec < u8 > ) -> PyResult < ( ) > {
172+ fn bind_connector_metadata ( & mut self , metadata : Vec < u8 > ) -> PyResult < ( ) > {
152173 debug_assert ! ( !self . bound, "connector metadata already bound" ) ;
153174 let metadata: ConnectorMetadata = serde_json:: from_slice ( & metadata) . map_err ( to_pyerr) ?;
154175 self . bound = true ;
@@ -214,7 +235,7 @@ impl KvConnectorWorker {
214235 }
215236
216237 /// Clears the connector metadata and marks the iteration as complete.
217- pub fn clear_connector_metadata ( & mut self ) {
238+ fn clear_connector_metadata ( & mut self ) {
218239 tracing:: debug!( iteration = self . iteration, "clearing connector metadata" ) ;
219240 debug_assert ! ( self . bound, "connector metadata not bound" ) ;
220241 self . bound = false ;
@@ -227,7 +248,7 @@ impl KvConnectorWorker {
227248
228249 /// Trigger layer-wise completion signals.
229250 /// Trigger block-wise completion signals afer last layer.
230- pub fn save_kv_layer ( & mut self , _layer_name : String , _kv_layer : Py < PyAny > ) -> PyResult < ( ) > {
251+ fn save_kv_layer ( & mut self , layer_name : String , kv_layer : Py < PyAny > ) -> PyResult < ( ) > {
231252 self . layers_complete += 1 ;
232253 if self . layers_complete == self . kv_caches . len ( ) {
233254 let offloading_operations = std:: mem:: take ( & mut self . offloading_operations ) ;
@@ -238,7 +259,7 @@ impl KvConnectorWorker {
238259 Ok ( ( ) )
239260 }
240261
241- pub fn get_finished (
262+ fn get_finished (
242263 & mut self ,
243264 finished_requests : HashSet < String > ,
244265 ) -> ( HashSet < String > , HashSet < String > ) {
@@ -321,3 +342,48 @@ impl KvConnectorWorker {
321342 ( is_finished_offloading, is_finished_onboarding)
322343 }
323344}
345+
346+ #[ pyclass]
347+ pub struct PyKvConnectorWorker {
348+ connector_worker : Box < dyn Worker > ,
349+ }
350+
351+ #[ pymethods]
352+ impl PyKvConnectorWorker {
353+ #[ new]
354+ #[ pyo3( signature = ( py_drt, vllm_worker_id) ) ]
355+ pub fn new ( py_drt : PyDistributedRuntime , vllm_worker_id : String ) -> PyResult < Self > {
356+ let connector_worker: Box < dyn Worker > = Box :: new ( KvConnectorWorker :: new ( py_drt, vllm_worker_id) ?) ;
357+ Ok ( Self { connector_worker } )
358+ }
359+
360+ pub fn register_kv_caches (
361+ & mut self ,
362+ num_device_blocks : usize ,
363+ page_size : usize ,
364+ device_id : usize ,
365+ dtype_width_bytes : usize ,
366+ kv_caches : Vec < ( String , Py < PyAny > ) > ,
367+ ) -> PyResult < ( ) > {
368+ self . connector_worker . register_kv_caches ( num_device_blocks, page_size, device_id, dtype_width_bytes, kv_caches)
369+ }
370+
371+ pub fn bind_connector_metadata ( & mut self , metadata : Vec < u8 > ) -> PyResult < ( ) > {
372+ self . connector_worker . bind_connector_metadata ( metadata)
373+ }
374+
375+ pub fn clear_connector_metadata ( & mut self ) {
376+ self . connector_worker . clear_connector_metadata ( )
377+ }
378+
379+ pub fn save_kv_layer ( & mut self , layer_name : String , kv_layer : Py < PyAny > ) -> PyResult < ( ) > {
380+ self . connector_worker . save_kv_layer ( layer_name, kv_layer)
381+ }
382+
383+ pub fn get_finished (
384+ & mut self ,
385+ finished_requests : HashSet < String > ,
386+ ) -> ( HashSet < String > , HashSet < String > ) {
387+ self . connector_worker . get_finished ( finished_requests)
388+ }
389+ }
0 commit comments