|
14 | 14 | // limitations under the License. |
15 | 15 |
|
16 | 16 | use super::context::{callable_accepts_kwarg, Context}; |
| 17 | +use dynamo_llm::protocols::DataStream; |
| 18 | +use dynamo_runtime::engine::AsyncEngineContext; |
17 | 19 | use pyo3::prelude::*; |
18 | 20 | use pyo3::types::{PyDict, PyModule}; |
19 | 21 | use pyo3::{PyAny, PyErr}; |
@@ -73,7 +75,7 @@ pub fn add_to_module(m: &Bound<'_, PyModule>) -> PyResult<()> { |
73 | 75 | /// ``` |
74 | 76 | #[pyclass] |
75 | 77 | #[derive(Clone)] |
76 | | -pub struct PythonAsyncEngine(PythonServerStreamingEngine); |
| 78 | +pub struct PythonAsyncEngine(pub PythonServerStreamingEngine); |
77 | 79 |
|
78 | 80 | #[pymethods] |
79 | 81 | impl PythonAsyncEngine { |
@@ -135,31 +137,16 @@ impl PythonServerStreamingEngine { |
135 | 137 | has_context, |
136 | 138 | } |
137 | 139 | } |
138 | | -} |
139 | 140 |
|
140 | | -#[derive(Debug, thiserror::Error)] |
141 | | -enum ResponseProcessingError { |
142 | | - #[error("python exception: {0}")] |
143 | | - PythonException(String), |
144 | | - |
145 | | - #[error("python generator exit: {0}")] |
146 | | - PyGeneratorExit(String), |
147 | | - |
148 | | - #[error("deserialize error: {0}")] |
149 | | - DeserializeError(String), |
150 | | - |
151 | | - #[error("gil offload error: {0}")] |
152 | | - OffloadError(String), |
153 | | -} |
154 | | - |
155 | | -#[async_trait] |
156 | | -impl<Req, Resp> AsyncEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>, Error> |
157 | | - for PythonServerStreamingEngine |
158 | | -where |
159 | | - Req: Data + Serialize, |
160 | | - Resp: Data + for<'de> Deserialize<'de>, |
161 | | -{ |
162 | | - async fn generate(&self, request: SingleIn<Req>) -> Result<ManyOut<Annotated<Resp>>, Error> { |
| 141 | + /// Generate the response in parts. |
| 142 | + pub async fn generate_in_parts<Req, Resp>( |
| 143 | + &self, |
| 144 | + request: SingleIn<Req>, |
| 145 | + ) -> Result<(DataStream<Annotated<Resp>>, Arc<dyn AsyncEngineContext>), Error> |
| 146 | + where |
| 147 | + Req: Data + Serialize, |
| 148 | + Resp: Data + for<'de> Deserialize<'de>, |
| 149 | + { |
163 | 150 | // Create a context |
164 | 151 | let (request, context) = request.transfer(()); |
165 | 152 | let ctx = context.context(); |
@@ -290,8 +277,36 @@ where |
290 | 277 | }); |
291 | 278 |
|
292 | 279 | let stream = ReceiverStream::new(rx); |
| 280 | + let context = context.context(); |
| 281 | + Ok((Box::pin(stream), context)) |
| 282 | + } |
| 283 | +} |
| 284 | + |
| 285 | +#[derive(Debug, thiserror::Error)] |
| 286 | +enum ResponseProcessingError { |
| 287 | + #[error("python exception: {0}")] |
| 288 | + PythonException(String), |
| 289 | + |
| 290 | + #[error("python generator exit: {0}")] |
| 291 | + PyGeneratorExit(String), |
| 292 | + |
| 293 | + #[error("deserialize error: {0}")] |
| 294 | + DeserializeError(String), |
293 | 295 |
|
294 | | - Ok(ResponseStream::new(Box::pin(stream), context.context())) |
| 296 | + #[error("gil offload error: {0}")] |
| 297 | + OffloadError(String), |
| 298 | +} |
| 299 | + |
| 300 | +#[async_trait] |
| 301 | +impl<Req, Resp> AsyncEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>, Error> |
| 302 | + for PythonServerStreamingEngine |
| 303 | +where |
| 304 | + Req: Data + Serialize, |
| 305 | + Resp: Data + for<'de> Deserialize<'de>, |
| 306 | +{ |
| 307 | + async fn generate(&self, request: SingleIn<Req>) -> Result<ManyOut<Annotated<Resp>>, Error> { |
| 308 | + let (stream, context) = self.generate_in_parts(request).await?; |
| 309 | + Ok(ResponseStream::new(Box::pin(stream), context)) |
295 | 310 | } |
296 | 311 | } |
297 | 312 |
|
|
0 commit comments