Skip to content

Commit 1d85aec

Browse files
committed
new generate method
Signed-off-by: Brian Larson <brian.larson@baseten.co>
1 parent 1477f6e commit 1d85aec

File tree

2 files changed

+59
-28
lines changed

2 files changed

+59
-28
lines changed

lib/bindings/python/rust/engine.rs

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
// limitations under the License.
1515

1616
use super::context::{callable_accepts_kwarg, Context};
17+
use dynamo_llm::protocols::DataStream;
18+
use dynamo_runtime::engine::AsyncEngineContext;
1719
use pyo3::prelude::*;
1820
use pyo3::types::{PyDict, PyModule};
1921
use pyo3::{PyAny, PyErr};
@@ -73,7 +75,7 @@ pub fn add_to_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
7375
/// ```
7476
#[pyclass]
7577
#[derive(Clone)]
76-
pub struct PythonAsyncEngine(PythonServerStreamingEngine);
78+
pub struct PythonAsyncEngine(pub PythonServerStreamingEngine);
7779

7880
#[pymethods]
7981
impl PythonAsyncEngine {
@@ -135,31 +137,16 @@ impl PythonServerStreamingEngine {
135137
has_context,
136138
}
137139
}
138-
}
139140

140-
#[derive(Debug, thiserror::Error)]
141-
enum ResponseProcessingError {
142-
#[error("python exception: {0}")]
143-
PythonException(String),
144-
145-
#[error("python generator exit: {0}")]
146-
PyGeneratorExit(String),
147-
148-
#[error("deserialize error: {0}")]
149-
DeserializeError(String),
150-
151-
#[error("gil offload error: {0}")]
152-
OffloadError(String),
153-
}
154-
155-
#[async_trait]
156-
impl<Req, Resp> AsyncEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>, Error>
157-
for PythonServerStreamingEngine
158-
where
159-
Req: Data + Serialize,
160-
Resp: Data + for<'de> Deserialize<'de>,
161-
{
162-
async fn generate(&self, request: SingleIn<Req>) -> Result<ManyOut<Annotated<Resp>>, Error> {
141+
/// Generate the response in parts.
142+
pub async fn generate_in_parts<Req, Resp>(
143+
&self,
144+
request: SingleIn<Req>,
145+
) -> Result<(DataStream<Annotated<Resp>>, Arc<dyn AsyncEngineContext>), Error>
146+
where
147+
Req: Data + Serialize,
148+
Resp: Data + for<'de> Deserialize<'de>,
149+
{
163150
// Create a context
164151
let (request, context) = request.transfer(());
165152
let ctx = context.context();
@@ -290,8 +277,36 @@ where
290277
});
291278

292279
let stream = ReceiverStream::new(rx);
280+
let context = context.context();
281+
Ok((Box::pin(stream), context))
282+
}
283+
}
284+
285+
#[derive(Debug, thiserror::Error)]
286+
enum ResponseProcessingError {
287+
#[error("python exception: {0}")]
288+
PythonException(String),
289+
290+
#[error("python generator exit: {0}")]
291+
PyGeneratorExit(String),
292+
293+
#[error("deserialize error: {0}")]
294+
DeserializeError(String),
293295

294-
Ok(ResponseStream::new(Box::pin(stream), context.context()))
296+
#[error("gil offload error: {0}")]
297+
OffloadError(String),
298+
}
299+
300+
#[async_trait]
301+
impl<Req, Resp> AsyncEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>, Error>
302+
for PythonServerStreamingEngine
303+
where
304+
Req: Data + Serialize,
305+
Resp: Data + for<'de> Deserialize<'de>,
306+
{
307+
async fn generate(&self, request: SingleIn<Req>) -> Result<ManyOut<Annotated<Resp>>, Error> {
308+
let (stream, context) = self.generate_in_parts(request).await?;
309+
Ok(ResponseStream::new(Box::pin(stream), context))
295310
}
296311
}
297312

lib/bindings/python/rust/http.rs

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,24 @@ where
177177
Resp: Data + for<'de> Deserialize<'de>,
178178
{
179179
async fn generate(&self, request: SingleIn<Req>) -> Result<ManyOut<Annotated<Resp>>, Error> {
180-
match self.0.generate(request).await {
181-
Ok(res) => Ok(res),
180+
match self.0 .0.generate_in_parts(request).await {
181+
Ok((mut stream, context)) => {
182+
let first_item = match futures::StreamExt::next(&mut stream).await {
183+
Some(item) => item,
184+
None => {
185+
return Err(Error::new(std::io::Error::new(
186+
std::io::ErrorKind::UnexpectedEof,
187+
"python async generator stream ended before processing started",
188+
)));
189+
}
190+
};
191+
192+
// Create a new stream that yields the first item followed by the rest of the original stream
193+
let once_stream = futures::stream::once(async { first_item });
194+
let stream = futures::StreamExt::chain(once_stream, stream);
195+
196+
Ok(ResponseStream::new(Box::pin(stream), context))
197+
}
182198

183199
// Inspect the error - if it was an HttpError from Python, extract the code and message
184200
// and return the rust version of HttpError

0 commit comments

Comments
 (0)