Skip to content

Commit b4ddca9

Browse files
authored
feat: Failure Detection while Responses are returning (#1671)
1 parent bd91dca commit b4ddca9

File tree

9 files changed

+361
-111
lines changed

9 files changed

+361
-111
lines changed

lib/bindings/python/rust/lib.rs

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ struct Endpoint {
214214
#[pyclass]
215215
#[derive(Clone)]
216216
struct Client {
217-
router: rs::pipeline::PushRouter<serde_json::Value, serde_json::Value>,
217+
router: rs::pipeline::PushRouter<serde_json::Value, RsAnnotated<serde_json::Value>>,
218218
}
219219

220220
#[pyclass(eq, eq_int)]
@@ -485,13 +485,12 @@ impl Endpoint {
485485
let inner = self.inner.clone();
486486
pyo3_async_runtimes::tokio::future_into_py(py, async move {
487487
let client = inner.client().await.map_err(to_pyerr)?;
488-
let push_router =
489-
rs::pipeline::PushRouter::<serde_json::Value, serde_json::Value>::from_client(
490-
client,
491-
Default::default(),
492-
)
493-
.await
494-
.map_err(to_pyerr)?;
488+
let push_router = rs::pipeline::PushRouter::<
489+
serde_json::Value,
490+
RsAnnotated<serde_json::Value>,
491+
>::from_client(client, Default::default())
492+
.await
493+
.map_err(to_pyerr)?;
495494
Ok(Client {
496495
router: push_router,
497496
})
@@ -757,23 +756,13 @@ impl Client {
757756
}
758757

759758
async fn process_stream(
760-
stream: EngineStream<serde_json::Value>,
759+
stream: EngineStream<RsAnnotated<serde_json::Value>>,
761760
tx: tokio::sync::mpsc::Sender<RsAnnotated<PyObject>>,
762761
) {
763762
let mut stream = stream;
764763
while let Some(response) = stream.next().await {
765764
// Convert the response to a PyObject using Python's GIL
766-
// TODO: Remove the clone, but still log the full JSON string on error. But how?
767-
let annotated: RsAnnotated<serde_json::Value> = match serde_json::from_value(
768-
response.clone(),
769-
) {
770-
Ok(a) => a,
771-
Err(err) => {
772-
tracing::error!(%err, %response, "process_stream: Failed de-serializing JSON into RsAnnotated");
773-
break;
774-
}
775-
};
776-
765+
let annotated: RsAnnotated<serde_json::Value> = response;
777766
let annotated: RsAnnotated<PyObject> = annotated.map_data(|data| {
778767
let result = Python::with_gil(|py| match pythonize::pythonize(py, &data) {
779768
Ok(pyobj) => Ok(pyobj.into()),

lib/llm/src/protocols/common/llm_backend.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use serde::{Deserialize, Serialize};
1818
pub use super::preprocessor::PreprocessedRequest;
1919
pub use super::FinishReason;
2020
use crate::protocols::TokenIdType;
21+
use dynamo_runtime::protocols::maybe_error::MaybeError;
2122

2223
pub type TokenType = Option<String>;
2324
pub type LogProbs = Vec<f64>;
@@ -134,6 +135,20 @@ impl LLMEngineOutput {
134135
}
135136
}
136137

138+
impl MaybeError for LLMEngineOutput {
139+
fn from_err(err: Box<dyn std::error::Error>) -> Self {
140+
LLMEngineOutput::error(format!("{:?}", err))
141+
}
142+
143+
fn err(&self) -> Option<Box<dyn std::error::Error>> {
144+
if let Some(FinishReason::Error(err_msg)) = &self.finish_reason {
145+
Some(anyhow::Error::msg(err_msg.clone()).into())
146+
} else {
147+
None
148+
}
149+
}
150+
}
151+
137152
/// Raw output from embedding engines containing embedding vectors
138153
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
139154
pub struct EmbeddingsEngineOutput {
@@ -144,3 +159,26 @@ pub struct EmbeddingsEngineOutput {
144159
pub prompt_tokens: u32,
145160
pub total_tokens: u32,
146161
}
162+
163+
#[cfg(test)]
164+
mod tests {
165+
use super::*;
166+
167+
#[test]
168+
fn test_maybe_error() {
169+
let output = LLMEngineOutput::stop();
170+
assert!(output.err().is_none());
171+
assert!(output.is_ok());
172+
assert!(!output.is_err());
173+
174+
let output = LLMEngineOutput::error("Test error".to_string());
175+
assert_eq!(format!("{}", output.err().unwrap()), "Test error");
176+
assert!(!output.is_ok());
177+
assert!(output.is_err());
178+
179+
let output = LLMEngineOutput::from_err(anyhow::Error::msg("Test error 2").into());
180+
assert_eq!(format!("{}", output.err().unwrap()), "Test error 2");
181+
assert!(!output.is_ok());
182+
assert!(output.is_err());
183+
}
184+
}

lib/runtime/src/pipeline/network.rs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,3 +323,54 @@ impl<Req: PipelineIO, Resp: PipelineIO> Ingress<Req, Resp> {
323323
pub trait PushWorkHandler: Send + Sync {
324324
async fn handle_payload(&self, payload: Bytes) -> Result<(), PipelineError>;
325325
}
326+
327+
/*
328+
/// `NetworkStreamWrapper` is a simple wrapper used to detect proper stream termination
329+
/// in network communication between ingress and egress components.
330+
///
331+
/// **Purpose**: This wrapper solves the problem of detecting whether a stream ended
332+
/// gracefully or was cut off prematurely (e.g., due to network issues).
333+
///
334+
/// **Design Rationale**:
335+
/// - Cannot use `Annotated` directly because the generic type `U` varies:
336+
/// - Sometimes `U = Annotated<...>`
337+
/// - Sometimes `U = LLMEngineOutput<...>`
338+
/// - Using `Annotated` would require double-wrapping like `Annotated<Annotated<...>>`
339+
/// - A simple wrapper is cleaner and more straightforward
340+
///
341+
/// **Stream Flow**:
342+
/// ```
343+
/// At AsyncEngine:
344+
/// response 1 -> response 2 -> response 3 -> <end>
345+
///
346+
/// Between ingress/egress:
347+
/// response 1 <end=false> -> response 2 <end=false> -> response 3 <end=false> -> (null) <end=true>
348+
///
349+
/// At client:
350+
/// response 1 -> response 2 -> response 3 -> <end>
351+
/// ```
352+
///
353+
/// **Error Handling**:
354+
/// If the stream is cut off before proper termination, the egress is responsible for
355+
/// injecting an error response to communicate the incomplete stream to the client:
356+
/// ```
357+
/// At AsyncEngine:
358+
/// response 1 -> ... <without end flag>
359+
///
360+
/// At egress:
361+
/// response 1 <end=false> -> <stream ended without end flag -> convert to error>
362+
///
363+
/// At client:
364+
/// response 1 -> error response
365+
/// ```
366+
///
367+
/// The detection must be done at egress level because premature stream termination
368+
/// can be due to network issues that only the egress component can detect.
369+
*/
370+
/// TODO: Detect end-of-stream using Server-Sent Events (SSE). This will be removed.
371+
#[derive(Serialize, Deserialize, Debug)]
372+
pub struct NetworkStreamWrapper<U> {
373+
#[serde(skip_serializing_if = "Option::is_none")]
374+
pub data: Option<U>,
375+
pub complete_final: bool,
376+
}

lib/runtime/src/pipeline/network/egress/addressed_router.rs

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ use async_nats::client::Client;
1717
use tracing as log;
1818

1919
use super::*;
20-
use crate::Result;
20+
use crate::{protocols::maybe_error::MaybeError, Result};
21+
use tokio_stream::{wrappers::ReceiverStream, StreamExt, StreamNotifyClose};
2122

2223
#[derive(Debug, Clone, Serialize, Deserialize)]
2324
#[serde(rename_all = "snake_case")]
@@ -80,7 +81,7 @@ impl AddressedPushRouter {
8081
impl<T, U> AsyncEngine<SingleIn<AddressedRequest<T>>, ManyOut<U>, Error> for AddressedPushRouter
8182
where
8283
T: Data + Serialize,
83-
U: Data + for<'de> Deserialize<'de>,
84+
U: Data + for<'de> Deserialize<'de> + MaybeError,
8485
{
8586
async fn generate(&self, request: SingleIn<AddressedRequest<T>>) -> Result<ManyOut<U>, Error> {
8687
let request_id = request.context().id().to_string();
@@ -160,16 +161,49 @@ where
160161
.map_err(|_| PipelineError::DetatchedStreamReceiver)?
161162
.map_err(PipelineError::ConnectionFailed)?;
162163

163-
let stream = tokio_stream::wrappers::ReceiverStream::new(response_stream.rx);
164-
165-
let stream = stream.filter_map(|msg| async move {
166-
match serde_json::from_slice::<U>(&msg) {
167-
Ok(r) => Some(r),
168-
Err(err) => {
169-
let json_str = String::from_utf8_lossy(&msg);
170-
log::warn!(%err, %json_str, "Failed deserializing JSON to response");
171-
None
164+
// TODO: Detect end-of-stream using Server-Sent Events (SSE)
165+
let mut is_complete_final = false;
166+
let stream = tokio_stream::StreamNotifyClose::new(
167+
tokio_stream::wrappers::ReceiverStream::new(response_stream.rx),
168+
)
169+
.filter_map(move |res| {
170+
if let Some(res_bytes) = res {
171+
if is_complete_final {
172+
return Some(U::from_err(
173+
Error::msg(
174+
"Response received after generation ended - this should never happen",
175+
)
176+
.into(),
177+
));
178+
}
179+
match serde_json::from_slice::<NetworkStreamWrapper<U>>(&res_bytes) {
180+
Ok(item) => {
181+
is_complete_final = item.complete_final;
182+
if let Some(data) = item.data {
183+
Some(data)
184+
} else if is_complete_final {
185+
None
186+
} else {
187+
Some(U::from_err(
188+
Error::msg("Empty response received - this should never happen")
189+
.into(),
190+
))
191+
}
192+
}
193+
Err(err) => {
194+
// legacy log print
195+
let json_str = String::from_utf8_lossy(&res_bytes);
196+
log::warn!(%err, %json_str, "Failed deserializing JSON to response");
197+
198+
Some(U::from_err(Error::new(err).into()))
199+
}
172200
}
201+
} else if is_complete_final {
202+
None
203+
} else {
204+
Some(U::from_err(
205+
Error::msg("Stream ended before generation completed").into(),
206+
))
173207
}
174208
});
175209

0 commit comments

Comments
 (0)