@@ -18,7 +18,7 @@ use axum::{
1818 Json , Router ,
1919} ;
2020use dynamo_runtime:: {
21- pipeline:: { AsyncEngineContextProvider , Context } ,
21+ pipeline:: { AsyncEngineContext , AsyncEngineContextProvider , Context } ,
2222 protocols:: annotated:: AnnotationsProvider ,
2323} ;
2424use futures:: { stream, StreamExt } ;
@@ -155,6 +155,54 @@ fn get_or_create_request_id(primary: Option<&str>, headers: &HeaderMap) -> Strin
155155 uuid. to_string ( )
156156}
157157
158+ // A RAII guard to ensure that the context is stopped when the request is dropped by client.
159+ // Request fututures are dropped in axum when the client disconnects.
160+ // https://github.com/tokio-rs/axum/discussions/1094
161+ // may be defused to prevent stopping the context and send a control message via
162+ // stop_generating
163+ struct CtxDropGuard {
164+ ctx : Arc < dyn AsyncEngineContext > ,
165+ issue_stop_generating : bool ,
166+ verbose : bool ,
167+ }
168+
169+ impl CtxDropGuard {
170+ fn new ( ctx : Arc < dyn AsyncEngineContext > ) -> Self {
171+ CtxDropGuard {
172+ ctx,
173+ issue_stop_generating : true ,
174+ verbose : true ,
175+ }
176+ }
177+ // request succeeded, no need to stop generating
178+ // takes ownership
179+ fn defuse ( & mut self ) {
180+ self . issue_stop_generating = false ;
181+ self . verbose = false ;
182+ }
183+
184+ // no-op, moves the guard to a thread.
185+ fn mute ( & mut self ) {
186+ self . verbose = false ;
187+ }
188+ }
189+
190+ impl Drop for CtxDropGuard {
191+ fn drop ( & mut self ) {
192+ if self . issue_stop_generating {
193+ self . ctx . stop_generating ( ) ;
194+ if self . verbose {
195+ tracing:: info!( "Stopping generation for request_id: {}" , self . ctx. id( ) ) ;
196+ } else {
197+ tracing:: trace!(
198+ "Stopping generation or end of successful request_id: {}" ,
199+ self . ctx. id( )
200+ ) ;
201+ }
202+ }
203+ }
204+ }
205+
158206/// OpenAI Completions Request Handler
159207///
160208/// This method will handle the incoming request for the `/v1/completions endpoint`. The endpoint is a "source"
@@ -232,7 +280,7 @@ async fn completions(
232280 state
233281 . metrics_clone ( )
234282 . create_inflight_guard ( model, Endpoint :: Completions , streaming) ;
235-
283+ let mut drop_guard = CtxDropGuard :: new ( request . context ( ) . clone ( ) ) ;
236284 let mut response_collector = state. metrics_clone ( ) . create_response_collector ( model) ;
237285
238286 // prepare to process any annotations
@@ -269,6 +317,7 @@ async fn completions(
269317
270318 if streaming {
271319 let stream = stream. map ( move |response| {
320+ drop_guard. mute ( ) ;
272321 process_event_converter ( EventConverter :: from ( response) , & mut response_collector)
273322 } ) ;
274323 let stream = monitor_for_disconnects ( stream, ctx, inflight_guard, stream_handle) ;
@@ -294,6 +343,7 @@ async fn completions(
294343 } ) ?;
295344
296345 inflight_guard. mark_ok ( ) ;
346+ drop_guard. defuse ( ) ;
297347 Ok ( Json ( response) . into_response ( ) )
298348 }
299349}
@@ -332,6 +382,8 @@ async fn embeddings(
332382 // todo - inherit request_id from distributed trace details
333383 let request = Context :: with_id ( request, request_id. clone ( ) ) ;
334384
385+ let mut drop_guard = CtxDropGuard :: new ( request. context ( ) . clone ( ) ) ;
386+
335387 // issue the generate call on the engine
336388 let stream = engine
337389 . generate ( request)
@@ -352,6 +404,7 @@ async fn embeddings(
352404 } ) ?;
353405
354406 inflight. mark_ok ( ) ;
407+ drop_guard. defuse ( ) ;
355408 Ok ( Json ( response) . into_response ( ) )
356409}
357410
@@ -438,6 +491,7 @@ async fn chat_completions(
438491 req. inner . stream = Some ( true ) ;
439492 req
440493 } ) ;
494+ let mut drop_guard = CtxDropGuard :: new ( request. context ( ) . clone ( ) ) ;
441495
442496 // todo - make the protocols be optional for model name
443497 // todo - when optional, if none, apply a default
@@ -494,6 +548,7 @@ async fn chat_completions(
494548 stream_handle. arm ( ) ;
495549
496550 let stream = stream. map ( move |response| {
551+ drop_guard. mute ( ) ;
497552 process_event_converter ( EventConverter :: from ( response) , & mut response_collector)
498553 } ) ;
499554 let stream = monitor_for_disconnects ( stream, ctx, inflight_guard, stream_handle) ;
@@ -522,6 +577,7 @@ async fn chat_completions(
522577 } ) ?;
523578
524579 inflight_guard. mark_ok ( ) ;
580+ drop_guard. defuse ( ) ;
525581 Ok ( Json ( response) . into_response ( ) )
526582 }
527583}
0 commit comments