Skip to content

Commit e75ca6d

Browse files
authored
feat: FT Python Context and Unit Tests (#2677)
1 parent 8f12b18 commit e75ca6d

File tree

15 files changed

+612
-33
lines changed

15 files changed

+612
-33
lines changed

lib/bindings/python/rust/context.rs

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,46 @@
11
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
// SPDX-License-Identifier: Apache-2.0
33

4-
// PyContext is a wrapper around the AsyncEngineContext to allow for Python bindings.
4+
// Context is a wrapper around the AsyncEngineContext to allow for Python bindings.
55

6+
use dynamo_runtime::pipeline::context::Controller;
67
pub use dynamo_runtime::pipeline::AsyncEngineContext;
78
use pyo3::prelude::*;
89
use std::sync::Arc;
910

10-
// PyContext is a wrapper around the AsyncEngineContext to allow for Python bindings.
11+
// Context is a wrapper around the AsyncEngineContext to allow for Python bindings.
1112
// Not all methods of the AsyncEngineContext are exposed, jsut the primary ones for tracing + cancellation.
1213
// Kept as class, to allow for future expansion if needed.
14+
#[derive(Clone)]
1315
#[pyclass]
14-
pub struct PyContext {
15-
pub inner: Arc<dyn AsyncEngineContext>,
16+
pub struct Context {
17+
inner: Arc<dyn AsyncEngineContext>,
1618
}
1719

18-
impl PyContext {
20+
impl Context {
1921
pub fn new(inner: Arc<dyn AsyncEngineContext>) -> Self {
2022
Self { inner }
2123
}
24+
25+
pub fn inner(&self) -> Arc<dyn AsyncEngineContext> {
26+
self.inner.clone()
27+
}
2228
}
2329

2430
#[pymethods]
25-
impl PyContext {
31+
impl Context {
32+
#[new]
33+
#[pyo3(signature = (id=None))]
34+
fn py_new(id: Option<String>) -> Self {
35+
let controller = match id {
36+
Some(id) => Controller::new(id),
37+
None => Controller::default(),
38+
};
39+
Self {
40+
inner: Arc::new(controller),
41+
}
42+
}
43+
2644
// sync method of `await async_is_stopped()`
2745
fn is_stopped(&self) -> bool {
2846
self.inner.is_stopped()

lib/bindings/python/rust/engine.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
// See the License for the specific language governing permissions and
1414
// limitations under the License.
1515

16-
use super::context::{callable_accepts_kwarg, PyContext};
16+
use super::context::{callable_accepts_kwarg, Context};
1717
use pyo3::prelude::*;
1818
use pyo3::types::{PyDict, PyModule};
1919
use pyo3::{PyAny, PyErr};
@@ -114,7 +114,7 @@ pub struct PythonServerStreamingEngine {
114114
_cancel_token: CancellationToken,
115115
generator: Arc<PyObject>,
116116
event_loop: Arc<PyObject>,
117-
has_pycontext: bool,
117+
has_context: bool,
118118
}
119119

120120
impl PythonServerStreamingEngine {
@@ -123,7 +123,7 @@ impl PythonServerStreamingEngine {
123123
generator: Arc<PyObject>,
124124
event_loop: Arc<PyObject>,
125125
) -> Self {
126-
let has_pycontext = Python::with_gil(|py| {
126+
let has_context = Python::with_gil(|py| {
127127
let callable = generator.bind(py);
128128
callable_accepts_kwarg(py, callable, "context").unwrap_or(false)
129129
});
@@ -132,7 +132,7 @@ impl PythonServerStreamingEngine {
132132
_cancel_token: cancel_token,
133133
generator,
134134
event_loop,
135-
has_pycontext,
135+
has_context,
136136
}
137137
}
138138
}
@@ -175,7 +175,7 @@ where
175175
let generator = self.generator.clone();
176176
let event_loop = self.event_loop.clone();
177177
let ctx_python = ctx.clone();
178-
let has_pycontext = self.has_pycontext;
178+
let has_context = self.has_context;
179179

180180
// Acquiring the GIL is similar to acquiring a standard lock/mutex
181181
// Performing this in an tokio async task could block the thread for an undefined amount of time
@@ -190,9 +190,9 @@ where
190190
let stream = tokio::task::spawn_blocking(move || {
191191
Python::with_gil(|py| {
192192
let py_request = pythonize(py, &request)?;
193-
let py_ctx = Py::new(py, PyContext::new(ctx_python.clone()))?;
193+
let py_ctx = Py::new(py, Context::new(ctx_python.clone()))?;
194194

195-
let gen = if has_pycontext {
195+
let gen = if has_context {
196196
// Pass context as a kwarg
197197
let kwarg = PyDict::new(py);
198198
kwarg.set_item("context", &py_ctx)?;

lib/bindings/python/rust/lib.rs

Lines changed: 57 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ use tokio::sync::Mutex;
1616
use dynamo_runtime::{
1717
self as rs, logging,
1818
pipeline::{
19-
network::egress::push_router::RouterMode as RsRouterMode, EngineStream, ManyOut, SingleIn,
19+
context::Context as RsContext, network::egress::push_router::RouterMode as RsRouterMode,
20+
EngineStream, ManyOut, SingleIn,
2021
},
2122
protocols::annotated::Annotated as RsAnnotated,
2223
traits::DistributedRuntimeProvider,
@@ -104,7 +105,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
104105
m.add_class::<http::HttpService>()?;
105106
m.add_class::<http::HttpError>()?;
106107
m.add_class::<http::HttpAsyncEngine>()?;
107-
m.add_class::<context::PyContext>()?;
108+
m.add_class::<context::Context>()?;
108109
m.add_class::<EtcdKvCache>()?;
109110
m.add_class::<ModelType>()?;
110111
m.add_class::<llm::kv::ForwardPassMetrics>()?;
@@ -697,27 +698,29 @@ impl Client {
697698
}
698699

699700
/// Issue a request to the endpoint using the default routing strategy.
700-
#[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING))]
701+
#[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING, context=None))]
701702
fn generate<'p>(
702703
&self,
703704
py: Python<'p>,
704705
request: PyObject,
705706
annotated: Option<bool>,
707+
context: Option<context::Context>,
706708
) -> PyResult<Bound<'p, PyAny>> {
707709
if self.router.client.is_static() {
708-
self.r#static(py, request, annotated)
710+
self.r#static(py, request, annotated, context)
709711
} else {
710-
self.random(py, request, annotated)
712+
self.random(py, request, annotated, context)
711713
}
712714
}
713715

714716
/// Send a request to the next endpoint in a round-robin fashion.
715-
#[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING))]
717+
#[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING, context=None))]
716718
fn round_robin<'p>(
717719
&self,
718720
py: Python<'p>,
719721
request: PyObject,
720722
annotated: Option<bool>,
723+
context: Option<context::Context>,
721724
) -> PyResult<Bound<'p, PyAny>> {
722725
let request: serde_json::Value = pythonize::depythonize(&request.into_bound(py))?;
723726
let annotated = annotated.unwrap_or(false);
@@ -726,7 +729,15 @@ impl Client {
726729
let client = self.router.clone();
727730

728731
pyo3_async_runtimes::tokio::future_into_py(py, async move {
729-
let stream = client.round_robin(request.into()).await.map_err(to_pyerr)?;
732+
let stream = match context {
733+
Some(context) => {
734+
let request = RsContext::with_id(request, context.inner().id().to_string());
735+
let stream = client.round_robin(request).await.map_err(to_pyerr)?;
736+
context.inner().link_child(stream.context());
737+
stream
738+
}
739+
_ => client.round_robin(request.into()).await.map_err(to_pyerr)?,
740+
};
730741
tokio::spawn(process_stream(stream, tx));
731742
Ok(AsyncResponseStream {
732743
rx: Arc::new(Mutex::new(rx)),
@@ -736,12 +747,13 @@ impl Client {
736747
}
737748

738749
/// Send a request to a random endpoint.
739-
#[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING))]
750+
#[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING, context=None))]
740751
fn random<'p>(
741752
&self,
742753
py: Python<'p>,
743754
request: PyObject,
744755
annotated: Option<bool>,
756+
context: Option<context::Context>,
745757
) -> PyResult<Bound<'p, PyAny>> {
746758
let request: serde_json::Value = pythonize::depythonize(&request.into_bound(py))?;
747759
let annotated = annotated.unwrap_or(false);
@@ -750,7 +762,15 @@ impl Client {
750762
let client = self.router.clone();
751763

752764
pyo3_async_runtimes::tokio::future_into_py(py, async move {
753-
let stream = client.random(request.into()).await.map_err(to_pyerr)?;
765+
let stream = match context {
766+
Some(context) => {
767+
let request = RsContext::with_id(request, context.inner().id().to_string());
768+
let stream = client.random(request).await.map_err(to_pyerr)?;
769+
context.inner().link_child(stream.context());
770+
stream
771+
}
772+
_ => client.random(request.into()).await.map_err(to_pyerr)?,
773+
};
754774
tokio::spawn(process_stream(stream, tx));
755775
Ok(AsyncResponseStream {
756776
rx: Arc::new(Mutex::new(rx)),
@@ -760,13 +780,14 @@ impl Client {
760780
}
761781

762782
/// Directly send a request to a specific endpoint.
763-
#[pyo3(signature = (request, instance_id, annotated=DEFAULT_ANNOTATED_SETTING))]
783+
#[pyo3(signature = (request, instance_id, annotated=DEFAULT_ANNOTATED_SETTING, context=None))]
764784
fn direct<'p>(
765785
&self,
766786
py: Python<'p>,
767787
request: PyObject,
768788
instance_id: i64,
769789
annotated: Option<bool>,
790+
context: Option<context::Context>,
770791
) -> PyResult<Bound<'p, PyAny>> {
771792
let request: serde_json::Value = pythonize::depythonize(&request.into_bound(py))?;
772793
let annotated = annotated.unwrap_or(false);
@@ -775,10 +796,21 @@ impl Client {
775796
let client = self.router.clone();
776797

777798
pyo3_async_runtimes::tokio::future_into_py(py, async move {
778-
let stream = client
779-
.direct(request.into(), instance_id)
780-
.await
781-
.map_err(to_pyerr)?;
799+
let stream = match context {
800+
Some(context) => {
801+
let request = RsContext::with_id(request, context.inner().id().to_string());
802+
let stream = client
803+
.direct(request, instance_id)
804+
.await
805+
.map_err(to_pyerr)?;
806+
context.inner().link_child(stream.context());
807+
stream
808+
}
809+
_ => client
810+
.direct(request.into(), instance_id)
811+
.await
812+
.map_err(to_pyerr)?,
813+
};
782814

783815
tokio::spawn(process_stream(stream, tx));
784816

@@ -790,12 +822,13 @@ impl Client {
790822
}
791823

792824
/// Directly send a request to a pre-defined static worker
793-
#[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING))]
825+
#[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING, context=None))]
794826
fn r#static<'p>(
795827
&self,
796828
py: Python<'p>,
797829
request: PyObject,
798830
annotated: Option<bool>,
831+
context: Option<context::Context>,
799832
) -> PyResult<Bound<'p, PyAny>> {
800833
let request: serde_json::Value = pythonize::depythonize(&request.into_bound(py))?;
801834
let annotated = annotated.unwrap_or(false);
@@ -804,7 +837,15 @@ impl Client {
804837
let client = self.router.clone();
805838

806839
pyo3_async_runtimes::tokio::future_into_py(py, async move {
807-
let stream = client.r#static(request.into()).await.map_err(to_pyerr)?;
840+
let stream = match context {
841+
Some(context) => {
842+
let request = RsContext::with_id(request, context.inner().id().to_string());
843+
let stream = client.r#static(request).await.map_err(to_pyerr)?;
844+
context.inner().link_child(stream.context());
845+
stream
846+
}
847+
_ => client.r#static(request.into()).await.map_err(to_pyerr)?,
848+
};
808849

809850
tokio::spawn(process_stream(stream, tx));
810851

lib/bindings/python/src/dynamo/runtime/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@
2525
from dynamo._core import Backend as Backend
2626
from dynamo._core import Client as Client
2727
from dynamo._core import Component as Component
28+
from dynamo._core import Context as Context
2829
from dynamo._core import DistributedRuntime as DistributedRuntime
2930
from dynamo._core import Endpoint as Endpoint
3031
from dynamo._core import EtcdKvCache as EtcdKvCache
3132
from dynamo._core import ModelDeploymentCard as ModelDeploymentCard
3233
from dynamo._core import OAIChatPreprocessor as OAIChatPreprocessor
33-
from dynamo._core import PyContext as PyContext
3434

3535

3636
def dynamo_worker(static=False):

0 commit comments

Comments
 (0)