Skip to content

Commit 4ed5695

Browse files
committed
fix: Retain context_id between requests
1 parent b8b6edc commit 4ed5695

File tree

1 file changed

+62
-20
lines changed

1 file changed

+62
-20
lines changed

lib/llm/src/migration.rs

Lines changed: 62 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use crate::{
1717

1818
use dynamo_runtime::{
1919
pipeline::{
20-
async_trait, AsyncEngineContextProvider, ManyOut, Operator, ResponseStream,
20+
async_trait, AsyncEngineContextProvider, Context, ManyOut, Operator, ResponseStream,
2121
ServerStreamingEngine, SingleIn,
2222
},
2323
protocols::{annotated::Annotated, maybe_error::MaybeError},
@@ -50,10 +50,12 @@ impl
5050
next: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
5151
) -> Result<ManyOut<Annotated<LLMEngineOutput>>> {
5252
let (preprocessed_request, context) = request.transfer(());
53+
let context_id = context.id().to_string();
5354
let engine_ctx = context.context();
5455
let engine_ctx_ = engine_ctx.clone();
5556
let retry_manager =
56-
RetryManager::build(preprocessed_request, next, self.migration_limit).await?;
57+
RetryManager::build(context_id, preprocessed_request, next, self.migration_limit)
58+
.await?;
5759
let response_stream = stream::unfold(retry_manager, move |mut retry_manager| {
5860
let engine_ctx = engine_ctx_.clone();
5961
async move {
@@ -71,6 +73,7 @@ impl
7173
}
7274

7375
struct RetryManager {
76+
context_id: String,
7477
request: PreprocessedRequest,
7578
next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
7679
next_stream: Option<ManyOut<Annotated<LLMEngineOutput>>>,
@@ -79,11 +82,13 @@ struct RetryManager {
7982

8083
impl RetryManager {
8184
pub async fn build(
85+
context_id: String,
8286
preprocessed_request: PreprocessedRequest,
8387
next: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
8488
retries_left: u32,
8589
) -> Result<Self> {
8690
let mut slf = Self {
91+
context_id,
8792
request: preprocessed_request,
8893
next_generate: next,
8994
next_stream: None,
@@ -127,8 +132,7 @@ impl RetryManager {
127132
let mut response_stream: Option<Result<ManyOut<Annotated<LLMEngineOutput>>>> = None;
128133
while self.retries_left > 0 {
129134
self.retries_left -= 1;
130-
// TODO: Is there anything needed to pass between context?
131-
let request = SingleIn::new(self.request.clone());
135+
let request = Context::with_id(self.request.clone(), self.context_id.clone());
132136
response_stream = Some(self.next_generate.generate(request).await);
133137
if let Some(err) = response_stream.as_ref().unwrap().as_ref().err() {
134138
if let Some(req_err) = err.downcast_ref::<NatsRequestError>() {
@@ -235,15 +239,22 @@ mod tests {
235239
num_responses: usize,
236240
token_offset: u32,
237241
call_count: Arc<AtomicU32>,
242+
context_id: String,
238243
}
239244

240245
impl MockEngine {
241-
fn new(behavior: MockBehavior, num_responses: usize, token_offset: u32) -> Self {
246+
fn new(
247+
behavior: MockBehavior,
248+
num_responses: usize,
249+
token_offset: u32,
250+
context_id: String,
251+
) -> Self {
242252
Self {
243253
behavior,
244254
num_responses,
245255
token_offset,
246256
call_count: Arc::new(AtomicU32::new(0)),
257+
context_id,
247258
}
248259
}
249260
}
@@ -261,7 +272,14 @@ mod tests {
261272
request: SingleIn<PreprocessedRequest>,
262273
) -> Result<ManyOut<Annotated<LLMEngineOutput>>> {
263274
let call_num = self.call_count.fetch_add(1, Ordering::SeqCst);
264-
let (preprocessed_request, _) = request.transfer(());
275+
let (preprocessed_request, context) = request.transfer(());
276+
277+
// Assert that the context_id matches the expected one
278+
assert_eq!(
279+
context.id().to_string(),
280+
self.context_id,
281+
"Context ID mismatch"
282+
);
265283

266284
// Calculate how many responses we've already generated based on request token_ids
267285
// Initial request has [1, 2, 3], so anything beyond that are generated responses
@@ -336,7 +354,7 @@ mod tests {
336354
}
337355

338356
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
339-
let ctx = Arc::new(Controller::default());
357+
let ctx = Arc::new(Controller::new(self.context_id.clone()));
340358
Ok(dynamo_runtime::pipeline::ResponseStream::new(
341359
Box::pin(stream),
342360
ctx,
@@ -367,7 +385,7 @@ mod tests {
367385
});
368386

369387
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
370-
let ctx = Arc::new(Controller::default());
388+
let ctx = Arc::new(Controller::new(self.context_id.clone()));
371389
Ok(dynamo_runtime::pipeline::ResponseStream::new(
372390
Box::pin(stream),
373391
ctx,
@@ -403,7 +421,7 @@ mod tests {
403421
});
404422

405423
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
406-
let ctx = Arc::new(Controller::default());
424+
let ctx = Arc::new(Controller::new(self.context_id.clone()));
407425
Ok(dynamo_runtime::pipeline::ResponseStream::new(
408426
Box::pin(stream),
409427
ctx,
@@ -420,7 +438,7 @@ mod tests {
420438
});
421439

422440
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
423-
let ctx = Arc::new(Controller::default());
441+
let ctx = Arc::new(Controller::new(self.context_id.clone()));
424442
Ok(dynamo_runtime::pipeline::ResponseStream::new(
425443
Box::pin(stream),
426444
ctx,
@@ -455,7 +473,7 @@ mod tests {
455473
});
456474

457475
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
458-
let ctx = Arc::new(Controller::default());
476+
let ctx = Arc::new(Controller::new(self.context_id.clone()));
459477
Ok(dynamo_runtime::pipeline::ResponseStream::new(
460478
Box::pin(stream),
461479
ctx,
@@ -469,12 +487,18 @@ mod tests {
469487
/// Expected behavior: All 10 responses should be received successfully.
470488
#[tokio::test]
471489
async fn test_retry_manager_no_migration() {
490+
let context_id = uuid::Uuid::new_v4().to_string();
472491
let request = create_mock_request(10);
473-
let mock_engine = Arc::new(MockEngine::new(MockBehavior::Success, 10, 100));
492+
let mock_engine = Arc::new(MockEngine::new(
493+
MockBehavior::Success,
494+
10,
495+
100,
496+
context_id.clone(),
497+
));
474498
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
475499
mock_engine;
476500

477-
let mut retry_manager = RetryManager::build(request, next_generate, 0)
501+
let mut retry_manager = RetryManager::build(context_id, request, next_generate, 0)
478502
.await
479503
.expect("Failed to build RetryManager");
480504

@@ -500,12 +524,18 @@ mod tests {
500524
/// Expected behavior: All 10 responses should be received successfully after retry.
501525
#[tokio::test]
502526
async fn test_retry_manager_new_request_migration() {
527+
let context_id = uuid::Uuid::new_v4().to_string();
503528
let request = create_mock_request(10);
504-
let mock_engine = Arc::new(MockEngine::new(MockBehavior::FailThenSuccess, 10, 100));
529+
let mock_engine = Arc::new(MockEngine::new(
530+
MockBehavior::FailThenSuccess,
531+
10,
532+
100,
533+
context_id.clone(),
534+
));
505535
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
506536
mock_engine;
507537

508-
let mut retry_manager = RetryManager::build(request, next_generate, 3)
538+
let mut retry_manager = RetryManager::build(context_id, request, next_generate, 3)
509539
.await
510540
.expect("Failed to build RetryManager");
511541

@@ -531,16 +561,18 @@ mod tests {
531561
/// Expected behavior: 5 responses from first stream + 5 responses from retry stream = 10 total.
532562
#[tokio::test]
533563
async fn test_retry_manager_ongoing_request_migration() {
564+
let context_id = uuid::Uuid::new_v4().to_string();
534565
let request = create_mock_request(10);
535566
let mock_engine = Arc::new(MockEngine::new(
536567
MockBehavior::MidStreamFail { fail_after: 5 },
537568
10,
538569
100,
570+
context_id.clone(),
539571
));
540572
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
541573
mock_engine;
542574

543-
let mut retry_manager = RetryManager::build(request, next_generate, 3)
575+
let mut retry_manager = RetryManager::build(context_id, request, next_generate, 3)
544576
.await
545577
.expect("Failed to build RetryManager");
546578

@@ -567,13 +599,19 @@ mod tests {
567599
/// Expected behavior: Should receive an error after all retries are exhausted, with the original error.
568600
#[tokio::test]
569601
async fn test_retry_manager_new_request_migration_indefinite_failure() {
602+
let context_id = uuid::Uuid::new_v4().to_string();
570603
let request = create_mock_request(0);
571-
let mock_engine = Arc::new(MockEngine::new(MockBehavior::AlwaysFail, 0, 100));
604+
let mock_engine = Arc::new(MockEngine::new(
605+
MockBehavior::AlwaysFail,
606+
0,
607+
100,
608+
context_id.clone(),
609+
));
572610
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
573611
mock_engine;
574612

575613
// Should fail to build due to initial stream creation failure after exhausting all 3 retries
576-
let retry_manager_result = RetryManager::build(request, next_generate, 3).await;
614+
let retry_manager_result = RetryManager::build(context_id, request, next_generate, 3).await;
577615

578616
assert!(retry_manager_result.is_err());
579617
if let Err(error) = retry_manager_result {
@@ -587,16 +625,18 @@ mod tests {
587625
/// Expected behavior: Should receive some responses from first stream, then error after retries exhausted.
588626
#[tokio::test]
589627
async fn test_retry_manager_ongoing_request_migration_indefinite_failure() {
628+
let context_id = uuid::Uuid::new_v4().to_string();
590629
let request = create_mock_request(10);
591630
let mock_engine = Arc::new(MockEngine::new(
592631
MockBehavior::MidStreamFailAlways { fail_after: 3 },
593632
10,
594633
100,
634+
context_id.clone(),
595635
));
596636
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
597637
mock_engine;
598638

599-
let mut retry_manager = RetryManager::build(request, next_generate, 3) // 3 retries
639+
let mut retry_manager = RetryManager::build(context_id, request, next_generate, 3) // 3 retries
600640
.await
601641
.expect("Failed to build RetryManager");
602642

@@ -634,16 +674,18 @@ mod tests {
634674
/// Expected behavior: Should receive some responses from first stream, then error after retries exhausted.
635675
#[tokio::test]
636676
async fn test_retry_manager_ongoing_request_migration_indefinite_failure_stream_error() {
677+
let context_id = uuid::Uuid::new_v4().to_string();
637678
let request = create_mock_request(10);
638679
let mock_engine = Arc::new(MockEngine::new(
639680
MockBehavior::MidStreamFailAlwaysStreamError { fail_after: 3 },
640681
10,
641682
100,
683+
context_id.clone(),
642684
));
643685
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
644686
mock_engine;
645687

646-
let mut retry_manager = RetryManager::build(request, next_generate, 3) // 3 retries
688+
let mut retry_manager = RetryManager::build(context_id, request, next_generate, 3) // 3 retries
647689
.await
648690
.expect("Failed to build RetryManager");
649691

0 commit comments

Comments
 (0)