Skip to content

Commit b23be13

Browse files
committed
add raii guard for monitoring the connection
1 parent 096d117 commit b23be13

File tree

1 file changed

+58
-2
lines changed

1 file changed

+58
-2
lines changed

lib/llm/src/http/service/openai.rs

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use axum::{
1818
Json, Router,
1919
};
2020
use dynamo_runtime::{
21-
pipeline::{AsyncEngineContextProvider, Context},
21+
pipeline::{AsyncEngineContext, AsyncEngineContextProvider, Context},
2222
protocols::annotated::AnnotationsProvider,
2323
};
2424
use futures::{stream, StreamExt};
@@ -155,6 +155,54 @@ fn get_or_create_request_id(primary: Option<&str>, headers: &HeaderMap) -> Strin
155155
uuid.to_string()
156156
}
157157

158+
// A RAII guard to ensure that the context is stopped when the request is dropped by client.
159+
// Request fututures are dropped in axum when the client disconnects.
160+
// https://github.com/tokio-rs/axum/discussions/1094
161+
// may be defused to prevent stopping the context and send a control message via
162+
// stop_generating
163+
struct CtxDropGuard {
164+
ctx: Arc<dyn AsyncEngineContext>,
165+
issue_stop_generating: bool,
166+
verbose: bool,
167+
}
168+
169+
impl CtxDropGuard {
170+
fn new(ctx: Arc<dyn AsyncEngineContext>) -> Self {
171+
CtxDropGuard {
172+
ctx,
173+
issue_stop_generating: true,
174+
verbose: true,
175+
}
176+
}
177+
// request succeeded, no need to stop generating
178+
// takes ownership
179+
fn defuse(&mut self) {
180+
self.issue_stop_generating = false;
181+
self.verbose = false;
182+
}
183+
184+
// no-op, moves the guard to a thread.
185+
fn mute(&mut self) {
186+
self.verbose = false;
187+
}
188+
}
189+
190+
impl Drop for CtxDropGuard {
191+
fn drop(&mut self) {
192+
if self.issue_stop_generating {
193+
self.ctx.stop_generating();
194+
if self.verbose {
195+
tracing::info!("Stopping generation for request_id: {}", self.ctx.id());
196+
} else {
197+
tracing::trace!(
198+
"Stopping generation or end of successful request_id: {}",
199+
self.ctx.id()
200+
);
201+
}
202+
}
203+
}
204+
}
205+
158206
/// OpenAI Completions Request Handler
159207
///
160208
/// This method will handle the incoming request for the `/v1/completions endpoint`. The endpoint is a "source"
@@ -232,7 +280,7 @@ async fn completions(
232280
state
233281
.metrics_clone()
234282
.create_inflight_guard(model, Endpoint::Completions, streaming);
235-
283+
let mut drop_guard = CtxDropGuard::new(request.context().clone());
236284
let mut response_collector = state.metrics_clone().create_response_collector(model);
237285

238286
// prepare to process any annotations
@@ -269,6 +317,7 @@ async fn completions(
269317

270318
if streaming {
271319
let stream = stream.map(move |response| {
320+
drop_guard.mute();
272321
process_event_converter(EventConverter::from(response), &mut response_collector)
273322
});
274323
let stream = monitor_for_disconnects(stream, ctx, inflight_guard, stream_handle);
@@ -294,6 +343,7 @@ async fn completions(
294343
})?;
295344

296345
inflight_guard.mark_ok();
346+
drop_guard.defuse();
297347
Ok(Json(response).into_response())
298348
}
299349
}
@@ -332,6 +382,8 @@ async fn embeddings(
332382
// todo - inherit request_id from distributed trace details
333383
let request = Context::with_id(request, request_id.clone());
334384

385+
let mut drop_guard = CtxDropGuard::new(request.context().clone());
386+
335387
// issue the generate call on the engine
336388
let stream = engine
337389
.generate(request)
@@ -352,6 +404,7 @@ async fn embeddings(
352404
})?;
353405

354406
inflight.mark_ok();
407+
drop_guard.defuse();
355408
Ok(Json(response).into_response())
356409
}
357410

@@ -438,6 +491,7 @@ async fn chat_completions(
438491
req.inner.stream = Some(true);
439492
req
440493
});
494+
let mut drop_guard = CtxDropGuard::new(request.context().clone());
441495

442496
// todo - make the protocols be optional for model name
443497
// todo - when optional, if none, apply a default
@@ -494,6 +548,7 @@ async fn chat_completions(
494548
stream_handle.arm();
495549

496550
let stream = stream.map(move |response| {
551+
drop_guard.mute();
497552
process_event_converter(EventConverter::from(response), &mut response_collector)
498553
});
499554
let stream = monitor_for_disconnects(stream, ctx, inflight_guard, stream_handle);
@@ -522,6 +577,7 @@ async fn chat_completions(
522577
})?;
523578

524579
inflight_guard.mark_ok();
580+
drop_guard.defuse();
525581
Ok(Json(response).into_response())
526582
}
527583
}

0 commit comments

Comments
 (0)