Skip to content

Commit 343a481

Browse files
authored
feat: http disconnects (#2014)
1 parent e330d96 commit 343a481

File tree

9 files changed

+888
-189
lines changed

9 files changed

+888
-189
lines changed

lib/llm/src/http/service.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
2121
mod openai;
2222

23+
pub mod disconnect;
2324
pub mod error;
2425
pub mod health;
2526
pub mod metrics;
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
//! The `disconnect` module provides a mechanism for our axum http services to monitoring and responding
5+
//! to disconnects from the client.
6+
//!
7+
//! There are two potential phases in any request where we need to handle the disconnect.
8+
//!
9+
//! For unary, request-response, there is just a single phase where the primary task that axum kicks off
10+
//! to handle the request will be dropped if the client disconnects. In order for us to have a long running
11+
//! task, like an LLM request, we need to spawn our long running task in a separate task and then spawn
12+
//! a second task that will monitor for disconnects from the client. The primary task which spawned the
13+
//! two tasks will hold an "armed" [`ConnectionHandle`] which will issue a [`ConnectionStatus::ClosedUnexpectedly`]
14+
//! if the task is dropped before it is [`ConnectionHandle::disarm`]ed.
15+
//!
16+
//! For the streaming case, request in - stream out, we need a second [`ConnectionHandle`] which will be owned
17+
//! by the stream. A streaming response is when the [`axum::response::Response]] is a [axum::response::Sse] stream.
18+
//! This means the primary task handle will go out of scope when it returns the stream. When we create our
19+
//! SSE stream, we capture the second [`ConnectionHandle`] and arm it. If the stream closes gracefully, the
20+
//! second handle will be disarmed, otherwise, the stream was dropped and the [`Drop`] trait on the [`ConnectionHandle`]
21+
//! triggers a [`ConnectionStatus::ClosedUnexpectedly`] signal.
22+
//!
23+
//! The [`ConnectionHandle`] is a simple wrapper around a [`tokio::sync::oneshot::Sender`] which will send a
24+
//! [`ConnectionStatus`] enum to the primary task. The primary task will then use this to determine if it should
25+
//! cancel the request or not.
26+
//!
27+
//! The [`ConnectionHandle`] is also used to signal to the client that the request has been cancelled. This is
28+
//! done by sending a [`axum::response::sse::Event`] with the event type "error" and the data "[DONE]".
29+
//!
30+
31+
use axum::response::sse::Event;
32+
use dynamo_runtime::engine::AsyncEngineContext;
33+
use futures::{Stream, StreamExt};
34+
use std::sync::Arc;
35+
36+
use crate::http::service::metrics::InflightGuard;
37+
38+
#[derive(Clone, Copy)]
39+
pub enum ConnectionStatus {
40+
Disabled,
41+
ClosedUnexpectedly,
42+
ClosedGracefully,
43+
}
44+
45+
pub struct ConnectionHandle {
46+
sender: Option<tokio::sync::oneshot::Sender<ConnectionStatus>>,
47+
on_drop: ConnectionStatus,
48+
}
49+
50+
impl ConnectionHandle {
51+
/// Handle which by default will issue a [`ConnectionStatus::ClosedGracefully`] signal when dropped.
52+
pub fn create_disarmed(sender: tokio::sync::oneshot::Sender<ConnectionStatus>) -> Self {
53+
Self {
54+
sender: Some(sender),
55+
on_drop: ConnectionStatus::ClosedGracefully,
56+
}
57+
}
58+
59+
/// Handle which will issue a [`ConnectionStatus::ClosedUnexpectedly`] signal when dropped.
60+
pub fn create_armed(sender: tokio::sync::oneshot::Sender<ConnectionStatus>) -> Self {
61+
Self {
62+
sender: Some(sender),
63+
on_drop: ConnectionStatus::ClosedUnexpectedly,
64+
}
65+
}
66+
67+
/// Handle which will not issue a signal when dropped.
68+
pub fn create_disabled(sender: tokio::sync::oneshot::Sender<ConnectionStatus>) -> Self {
69+
Self {
70+
sender: Some(sender),
71+
on_drop: ConnectionStatus::Disabled,
72+
}
73+
}
74+
75+
/// Handle which will issue a [`ConnectionStatus::ClosedGracefully`] signal when dropped.
76+
pub fn disarm(&mut self) {
77+
self.on_drop = ConnectionStatus::ClosedGracefully;
78+
}
79+
80+
/// Handle which will issue a [`ConnectionStatus::ClosedUnexpectedly`] signal when dropped.
81+
pub fn arm(&mut self) {
82+
self.on_drop = ConnectionStatus::ClosedUnexpectedly;
83+
}
84+
}
85+
86+
impl Drop for ConnectionHandle {
87+
fn drop(&mut self) {
88+
if let Some(sender) = self.sender.take() {
89+
let _ = sender.send(self.on_drop);
90+
}
91+
}
92+
}
93+
94+
/// Creates a pair of handles which will monitor for disconnects from the client.
95+
///
96+
/// The first handle is armed and will issue a [`ConnectionStatus::ClosedUnexpectedly`] signal when dropped.
97+
/// The second handle is disarmed and will issue a [`ConnectionStatus::ClosedGracefully`] signal when dropped.
98+
///
99+
/// The handles are returned in the order of the first being armed and the second being disarmed.
100+
pub async fn create_connection_monitor(
101+
engine_context: Arc<dyn AsyncEngineContext>,
102+
) -> (ConnectionHandle, ConnectionHandle) {
103+
// these oneshot channels monitor possible disconnects from the client in two different scopes:
104+
// - the local task (connection_handle)
105+
// - an optionally streaming response (stream_handle)
106+
let (connection_tx, connection_rx) = tokio::sync::oneshot::channel();
107+
let (stream_tx, stream_rx) = tokio::sync::oneshot::channel();
108+
109+
// detached task that will naturally close when both handles are dropped
110+
tokio::spawn(connection_monitor(
111+
engine_context.clone(),
112+
connection_rx,
113+
stream_rx,
114+
));
115+
116+
// Two handles, the first is armed, the second is disarmed
117+
(
118+
ConnectionHandle::create_armed(connection_tx),
119+
ConnectionHandle::create_disabled(stream_tx),
120+
)
121+
}
122+
123+
#[tracing::instrument(level = "trace", skip_all, fields(request_id = %engine_context.id()))]
124+
async fn connection_monitor(
125+
engine_context: Arc<dyn AsyncEngineContext>,
126+
connection_rx: tokio::sync::oneshot::Receiver<ConnectionStatus>,
127+
stream_rx: tokio::sync::oneshot::Receiver<ConnectionStatus>,
128+
) {
129+
match connection_rx.await {
130+
Err(_) | Ok(ConnectionStatus::ClosedUnexpectedly) => {
131+
// the client has disconnected, no need to gracefully cancel, just kill the context
132+
tracing::trace!("Connection closed unexpectedly; issuing cancellation");
133+
engine_context.kill();
134+
}
135+
Ok(ConnectionStatus::ClosedGracefully) => {
136+
tracing::trace!("Connection closed gracefully");
137+
}
138+
Ok(ConnectionStatus::Disabled) => {}
139+
}
140+
141+
match stream_rx.await {
142+
Err(_) | Ok(ConnectionStatus::ClosedUnexpectedly) => {
143+
tracing::trace!("Stream closed unexpectedly; issuing cancellation");
144+
engine_context.kill();
145+
}
146+
Ok(ConnectionStatus::ClosedGracefully) => {
147+
tracing::trace!("Stream closed gracefully");
148+
}
149+
Ok(ConnectionStatus::Disabled) => {}
150+
}
151+
}
152+
153+
/// This method will consume a stream of SSE events and monitor for disconnects or context cancellation.
154+
///
155+
/// Uses `tokio::select!` to choose between receiving events from the source stream or detecting when
156+
/// the context is stopped. If the context is stopped, we break the stream. If the source stream ends
157+
/// naturally, we mark the request as successful and send the final `[DONE]` event.
158+
pub fn monitor_for_disconnects(
159+
stream: impl Stream<Item = Result<Event, axum::Error>>,
160+
context: Arc<dyn AsyncEngineContext>,
161+
mut inflight_guard: InflightGuard,
162+
mut stream_handle: ConnectionHandle,
163+
) -> impl Stream<Item = Result<Event, axum::Error>> {
164+
stream_handle.arm();
165+
async_stream::try_stream! {
166+
tokio::pin!(stream);
167+
loop {
168+
tokio::select! {
169+
event = stream.next() => {
170+
match event {
171+
Some(Ok(event)) => {
172+
yield event;
173+
}
174+
Some(Err(err)) => {
175+
yield Event::default().event("error").comment(err.to_string());
176+
}
177+
None => {
178+
// Stream ended normally
179+
inflight_guard.mark_ok();
180+
stream_handle.disarm();
181+
182+
// todo: if we yield a dynamo sentinel event, we need to do it before the done or the
183+
// async-openai client will chomp it.
184+
yield Event::default().data("[DONE]");
185+
break;
186+
}
187+
}
188+
}
189+
_ = context.stopped() => {
190+
tracing::trace!("Context stopped; breaking stream");
191+
break;
192+
}
193+
}
194+
}
195+
}
196+
}

0 commit comments

Comments
 (0)