Skip to content

Commit dbbcc31

Browse files
committed
refactor: separate pyo3 from rust in vllm connector worker leader
1 parent 8c4beb7 commit dbbcc31

File tree

5 files changed

+136
-64
lines changed

5 files changed

+136
-64
lines changed

lib/bindings/python/rust/llm/block_manager/vllm.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ fn _vllm_integration(m: &Bound<'_, PyModule>) -> PyResult<()> {
4747
m.add_class::<BlockStates>()?;
4848
m.add_class::<SlotUpdate>()?;
4949

50-
m.add_class::<connector::worker::KvConnectorWorker>()?;
50+
m.add_class::<connector::worker::PyKvConnectorWorker>()?;
5151
m.add_class::<connector::leader::PyKvConnectorLeader>()?;
5252
m.add_class::<connector::SchedulerOutput>()?;
5353
Ok(())

lib/bindings/python/rust/llm/block_manager/vllm/connector/leader.rs

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ use std::{
3030
};
3131
use tokio::sync::mpsc;
3232
use tokio;
33-
use pyo3_async_runtimes;
3433

3534
type VllmLocality = Logical<DistributedLeaderWorkerResources>;
3635

@@ -41,6 +40,7 @@ impl From<SlotError> for PyErr {
4140
}
4241
use dynamo_llm::recorder::Recorder;
4342
use tokio_util::sync::CancellationToken;
43+
use anyhow;
4444

4545

4646
pub trait Leader: Send + Sync + std::fmt::Debug {
@@ -49,29 +49,29 @@ pub trait Leader: Send + Sync + std::fmt::Debug {
4949
request_id: String,
5050
request_num_tokens: usize,
5151
num_computed_tokens: usize,
52-
) -> PyResult<(usize, bool)>;
52+
) -> anyhow::Result<(usize, bool)>;
5353

5454
fn update_state_after_alloc(
5555
&mut self,
5656
request_id: String,
5757
block_ids: Vec<BlockId>,
5858
num_external_tokens: usize,
59-
) -> PyResult<()>;
59+
) -> anyhow::Result<()>;
6060

6161
fn build_connector_metadata(
6262
&mut self,
6363
scheduler_output: SchedulerOutput,
64-
) -> PyResult<Vec<u8>>;
64+
) -> anyhow::Result<Vec<u8>>;
6565

6666
fn request_finished(
6767
&mut self,
6868
request_id: String,
6969
block_ids: Vec<BlockId>,
70-
) -> PyResult<bool>;
70+
) -> anyhow::Result<bool>;
7171

7272
fn has_slot(&self, request_id: String) -> bool;
7373

74-
fn create_slot(&mut self, request: KvbmRequest, tokens: Vec<u32>) -> PyResult<()>;
74+
fn create_slot(&mut self, request: KvbmRequest, tokens: Vec<u32>) -> anyhow::Result<()>;
7575
}
7676

7777
#[derive(Debug)]
@@ -128,16 +128,16 @@ impl Leader for KvConnectorLeader {
128128
request_id: String,
129129
request_num_tokens: usize,
130130
num_computed_tokens: usize,
131-
) -> PyResult<(usize, bool)> {
131+
) -> anyhow::Result<(usize, bool)> {
132132
tracing::debug!(
133133
"request_num_tokens: {request_num_tokens}; num_computed_tokens: {num_computed_tokens}"
134134
);
135135

136136
// the number of device matched tokens should be less than or equal to the number of tokens in the request
137137
debug_assert!(num_computed_tokens % self.block_size == 0);
138138

139-
let shared_slot = self.slot_manager.get_slot(&request_id).map_err(to_pyerr)?;
140-
let mut slot = shared_slot.lock().map_err(to_pyerr)?;
139+
let shared_slot = self.slot_manager.get_slot(&request_id)?;
140+
let mut slot = shared_slot.lock().map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?;
141141

142142
// vllm is telling us that the tokens have been computed, since we do not have insight into the device pool
143143
// we accept this and advance the computed position
@@ -179,16 +179,16 @@ impl Leader for KvConnectorLeader {
179179
request_id: String,
180180
block_ids: Vec<BlockId>,
181181
num_external_tokens: usize,
182-
) -> PyResult<()> {
182+
) -> anyhow::Result<()> {
183183
tracing::debug!(
184184
request_id,
185185
"num_device_blocks: {}; num_external_tokens: {}",
186186
block_ids.len(),
187187
num_external_tokens
188188
);
189189

190-
let shared_slot = self.slot_manager.get_slot(&request_id).map_err(to_pyerr)?;
191-
let mut slot = shared_slot.lock().map_err(to_pyerr)?;
190+
let shared_slot = self.slot_manager.get_slot(&request_id)?;
191+
let mut slot = shared_slot.lock().map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?;
192192

193193
slot.append_mutable_device_blocks(&block_ids)?;
194194

@@ -211,7 +211,7 @@ impl Leader for KvConnectorLeader {
211211
fn build_connector_metadata(
212212
&mut self,
213213
scheduler_output: SchedulerOutput,
214-
) -> PyResult<Vec<u8>> {
214+
) -> anyhow::Result<Vec<u8>> {
215215
// the iteration counter is used to track the number of times we have built the connector metadata
216216
// all connetor operations have the iteration counter at which they were issued.
217217
// this allows operations to be lazily enqueued to the transfer engine
@@ -234,8 +234,8 @@ impl Leader for KvConnectorLeader {
234234
// This is kind of a nice abstraction as it keeps the events simplier; however, we now create the request-slot
235235
// once for onboarding (this loop), then again for prefill/decode (new_requests loop).
236236
for request_id in onboarding_slots.iter() {
237-
let shared_slot = self.slot_manager.get_slot(request_id).map_err(to_pyerr)?;
238-
let mut slot = shared_slot.lock().map_err(to_pyerr)?;
237+
let shared_slot = self.slot_manager.get_slot(request_id)?;
238+
let mut slot = shared_slot.lock().map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?;
239239

240240
md.create_slot(request_id.clone());
241241

@@ -256,8 +256,8 @@ impl Leader for KvConnectorLeader {
256256
"request_id {request_id} not found in inflight_requests: "
257257
);
258258

259-
let shared_slot = self.slot_manager.get_slot(request_id).map_err(to_pyerr)?;
260-
let mut slot = shared_slot.lock().map_err(to_pyerr)?;
259+
let shared_slot = self.slot_manager.get_slot(request_id)?;
260+
let mut slot = shared_slot.lock().map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?;
261261

262262
// inform the worker that a new request-slot should be created
263263
md.create_slot(new_req.request_id.clone());
@@ -297,8 +297,8 @@ impl Leader for KvConnectorLeader {
297297
"request_id {request_id} not found in inflight_requests: "
298298
);
299299

300-
let shared_slot = self.slot_manager.get_slot(request_id).map_err(to_pyerr)?;
301-
let mut slot = shared_slot.lock().map_err(to_pyerr)?;
300+
let shared_slot = self.slot_manager.get_slot(request_id)?;
301+
let mut slot = shared_slot.lock().map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?;
302302

303303
let scheduled_tokens = *scheduler_output
304304
.num_scheduled_tokens
@@ -323,16 +323,16 @@ impl Leader for KvConnectorLeader {
323323
}
324324

325325
tracing::debug!("scheduler_output: {scheduler_output:#?}");
326-
serde_json::to_vec(&md).map_err(to_pyerr)
326+
serde_json::to_vec(&md).map_err(|e| anyhow::anyhow!("Failed to serialize connector metadata: {}", e))
327327
}
328328

329-
fn request_finished(&mut self, request_id: String, block_ids: Vec<BlockId>) -> PyResult<bool> {
329+
fn request_finished(&mut self, request_id: String, block_ids: Vec<BlockId>) -> anyhow::Result<bool> {
330330
tracing::debug!("Request finished: {request_id}; block_ids: {block_ids:?}");
331331
// grab the slot
332-
let shared_slot = self.slot_manager.get_slot(&request_id).map_err(to_pyerr)?;
332+
let shared_slot = self.slot_manager.get_slot(&request_id)?;
333333

334334
// mark the slot as finished
335-
let mut slot = shared_slot.lock().map_err(to_pyerr)?;
335+
let mut slot = shared_slot.lock().map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?;
336336
slot.mark_as_finished(self.iteration_counter)?;
337337

338338
// todo: allow the request to resolve when it should exit
@@ -361,7 +361,7 @@ impl Leader for KvConnectorLeader {
361361

362362
/// Create a new slot for the given request ID.
363363
/// This is used to create a new slot for the request.
364-
fn create_slot(&mut self, request: KvbmRequest, tokens: Vec<u32>) -> PyResult<()> {
364+
fn create_slot(&mut self, request: KvbmRequest, tokens: Vec<u32>) -> anyhow::Result<()> {
365365
self.slot_manager
366366
.create_slot(&request.request_id, tokens, request.salt_hash)?;
367367

@@ -414,7 +414,7 @@ impl PyKvConnectorLeader {
414414
request_num_tokens: usize,
415415
num_computed_tokens: usize,
416416
) -> PyResult<(usize, bool)> {
417-
self.connector_leader.get_num_new_matched_tokens(request_id, request_num_tokens, num_computed_tokens)
417+
self.connector_leader.get_num_new_matched_tokens(request_id, request_num_tokens, num_computed_tokens).map_err(to_pyerr)
418418
}
419419

420420
fn update_state_after_alloc(
@@ -423,25 +423,25 @@ impl PyKvConnectorLeader {
423423
block_ids: Vec<BlockId>,
424424
num_external_tokens: usize,
425425
) -> PyResult<()> {
426-
self.connector_leader.update_state_after_alloc(request_id, block_ids, num_external_tokens)
426+
self.connector_leader.update_state_after_alloc(request_id, block_ids, num_external_tokens).map_err(to_pyerr)
427427
}
428428

429429
fn build_connector_metadata(
430430
&mut self,
431431
scheduler_output: SchedulerOutput,
432432
) -> PyResult<Vec<u8>> {
433-
self.connector_leader.build_connector_metadata(scheduler_output)
433+
self.connector_leader.build_connector_metadata(scheduler_output).map_err(to_pyerr)
434434
}
435435

436436
fn request_finished(&mut self, request_id: &str, block_ids: Vec<BlockId>) -> PyResult<bool> {
437-
self.connector_leader.request_finished(request_id.to_string(), block_ids)
437+
self.connector_leader.request_finished(request_id.to_string(), block_ids).map_err(to_pyerr)
438438
}
439439

440440
fn has_slot(&self, request_id: &str) -> bool {
441441
self.connector_leader.has_slot(request_id.to_string())
442442
}
443443

444444
fn create_slot(&mut self, request: KvbmRequest, tokens: Vec<u32>) -> PyResult<()> {
445-
self.connector_leader.create_slot(request, tokens)
445+
self.connector_leader.create_slot(request, tokens).map_err(to_pyerr)
446446
}
447447
}

lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/recorder.rs

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use super::*;
2+
use anyhow;
23

34

45
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -108,9 +109,7 @@ impl KvConnectorLeaderRecorder {
108109
let output_path = "/tmp/records.jsonl";
109110
tracing::info!("recording events to {}", output_path);
110111

111-
// Create recorder synchronously using pyo3 async runtime
112-
let runtime = pyo3_async_runtimes::tokio::get_runtime();
113-
let recorder = runtime.block_on(async {
112+
let recorder = drt.runtime().primary().block_on(async {
114113
Recorder::new(token, &output_path, None, None, None).await
115114
}).unwrap();
116115

@@ -125,7 +124,7 @@ impl KvConnectorLeaderRecorder {
125124
let (unbounded_tx, unbounded_rx) = mpsc::unbounded_channel();
126125
let recorder_tx = recorder.event_sender();
127126

128-
let _ = runtime.spawn(Self::forward_unbounded_to_sender(unbounded_rx, recorder_tx));
127+
let _ = drt.runtime().primary().spawn(Self::forward_unbounded_to_sender(unbounded_rx, recorder_tx));
129128

130129
Self {
131130
_recorder: recorder,
@@ -158,7 +157,7 @@ impl Leader for KvConnectorLeaderRecorder {
158157
request_id: String,
159158
request_num_tokens: usize,
160159
num_computed_tokens: usize,
161-
) -> PyResult<(usize, bool)> {
160+
) -> anyhow::Result<(usize, bool)> {
162161
let input_copy = GetNumNewMatchedTokensInput {
163162
request_id: request_id.clone(),
164163
request_num_tokens: request_num_tokens.clone(),
@@ -183,7 +182,7 @@ impl Leader for KvConnectorLeaderRecorder {
183182
request_id: String,
184183
block_ids: Vec<BlockId>,
185184
num_external_tokens: usize,
186-
) -> PyResult<()> {
185+
) -> anyhow::Result<()> {
187186
let input_copy = UpdateStateAfterAllocInput {
188187
request_id: request_id.clone(),
189188
block_ids: block_ids.clone(),
@@ -197,7 +196,7 @@ impl Leader for KvConnectorLeaderRecorder {
197196
fn build_connector_metadata(
198197
&mut self,
199198
scheduler_output: SchedulerOutput,
200-
) -> PyResult<Vec<u8>> {
199+
) -> anyhow::Result<Vec<u8>> {
201200
let input_copy = BuildConnectorMetaInput {
202201
scheduler_output: scheduler_output.clone(),
203202
};
@@ -210,7 +209,7 @@ impl Leader for KvConnectorLeaderRecorder {
210209
output
211210
}
212211

213-
fn request_finished(&mut self, request_id: String, block_ids: Vec<BlockId>) -> PyResult<bool> {
212+
fn request_finished(&mut self, request_id: String, block_ids: Vec<BlockId>) -> anyhow::Result<bool> {
214213
let input_copy = RequestFinishedInput {
215214
request_id: request_id.clone(),
216215
block_ids: block_ids.clone(),
@@ -239,7 +238,7 @@ impl Leader for KvConnectorLeaderRecorder {
239238

240239
/// Create a new slot for the given request ID.
241240
/// This is used to create a new slot for the request.
242-
fn create_slot(&mut self, request: KvbmRequest, tokens: Vec<u32>) -> PyResult<()> {
241+
fn create_slot(&mut self, request: KvbmRequest, tokens: Vec<u32>) -> anyhow::Result<()> {
243242
let input_copy = CreateSlotInput {
244243
request: request.clone(),
245244
tokens: tokens.clone(),

0 commit comments

Comments
 (0)