Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 48 additions & 4 deletions lib/bindings/python/rust/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@
// limitations under the License.

use super::context::{callable_accepts_kwarg, PyContext};
use futures::stream::{self, StreamExt as FuturesStreamExt};
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyModule};
use pyo3::{PyAny, PyErr};
use pyo3_async_runtimes::TaskLocals;
use pythonize::{depythonize, pythonize};
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio_stream::{wrappers::ReceiverStream, StreamExt};
use tokio_stream::wrappers::ReceiverStream;

pub use dynamo_runtime::{
pipeline::{
Expand Down Expand Up @@ -96,6 +97,10 @@ impl PythonAsyncEngine {
Arc::new(event_loop),
)))
}

pub fn block_until_stream_item(&mut self, enabled: bool) {
self.0.block_until_stream_item(enabled);
}
}

#[async_trait]
Expand All @@ -115,6 +120,7 @@ pub struct PythonServerStreamingEngine {
generator: Arc<PyObject>,
event_loop: Arc<PyObject>,
has_pycontext: bool,
block_until_stream_item: bool,
}

impl PythonServerStreamingEngine {
Expand All @@ -133,8 +139,13 @@ impl PythonServerStreamingEngine {
generator,
event_loop,
has_pycontext,
block_until_stream_item: false,
}
}

pub fn block_until_stream_item(&mut self, enabled: bool) {
self.block_until_stream_item = enabled;
}
}

#[derive(Debug, thiserror::Error)]
Expand Down Expand Up @@ -208,14 +219,46 @@ where
})
.await??;

let stream = Box::pin(stream);

// process the stream
// any error thrown in the stream will be caught and complete the processing task
// errors are captured by a task that is watching the processing task
// the error will be emitted as an annotated error
let request_id = id.clone();

let mut stream = Box::pin(stream);

let stream = if self.block_until_stream_item {
let first_item = match FuturesStreamExt::next(&mut stream).await {
Some(Ok(item)) => item,
Some(Err(e)) => {
// Any Python exception (including HttpError) is already wrapped in PyErr
// The HttpAsyncEngine will inspect this PyErr later to see if it's an HttpError
tracing::warn!(
request_id,
"Python exception occurred before finish of first iteration: {}",
e
);
return Err(Error::new(e));
}
None => {
tracing::warn!(
request_id,
"python async generator stream ended before processing started"
);
return Err(Error::new(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"python async generator stream ended before processing started",
)));
}
};
// Create a new stream that yields the first item followed by the rest of the original stream
let stream =
futures::StreamExt::chain(stream::once(futures::future::ok(first_item)), stream);
FuturesStreamExt::boxed(stream)
} else {
stream
};

tokio::spawn(async move {
tracing::debug!(
request_id,
Expand All @@ -225,7 +268,8 @@ where
let mut stream = stream;
let mut count = 0;

while let Some(item) = stream.next().await {
// Fix the third error by explicitly using FuturesStreamExt::next
while let Some(item) = FuturesStreamExt::next(&mut stream).await {
count += 1;
tracing::trace!(
request_id,
Expand Down
4 changes: 4 additions & 0 deletions lib/bindings/python/rust/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ impl HttpAsyncEngine {
pub fn new(generator: PyObject, event_loop: PyObject) -> PyResult<Self> {
Ok(PythonAsyncEngine::new(generator, event_loop)?.into())
}

pub fn block_until_stream_item(&mut self, enabled: bool) {
self.0.block_until_stream_item(enabled);
}
}

#[async_trait]
Expand Down
Loading