@@ -16,7 +16,8 @@ use tokio::sync::Mutex;
1616use dynamo_runtime:: {
1717 self as rs, logging,
1818 pipeline:: {
19- network:: egress:: push_router:: RouterMode as RsRouterMode , EngineStream , ManyOut , SingleIn ,
19+ context:: Context as RsContext , network:: egress:: push_router:: RouterMode as RsRouterMode ,
20+ EngineStream , ManyOut , SingleIn ,
2021 } ,
2122 protocols:: annotated:: Annotated as RsAnnotated ,
2223 traits:: DistributedRuntimeProvider ,
@@ -104,7 +105,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
104105 m. add_class :: < http:: HttpService > ( ) ?;
105106 m. add_class :: < http:: HttpError > ( ) ?;
106107 m. add_class :: < http:: HttpAsyncEngine > ( ) ?;
107- m. add_class :: < context:: PyContext > ( ) ?;
108+ m. add_class :: < context:: Context > ( ) ?;
108109 m. add_class :: < EtcdKvCache > ( ) ?;
109110 m. add_class :: < ModelType > ( ) ?;
110111 m. add_class :: < llm:: kv:: ForwardPassMetrics > ( ) ?;
@@ -697,27 +698,29 @@ impl Client {
697698 }
698699
699700 /// Issue a request to the endpoint using the default routing strategy.
700- #[ pyo3( signature = ( request, annotated=DEFAULT_ANNOTATED_SETTING ) ) ]
701+ #[ pyo3( signature = ( request, annotated=DEFAULT_ANNOTATED_SETTING , context= None ) ) ]
701702 fn generate < ' p > (
702703 & self ,
703704 py : Python < ' p > ,
704705 request : PyObject ,
705706 annotated : Option < bool > ,
707+ context : Option < context:: Context > ,
706708 ) -> PyResult < Bound < ' p , PyAny > > {
707709 if self . router . client . is_static ( ) {
708- self . r#static ( py, request, annotated)
710+ self . r#static ( py, request, annotated, context )
709711 } else {
710- self . random ( py, request, annotated)
712+ self . random ( py, request, annotated, context )
711713 }
712714 }
713715
714716 /// Send a request to the next endpoint in a round-robin fashion.
715- #[ pyo3( signature = ( request, annotated=DEFAULT_ANNOTATED_SETTING ) ) ]
717+ #[ pyo3( signature = ( request, annotated=DEFAULT_ANNOTATED_SETTING , context= None ) ) ]
716718 fn round_robin < ' p > (
717719 & self ,
718720 py : Python < ' p > ,
719721 request : PyObject ,
720722 annotated : Option < bool > ,
723+ context : Option < context:: Context > ,
721724 ) -> PyResult < Bound < ' p , PyAny > > {
722725 let request: serde_json:: Value = pythonize:: depythonize ( & request. into_bound ( py) ) ?;
723726 let annotated = annotated. unwrap_or ( false ) ;
@@ -726,7 +729,15 @@ impl Client {
726729 let client = self . router . clone ( ) ;
727730
728731 pyo3_async_runtimes:: tokio:: future_into_py ( py, async move {
729- let stream = client. round_robin ( request. into ( ) ) . await . map_err ( to_pyerr) ?;
732+ let stream = match context {
733+ Some ( context) => {
734+ let request = RsContext :: with_id ( request, context. inner ( ) . id ( ) . to_string ( ) ) ;
735+ let stream = client. round_robin ( request) . await . map_err ( to_pyerr) ?;
736+ context. inner ( ) . link_child ( stream. context ( ) ) ;
737+ stream
738+ }
739+ _ => client. round_robin ( request. into ( ) ) . await . map_err ( to_pyerr) ?,
740+ } ;
730741 tokio:: spawn ( process_stream ( stream, tx) ) ;
731742 Ok ( AsyncResponseStream {
732743 rx : Arc :: new ( Mutex :: new ( rx) ) ,
@@ -736,12 +747,13 @@ impl Client {
736747 }
737748
738749 /// Send a request to a random endpoint.
739- #[ pyo3( signature = ( request, annotated=DEFAULT_ANNOTATED_SETTING ) ) ]
750+ #[ pyo3( signature = ( request, annotated=DEFAULT_ANNOTATED_SETTING , context= None ) ) ]
740751 fn random < ' p > (
741752 & self ,
742753 py : Python < ' p > ,
743754 request : PyObject ,
744755 annotated : Option < bool > ,
756+ context : Option < context:: Context > ,
745757 ) -> PyResult < Bound < ' p , PyAny > > {
746758 let request: serde_json:: Value = pythonize:: depythonize ( & request. into_bound ( py) ) ?;
747759 let annotated = annotated. unwrap_or ( false ) ;
@@ -750,7 +762,15 @@ impl Client {
750762 let client = self . router . clone ( ) ;
751763
752764 pyo3_async_runtimes:: tokio:: future_into_py ( py, async move {
753- let stream = client. random ( request. into ( ) ) . await . map_err ( to_pyerr) ?;
765+ let stream = match context {
766+ Some ( context) => {
767+ let request = RsContext :: with_id ( request, context. inner ( ) . id ( ) . to_string ( ) ) ;
768+ let stream = client. random ( request) . await . map_err ( to_pyerr) ?;
769+ context. inner ( ) . link_child ( stream. context ( ) ) ;
770+ stream
771+ }
772+ _ => client. random ( request. into ( ) ) . await . map_err ( to_pyerr) ?,
773+ } ;
754774 tokio:: spawn ( process_stream ( stream, tx) ) ;
755775 Ok ( AsyncResponseStream {
756776 rx : Arc :: new ( Mutex :: new ( rx) ) ,
@@ -760,13 +780,14 @@ impl Client {
760780 }
761781
762782 /// Directly send a request to a specific endpoint.
763- #[ pyo3( signature = ( request, instance_id, annotated=DEFAULT_ANNOTATED_SETTING ) ) ]
783+ #[ pyo3( signature = ( request, instance_id, annotated=DEFAULT_ANNOTATED_SETTING , context= None ) ) ]
764784 fn direct < ' p > (
765785 & self ,
766786 py : Python < ' p > ,
767787 request : PyObject ,
768788 instance_id : i64 ,
769789 annotated : Option < bool > ,
790+ context : Option < context:: Context > ,
770791 ) -> PyResult < Bound < ' p , PyAny > > {
771792 let request: serde_json:: Value = pythonize:: depythonize ( & request. into_bound ( py) ) ?;
772793 let annotated = annotated. unwrap_or ( false ) ;
@@ -775,10 +796,21 @@ impl Client {
775796 let client = self . router . clone ( ) ;
776797
777798 pyo3_async_runtimes:: tokio:: future_into_py ( py, async move {
778- let stream = client
779- . direct ( request. into ( ) , instance_id)
780- . await
781- . map_err ( to_pyerr) ?;
799+ let stream = match context {
800+ Some ( context) => {
801+ let request = RsContext :: with_id ( request, context. inner ( ) . id ( ) . to_string ( ) ) ;
802+ let stream = client
803+ . direct ( request, instance_id)
804+ . await
805+ . map_err ( to_pyerr) ?;
806+ context. inner ( ) . link_child ( stream. context ( ) ) ;
807+ stream
808+ }
809+ _ => client
810+ . direct ( request. into ( ) , instance_id)
811+ . await
812+ . map_err ( to_pyerr) ?,
813+ } ;
782814
783815 tokio:: spawn ( process_stream ( stream, tx) ) ;
784816
@@ -790,12 +822,13 @@ impl Client {
790822 }
791823
792824 /// Directly send a request to a pre-defined static worker
793- #[ pyo3( signature = ( request, annotated=DEFAULT_ANNOTATED_SETTING ) ) ]
825+ #[ pyo3( signature = ( request, annotated=DEFAULT_ANNOTATED_SETTING , context= None ) ) ]
794826 fn r#static < ' p > (
795827 & self ,
796828 py : Python < ' p > ,
797829 request : PyObject ,
798830 annotated : Option < bool > ,
831+ context : Option < context:: Context > ,
799832 ) -> PyResult < Bound < ' p , PyAny > > {
800833 let request: serde_json:: Value = pythonize:: depythonize ( & request. into_bound ( py) ) ?;
801834 let annotated = annotated. unwrap_or ( false ) ;
@@ -804,7 +837,15 @@ impl Client {
804837 let client = self . router . clone ( ) ;
805838
806839 pyo3_async_runtimes:: tokio:: future_into_py ( py, async move {
807- let stream = client. r#static ( request. into ( ) ) . await . map_err ( to_pyerr) ?;
840+ let stream = match context {
841+ Some ( context) => {
842+ let request = RsContext :: with_id ( request, context. inner ( ) . id ( ) . to_string ( ) ) ;
843+ let stream = client. r#static ( request) . await . map_err ( to_pyerr) ?;
844+ context. inner ( ) . link_child ( stream. context ( ) ) ;
845+ stream
846+ }
847+ _ => client. r#static ( request. into ( ) ) . await . map_err ( to_pyerr) ?,
848+ } ;
808849
809850 tokio:: spawn ( process_stream ( stream, tx) ) ;
810851
0 commit comments