@@ -17,7 +17,7 @@ use crate::{
1717
1818use dynamo_runtime:: {
1919 pipeline:: {
20- async_trait, AsyncEngineContextProvider , ManyOut , Operator , ResponseStream ,
20+ async_trait, AsyncEngineContextProvider , Context , ManyOut , Operator , ResponseStream ,
2121 ServerStreamingEngine , SingleIn ,
2222 } ,
2323 protocols:: { annotated:: Annotated , maybe_error:: MaybeError } ,
5050 next : ServerStreamingEngine < PreprocessedRequest , Annotated < LLMEngineOutput > > ,
5151 ) -> Result < ManyOut < Annotated < LLMEngineOutput > > > {
5252 let ( preprocessed_request, context) = request. transfer ( ( ) ) ;
53+ let context_id = context. id ( ) . to_string ( ) ;
5354 let engine_ctx = context. context ( ) ;
5455 let engine_ctx_ = engine_ctx. clone ( ) ;
5556 let retry_manager =
56- RetryManager :: build ( preprocessed_request, next, self . migration_limit ) . await ?;
57+ RetryManager :: build ( context_id, preprocessed_request, next, self . migration_limit )
58+ . await ?;
5759 let response_stream = stream:: unfold ( retry_manager, move |mut retry_manager| {
5860 let engine_ctx = engine_ctx_. clone ( ) ;
5961 async move {
7173}
7274
7375struct RetryManager {
76+ context_id : String ,
7477 request : PreprocessedRequest ,
7578 next_generate : ServerStreamingEngine < PreprocessedRequest , Annotated < LLMEngineOutput > > ,
7679 next_stream : Option < ManyOut < Annotated < LLMEngineOutput > > > ,
@@ -79,11 +82,13 @@ struct RetryManager {
7982
8083impl RetryManager {
8184 pub async fn build (
85+ context_id : String ,
8286 preprocessed_request : PreprocessedRequest ,
8387 next : ServerStreamingEngine < PreprocessedRequest , Annotated < LLMEngineOutput > > ,
8488 retries_left : u32 ,
8589 ) -> Result < Self > {
8690 let mut slf = Self {
91+ context_id,
8792 request : preprocessed_request,
8893 next_generate : next,
8994 next_stream : None ,
@@ -127,8 +132,7 @@ impl RetryManager {
127132 let mut response_stream: Option < Result < ManyOut < Annotated < LLMEngineOutput > > > > = None ;
128133 while self . retries_left > 0 {
129134 self . retries_left -= 1 ;
130- // TODO: Is there anything needed to pass between context?
131- let request = SingleIn :: new ( self . request . clone ( ) ) ;
135+ let request = Context :: with_id ( self . request . clone ( ) , self . context_id . clone ( ) ) ;
132136 response_stream = Some ( self . next_generate . generate ( request) . await ) ;
133137 if let Some ( err) = response_stream. as_ref ( ) . unwrap ( ) . as_ref ( ) . err ( ) {
134138 if let Some ( req_err) = err. downcast_ref :: < NatsRequestError > ( ) {
@@ -235,15 +239,22 @@ mod tests {
235239 num_responses : usize ,
236240 token_offset : u32 ,
237241 call_count : Arc < AtomicU32 > ,
242+ context_id : String ,
238243 }
239244
240245 impl MockEngine {
241- fn new ( behavior : MockBehavior , num_responses : usize , token_offset : u32 ) -> Self {
246+ fn new (
247+ behavior : MockBehavior ,
248+ num_responses : usize ,
249+ token_offset : u32 ,
250+ context_id : String ,
251+ ) -> Self {
242252 Self {
243253 behavior,
244254 num_responses,
245255 token_offset,
246256 call_count : Arc :: new ( AtomicU32 :: new ( 0 ) ) ,
257+ context_id,
247258 }
248259 }
249260 }
@@ -261,7 +272,14 @@ mod tests {
261272 request : SingleIn < PreprocessedRequest > ,
262273 ) -> Result < ManyOut < Annotated < LLMEngineOutput > > > {
263274 let call_num = self . call_count . fetch_add ( 1 , Ordering :: SeqCst ) ;
264- let ( preprocessed_request, _) = request. transfer ( ( ) ) ;
275+ let ( preprocessed_request, context) = request. transfer ( ( ) ) ;
276+
277+ // Assert that the context_id matches the expected one
278+ assert_eq ! (
279+ context. id( ) . to_string( ) ,
280+ self . context_id,
281+ "Context ID mismatch"
282+ ) ;
265283
266284 // Calculate how many responses we've already generated based on request token_ids
267285 // Initial request has [1, 2, 3], so anything beyond that are generated responses
@@ -336,7 +354,7 @@ mod tests {
336354 }
337355
338356 let stream = tokio_stream:: wrappers:: ReceiverStream :: new ( rx) ;
339- let ctx = Arc :: new ( Controller :: default ( ) ) ;
357+ let ctx = Arc :: new ( Controller :: new ( self . context_id . clone ( ) ) ) ;
340358 Ok ( dynamo_runtime:: pipeline:: ResponseStream :: new (
341359 Box :: pin ( stream) ,
342360 ctx,
@@ -367,7 +385,7 @@ mod tests {
367385 } ) ;
368386
369387 let stream = tokio_stream:: wrappers:: ReceiverStream :: new ( rx) ;
370- let ctx = Arc :: new ( Controller :: default ( ) ) ;
388+ let ctx = Arc :: new ( Controller :: new ( self . context_id . clone ( ) ) ) ;
371389 Ok ( dynamo_runtime:: pipeline:: ResponseStream :: new (
372390 Box :: pin ( stream) ,
373391 ctx,
@@ -403,7 +421,7 @@ mod tests {
403421 } ) ;
404422
405423 let stream = tokio_stream:: wrappers:: ReceiverStream :: new ( rx) ;
406- let ctx = Arc :: new ( Controller :: default ( ) ) ;
424+ let ctx = Arc :: new ( Controller :: new ( self . context_id . clone ( ) ) ) ;
407425 Ok ( dynamo_runtime:: pipeline:: ResponseStream :: new (
408426 Box :: pin ( stream) ,
409427 ctx,
@@ -420,7 +438,7 @@ mod tests {
420438 } ) ;
421439
422440 let stream = tokio_stream:: wrappers:: ReceiverStream :: new ( rx) ;
423- let ctx = Arc :: new ( Controller :: default ( ) ) ;
441+ let ctx = Arc :: new ( Controller :: new ( self . context_id . clone ( ) ) ) ;
424442 Ok ( dynamo_runtime:: pipeline:: ResponseStream :: new (
425443 Box :: pin ( stream) ,
426444 ctx,
@@ -455,7 +473,7 @@ mod tests {
455473 } ) ;
456474
457475 let stream = tokio_stream:: wrappers:: ReceiverStream :: new ( rx) ;
458- let ctx = Arc :: new ( Controller :: default ( ) ) ;
476+ let ctx = Arc :: new ( Controller :: new ( self . context_id . clone ( ) ) ) ;
459477 Ok ( dynamo_runtime:: pipeline:: ResponseStream :: new (
460478 Box :: pin ( stream) ,
461479 ctx,
@@ -469,12 +487,18 @@ mod tests {
469487 /// Expected behavior: All 10 responses should be received successfully.
470488 #[ tokio:: test]
471489 async fn test_retry_manager_no_migration ( ) {
490+ let context_id = uuid:: Uuid :: new_v4 ( ) . to_string ( ) ;
472491 let request = create_mock_request ( 10 ) ;
473- let mock_engine = Arc :: new ( MockEngine :: new ( MockBehavior :: Success , 10 , 100 ) ) ;
492+ let mock_engine = Arc :: new ( MockEngine :: new (
493+ MockBehavior :: Success ,
494+ 10 ,
495+ 100 ,
496+ context_id. clone ( ) ,
497+ ) ) ;
474498 let next_generate: ServerStreamingEngine < PreprocessedRequest , Annotated < LLMEngineOutput > > =
475499 mock_engine;
476500
477- let mut retry_manager = RetryManager :: build ( request, next_generate, 0 )
501+ let mut retry_manager = RetryManager :: build ( context_id , request, next_generate, 0 )
478502 . await
479503 . expect ( "Failed to build RetryManager" ) ;
480504
@@ -500,12 +524,18 @@ mod tests {
500524 /// Expected behavior: All 10 responses should be received successfully after retry.
501525 #[ tokio:: test]
502526 async fn test_retry_manager_new_request_migration ( ) {
527+ let context_id = uuid:: Uuid :: new_v4 ( ) . to_string ( ) ;
503528 let request = create_mock_request ( 10 ) ;
504- let mock_engine = Arc :: new ( MockEngine :: new ( MockBehavior :: FailThenSuccess , 10 , 100 ) ) ;
529+ let mock_engine = Arc :: new ( MockEngine :: new (
530+ MockBehavior :: FailThenSuccess ,
531+ 10 ,
532+ 100 ,
533+ context_id. clone ( ) ,
534+ ) ) ;
505535 let next_generate: ServerStreamingEngine < PreprocessedRequest , Annotated < LLMEngineOutput > > =
506536 mock_engine;
507537
508- let mut retry_manager = RetryManager :: build ( request, next_generate, 3 )
538+ let mut retry_manager = RetryManager :: build ( context_id , request, next_generate, 3 )
509539 . await
510540 . expect ( "Failed to build RetryManager" ) ;
511541
@@ -531,16 +561,18 @@ mod tests {
531561 /// Expected behavior: 5 responses from first stream + 5 responses from retry stream = 10 total.
532562 #[ tokio:: test]
533563 async fn test_retry_manager_ongoing_request_migration ( ) {
564+ let context_id = uuid:: Uuid :: new_v4 ( ) . to_string ( ) ;
534565 let request = create_mock_request ( 10 ) ;
535566 let mock_engine = Arc :: new ( MockEngine :: new (
536567 MockBehavior :: MidStreamFail { fail_after : 5 } ,
537568 10 ,
538569 100 ,
570+ context_id. clone ( ) ,
539571 ) ) ;
540572 let next_generate: ServerStreamingEngine < PreprocessedRequest , Annotated < LLMEngineOutput > > =
541573 mock_engine;
542574
543- let mut retry_manager = RetryManager :: build ( request, next_generate, 3 )
575+ let mut retry_manager = RetryManager :: build ( context_id , request, next_generate, 3 )
544576 . await
545577 . expect ( "Failed to build RetryManager" ) ;
546578
@@ -567,13 +599,19 @@ mod tests {
567599 /// Expected behavior: Should receive an error after all retries are exhausted, with the original error.
568600 #[ tokio:: test]
569601 async fn test_retry_manager_new_request_migration_indefinite_failure ( ) {
602+ let context_id = uuid:: Uuid :: new_v4 ( ) . to_string ( ) ;
570603 let request = create_mock_request ( 0 ) ;
571- let mock_engine = Arc :: new ( MockEngine :: new ( MockBehavior :: AlwaysFail , 0 , 100 ) ) ;
604+ let mock_engine = Arc :: new ( MockEngine :: new (
605+ MockBehavior :: AlwaysFail ,
606+ 0 ,
607+ 100 ,
608+ context_id. clone ( ) ,
609+ ) ) ;
572610 let next_generate: ServerStreamingEngine < PreprocessedRequest , Annotated < LLMEngineOutput > > =
573611 mock_engine;
574612
575613 // Should fail to build due to initial stream creation failure after exhausting all 3 retries
576- let retry_manager_result = RetryManager :: build ( request, next_generate, 3 ) . await ;
614+ let retry_manager_result = RetryManager :: build ( context_id , request, next_generate, 3 ) . await ;
577615
578616 assert ! ( retry_manager_result. is_err( ) ) ;
579617 if let Err ( error) = retry_manager_result {
@@ -587,16 +625,18 @@ mod tests {
587625 /// Expected behavior: Should receive some responses from first stream, then error after retries exhausted.
588626 #[ tokio:: test]
589627 async fn test_retry_manager_ongoing_request_migration_indefinite_failure ( ) {
628+ let context_id = uuid:: Uuid :: new_v4 ( ) . to_string ( ) ;
590629 let request = create_mock_request ( 10 ) ;
591630 let mock_engine = Arc :: new ( MockEngine :: new (
592631 MockBehavior :: MidStreamFailAlways { fail_after : 3 } ,
593632 10 ,
594633 100 ,
634+ context_id. clone ( ) ,
595635 ) ) ;
596636 let next_generate: ServerStreamingEngine < PreprocessedRequest , Annotated < LLMEngineOutput > > =
597637 mock_engine;
598638
599- let mut retry_manager = RetryManager :: build ( request, next_generate, 3 ) // 3 retries
639+ let mut retry_manager = RetryManager :: build ( context_id , request, next_generate, 3 ) // 3 retries
600640 . await
601641 . expect ( "Failed to build RetryManager" ) ;
602642
@@ -634,16 +674,18 @@ mod tests {
634674 /// Expected behavior: Should receive some responses from first stream, then error after retries exhausted.
635675 #[ tokio:: test]
636676 async fn test_retry_manager_ongoing_request_migration_indefinite_failure_stream_error ( ) {
677+ let context_id = uuid:: Uuid :: new_v4 ( ) . to_string ( ) ;
637678 let request = create_mock_request ( 10 ) ;
638679 let mock_engine = Arc :: new ( MockEngine :: new (
639680 MockBehavior :: MidStreamFailAlwaysStreamError { fail_after : 3 } ,
640681 10 ,
641682 100 ,
683+ context_id. clone ( ) ,
642684 ) ) ;
643685 let next_generate: ServerStreamingEngine < PreprocessedRequest , Annotated < LLMEngineOutput > > =
644686 mock_engine;
645687
646- let mut retry_manager = RetryManager :: build ( request, next_generate, 3 ) // 3 retries
688+ let mut retry_manager = RetryManager :: build ( context_id , request, next_generate, 3 ) // 3 retries
647689 . await
648690 . expect ( "Failed to build RetryManager" ) ;
649691
0 commit comments