diff --git a/launch/dynamo-run/src/flags.rs b/launch/dynamo-run/src/flags.rs index acc909a0a2..4bb086410a 100644 --- a/launch/dynamo-run/src/flags.rs +++ b/launch/dynamo-run/src/flags.rs @@ -162,6 +162,11 @@ pub struct Flags { #[arg(long)] pub request_template: Option, + /// How many times a request can be migrated to another worker if the HTTP server lost + /// connection to the current worker. + #[arg(long, value_parser = clap::value_parser!(u32).range(0..1024))] + pub migration_limit: Option, + /// Everything after a `--`. /// These are the command line arguments to the python engine when using `pystr` or `pytok`. #[arg(index = 2, last = true, hide = true, allow_hyphen_values = true)] @@ -180,6 +185,9 @@ impl Flags { if self.kv_cache_block_size.is_some() { anyhow::bail!("'--kv-cache-block-size' flag should only be used on the worker node, not on the ingress"); } + if self.migration_limit.is_some() { + anyhow::bail!("'--migration-limit' flag should only be used on the worker node, not on the ingress"); + } } Output::EchoFull => {} Output::EchoCore => { diff --git a/launch/dynamo-run/src/lib.rs b/launch/dynamo-run/src/lib.rs index b9525b40e3..5662db762c 100644 --- a/launch/dynamo-run/src/lib.rs +++ b/launch/dynamo-run/src/lib.rs @@ -45,7 +45,8 @@ pub async fn run( .context_length(flags.context_length) .http_port(Some(flags.http_port)) .router_config(Some(flags.router_config())) - .request_template(flags.request_template.clone()); + .request_template(flags.request_template.clone()) + .migration_limit(flags.migration_limit); // If `in=dyn` we want the trtllm/sglang/vllm subprocess to listen on that endpoint. // If not, then the endpoint isn't exposed so we let LocalModel invent one. diff --git a/launch/dynamo-run/src/subprocess.rs b/launch/dynamo-run/src/subprocess.rs index 73ed24f7bd..0e714918b5 100644 --- a/launch/dynamo-run/src/subprocess.rs +++ b/launch/dynamo-run/src/subprocess.rs @@ -48,6 +48,8 @@ pub async fn start( card.kv_cache_block_size.to_string(), "--context-length".to_string(), card.context_length.to_string(), + "--migration-limit".to_string(), + card.migration_limit.to_string(), ]; // TRTLLM only // The worker node will only publish events and metrics if the router mode is KV diff --git a/launch/dynamo-run/src/subprocess/sglang_inc.py b/launch/dynamo-run/src/subprocess/sglang_inc.py index 43827677e0..2d97c29116 100644 --- a/launch/dynamo-run/src/subprocess/sglang_inc.py +++ b/launch/dynamo-run/src/subprocess/sglang_inc.py @@ -42,6 +42,7 @@ class Config: nnodes: int node_rank: int dist_init_addr: str + migration_limit: int extra_engine_args: str @@ -202,7 +203,13 @@ async def init(runtime: DistributedRuntime, config: Config): model_type = ( ModelType.Backend if not engine_args.is_embedding else ModelType.Embedding ) - await register_llm(model_type, endpoint, config.model_path, config.model_name) + await register_llm( + model_type, + endpoint, + config.model_path, + config.model_name, + migration_limit=config.migration_limit, + ) # the server will gracefully shutdown (i.e., keep opened TCP streams finishes) # after the lease is revoked @@ -268,6 +275,12 @@ def cmd_line_args(): default="", help="Host address (e.g., `192.168.0.2:25000`) of the node with rank 0", ) + parser.add_argument( + "--migration-limit", + type=int, + default=0, + help="Maximum number of times a request may be migrated to a different engine worker. The number may be overridden by the engine.", + ) parser.add_argument( "--extra-engine-args", type=str, @@ -304,6 +317,7 @@ def cmd_line_args(): config.nnodes = args.nnodes config.node_rank = args.node_rank config.dist_init_addr = args.dist_init_addr + config.migration_limit = args.migration_limit config.extra_engine_args = args.extra_engine_args return config diff --git a/launch/dynamo-run/src/subprocess/trtllm_inc.py b/launch/dynamo-run/src/subprocess/trtllm_inc.py index 4a009ee343..5d79de4467 100644 --- a/launch/dynamo-run/src/subprocess/trtllm_inc.py +++ b/launch/dynamo-run/src/subprocess/trtllm_inc.py @@ -122,6 +122,7 @@ class Config: model_name: Optional[str] = None tensor_parallel_size: int kv_block_size: int + migration_limit: int extra_engine_args: str publish_events_and_metrics: bool disaggregation_mode: str @@ -136,6 +137,7 @@ def __str__(self) -> str: f"model_name={self.model_name}, " f"tensor_parallel_size={self.tensor_parallel_size}, " f"kv_block_size={self.kv_block_size}, " + f"migration_limit={self.migration_limit}, " f"extra_engine_args={self.extra_engine_args}, " f"publish_events_and_metrics={self.publish_events_and_metrics}, " f"disaggregation_mode={self.disaggregation_mode}, " @@ -404,6 +406,7 @@ async def init(runtime: DistributedRuntime, config: Config): config.model_path, config.model_name, kv_cache_block_size=config.kv_block_size, + migration_limit=config.migration_limit, ) # publisher will be set later if publishing is enabled. @@ -476,6 +479,12 @@ def cmd_line_args(): default=None, help="This argument is not used by TRTLLM. Please provide max_input_len, max_seq_len and max_output_len in yaml file and point --extra-engine-args to the yaml file.", ) + parser.add_argument( + "--migration-limit", + type=int, + default=0, + help="Maximum number of times a request may be migrated to a different engine worker. The number may be overridden by the engine.", + ) parser.add_argument( "--extra-engine-args", type=str, @@ -557,6 +566,7 @@ def cmd_line_args(): config.endpoint = parsed_endpoint_name config.tensor_parallel_size = args.tensor_parallel_size config.kv_block_size = args.kv_block_size + config.migration_limit = args.migration_limit config.extra_engine_args = args.extra_engine_args config.publish_events_and_metrics = args.publish_events_and_metrics config.disaggregation_mode = disaggregation_mode diff --git a/launch/dynamo-run/src/subprocess/vllm_inc.py b/launch/dynamo-run/src/subprocess/vllm_inc.py index 9085af9af1..583b1a311e 100644 --- a/launch/dynamo-run/src/subprocess/vllm_inc.py +++ b/launch/dynamo-run/src/subprocess/vllm_inc.py @@ -56,6 +56,7 @@ class Config: tensor_parallel_size: int kv_block_size: int context_length: int + migration_limit: int extra_engine_args: str @@ -233,6 +234,7 @@ async def init(runtime: DistributedRuntime, config: Config): "max_model_len", None ), # if None, takes length from tokenizer kv_cache_block_size=arg_map["block_size"], + migration_limit=config.migration_limit, ) handler = RequestHandler(component, engine_client, default_sampling_params) handler.setup_kv_metrics() @@ -276,6 +278,12 @@ def cmd_line_args(): default=None, help="Max model context length. Defaults to models max, usually model_max_length from tokenizer_config.json. Reducing this reduces VRAM requirements.", ) + parser.add_argument( + "--migration-limit", + type=int, + default=0, + help="Maximum number of times a request may be migrated to a different engine worker. The number may be overridden by the engine.", + ) parser.add_argument( "--extra-engine-args", type=str, @@ -308,6 +316,7 @@ def cmd_line_args(): config.tensor_parallel_size = args.tensor_parallel_size config.kv_block_size = args.kv_block_size config.context_length = args.context_length + config.migration_limit = args.migration_limit config.extra_engine_args = args.extra_engine_args return config diff --git a/launch/dynamo-run/src/subprocess/vllm_v1_inc.py b/launch/dynamo-run/src/subprocess/vllm_v1_inc.py index 0fc956cc49..c5e9465434 100644 --- a/launch/dynamo-run/src/subprocess/vllm_v1_inc.py +++ b/launch/dynamo-run/src/subprocess/vllm_v1_inc.py @@ -65,6 +65,7 @@ class Config: tensor_parallel_size: int kv_block_size: int context_length: int + migration_limit: int extra_engine_args: str @@ -218,6 +219,7 @@ async def init(runtime: DistributedRuntime, config: Config): config.model_path, config.model_name, kv_cache_block_size=config.kv_block_size, + migration_limit=config.migration_limit, ) arg_map = { @@ -333,6 +335,12 @@ def cmd_line_args(): default=None, help="Max model context length. Defaults to models max, usually model_max_length from tokenizer_config.json. Reducing this reduces VRAM requirements.", ) + parser.add_argument( + "--migration-limit", + type=int, + default=0, + help="Maximum number of times a request may be migrated to a different engine worker. The number may be overridden by the engine.", + ) parser.add_argument( "--extra-engine-args", type=str, @@ -365,6 +373,7 @@ def cmd_line_args(): config.tensor_parallel_size = args.tensor_parallel_size config.kv_block_size = args.kv_block_size config.context_length = args.context_length + config.migration_limit = args.migration_limit config.extra_engine_args = args.extra_engine_args return config diff --git a/lib/bindings/python/rust/lib.rs b/lib/bindings/python/rust/lib.rs index 12d639c4d6..5b548352f8 100644 --- a/lib/bindings/python/rust/lib.rs +++ b/lib/bindings/python/rust/lib.rs @@ -131,7 +131,7 @@ fn log_message(level: &str, message: &str, module: &str, file: &str, line: u32) } #[pyfunction] -#[pyo3(signature = (model_type, endpoint, model_path, model_name=None, context_length=None, kv_cache_block_size=None, router_mode=None))] +#[pyo3(signature = (model_type, endpoint, model_path, model_name=None, context_length=None, kv_cache_block_size=None, router_mode=None, migration_limit=0))] #[allow(clippy::too_many_arguments)] fn register_llm<'p>( py: Python<'p>, @@ -142,6 +142,7 @@ fn register_llm<'p>( context_length: Option, kv_cache_block_size: Option, router_mode: Option, + migration_limit: u32, ) -> PyResult> { let model_type_obj = match model_type { ModelType::Chat => llm_rs::model_type::ModelType::Chat, @@ -162,7 +163,8 @@ fn register_llm<'p>( .model_name(model_name) .context_length(context_length) .kv_cache_block_size(kv_cache_block_size) - .router_config(Some(router_config)); + .router_config(Some(router_config)) + .migration_limit(Some(migration_limit)); // Download from HF, load the ModelDeploymentCard let mut local_model = builder.build().await.map_err(to_pyerr)?; // Advertise ourself on etcd so ingress can find us diff --git a/lib/llm/src/discovery/watcher.rs b/lib/llm/src/discovery/watcher.rs index 12b497b88f..4b111fe18c 100644 --- a/lib/llm/src/discovery/watcher.rs +++ b/lib/llm/src/discovery/watcher.rs @@ -19,6 +19,7 @@ use dynamo_runtime::{ use crate::{ backend::Backend, kv_router::{KvPushRouter, KvRouterConfig}, + migration::Migration, model_type::ModelType, preprocessor::{OpenAIPreprocessor, PreprocessedEmbeddingRequest, PreprocessedRequest}, protocols::common::llm_backend::{EmbeddingsEngineOutput, LLMEngineOutput}, @@ -197,12 +198,14 @@ impl ModelWatcher { // function. Needs checking carefully, possibly we need to store it in state. let _cache_dir = Some(card.move_from_nats(self.drt.nats_client()).await?); + // Chat Completions let frontend = SegmentSource::< SingleIn, ManyOut>, >::new(); let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator(); let backend = Backend::from_mdc(card.clone()).await?.into_operator(); + let migration = Migration::from_mdc(card.clone()).await?.into_operator(); let router = PushRouter::>::from_client( client.clone(), @@ -231,19 +234,23 @@ impl ModelWatcher { let chat_engine = frontend .link(preprocessor.forward_edge())? .link(backend.forward_edge())? + .link(migration.forward_edge())? .link(service_backend)? + .link(migration.backward_edge())? .link(backend.backward_edge())? .link(preprocessor.backward_edge())? .link(frontend)?; self.manager .add_chat_completions_model(&model_entry.name, chat_engine)?; + // Completions let frontend = SegmentSource::< SingleIn, ManyOut>, >::new(); let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator(); let backend = Backend::from_mdc(card.clone()).await?.into_operator(); + let migration = Migration::from_mdc(card.clone()).await?.into_operator(); let router = PushRouter::>::from_client( client, @@ -272,7 +279,9 @@ impl ModelWatcher { let completions_engine = frontend .link(preprocessor.forward_edge())? .link(backend.forward_edge())? + .link(migration.forward_edge())? .link(service_backend)? + .link(migration.backward_edge())? .link(backend.backward_edge())? .link(preprocessor.backward_edge())? .link(frontend)?; diff --git a/lib/llm/src/lib.rs b/lib/llm/src/lib.rs index e5e2e50fb6..19d1ef76a1 100644 --- a/lib/llm/src/lib.rs +++ b/lib/llm/src/lib.rs @@ -22,6 +22,7 @@ pub mod hub; // pub mod key_value_store; pub mod kv_router; pub mod local_model; +pub mod migration; pub mod mocker; pub mod model_card; pub mod model_type; diff --git a/lib/llm/src/local_model.rs b/lib/llm/src/local_model.rs index 7a0f6636f7..c32ca25bdb 100644 --- a/lib/llm/src/local_model.rs +++ b/lib/llm/src/local_model.rs @@ -46,6 +46,7 @@ pub struct LocalModelBuilder { router_config: Option, kv_cache_block_size: u32, http_port: u16, + migration_limit: u32, } impl Default for LocalModelBuilder { @@ -60,6 +61,7 @@ impl Default for LocalModelBuilder { context_length: Default::default(), template_file: Default::default(), router_config: Default::default(), + migration_limit: Default::default(), } } } @@ -112,6 +114,11 @@ impl LocalModelBuilder { self } + pub fn migration_limit(&mut self, migration_limit: Option) -> &mut Self { + self.migration_limit = migration_limit.unwrap_or(0); + self + } + /// Make an LLM ready for use: /// - Download it from Hugging Face (and NGC in future) if necessary /// - Resolve the path @@ -137,10 +144,12 @@ impl LocalModelBuilder { // echo_full engine doesn't need a path. It's an edge case, move it out of the way. if self.model_path.is_none() { + let mut card = ModelDeploymentCard::with_name_only( + self.model_name.as_deref().unwrap_or(DEFAULT_NAME), + ); + card.migration_limit = self.migration_limit; return Ok(LocalModel { - card: ModelDeploymentCard::with_name_only( - self.model_name.as_deref().unwrap_or(DEFAULT_NAME), - ), + card, full_path: PathBuf::new(), endpoint_id, template, @@ -194,6 +203,8 @@ impl LocalModelBuilder { card.context_length = context_length; } + card.migration_limit = self.migration_limit; + Ok(LocalModel { card, full_path, diff --git a/lib/llm/src/migration.rs b/lib/llm/src/migration.rs new file mode 100644 index 0000000000..2703ca2922 --- /dev/null +++ b/lib/llm/src/migration.rs @@ -0,0 +1,662 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use anyhow::{Error, Result}; +use futures::{stream, stream::StreamExt}; + +use async_nats::client::{ + RequestError as NatsRequestError, RequestErrorKind::NoResponders as NatsNoResponders, +}; + +use crate::{ + model_card::model::ModelDeploymentCard, + protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest}, +}; + +use dynamo_runtime::{ + pipeline::{ + async_trait, AsyncEngineContextProvider, ManyOut, Operator, ResponseStream, + ServerStreamingEngine, SingleIn, + }, + protocols::{annotated::Annotated, maybe_error::MaybeError}, +}; + +pub struct Migration { + migration_limit: u32, +} + +impl Migration { + pub async fn from_mdc(mdc: ModelDeploymentCard) -> Result> { + Ok(Arc::new(Self { + migration_limit: mdc.migration_limit, + })) + } +} + +#[async_trait] +impl + Operator< + SingleIn, + ManyOut>, + SingleIn, + ManyOut>, + > for Migration +{ + async fn generate( + &self, + request: SingleIn, + next: ServerStreamingEngine>, + ) -> Result>> { + let (preprocessed_request, context) = request.transfer(()); + let engine_ctx = context.context(); + let retry_manager = + RetryManager::build(preprocessed_request, next, self.migration_limit).await?; + let response_stream = stream::unfold(retry_manager, |mut retry_manager| async move { + retry_manager + .next() + .await + .map(|response| (response, retry_manager)) + }); + Ok(ResponseStream::new(Box::pin(response_stream), engine_ctx)) + } +} + +struct RetryManager { + request: PreprocessedRequest, + next_generate: ServerStreamingEngine>, + next_stream: Option>>, + retries_left: u32, +} + +impl RetryManager { + pub async fn build( + preprocessed_request: PreprocessedRequest, + next: ServerStreamingEngine>, + retries_left: u32, + ) -> Result { + let mut slf = Self { + request: preprocessed_request, + next_generate: next, + next_stream: None, + retries_left: retries_left + 1, // +1 to account for the initial attempt + }; + slf.new_stream().await?; + Ok(slf) + } + + pub async fn next(&mut self) -> Option> { + loop { + let response_stream = match self.next_stream.as_mut() { + Some(stream) => stream, + None => { + tracing::error!("next() called with next_stream is None - should not happen"); + return Some(Annotated::from_err( + Error::msg("next_stream is None").into(), + )); + } + }; + if let Some(response) = response_stream.next().await { + if let Some(err) = response.err() { + const STREAM_ERR_MSG: &str = "Stream ended before generation completed"; + if format!("{:?}", err) == STREAM_ERR_MSG { + tracing::info!("Stream disconnected... recreating stream..."); + if let Err(err) = self.new_stream().await { + tracing::info!("Cannot recreate stream: {:?}", err); + } else { + continue; + } + } + } + self.track_response(&response); + return Some(response); + } + return None; + } + } + + async fn new_stream(&mut self) -> Result<()> { + let mut response_stream: Option>>> = None; + while self.retries_left > 0 { + self.retries_left -= 1; + // TODO: Is there anything needed to pass between context? + let request = SingleIn::new(self.request.clone()); + response_stream = Some(self.next_generate.generate(request).await); + if let Some(err) = response_stream.as_ref().unwrap().as_ref().err() { + if let Some(req_err) = err.downcast_ref::() { + if matches!(req_err.kind(), NatsNoResponders) { + tracing::info!("Creating new stream... retrying..."); + continue; + } + } + } + break; + } + match response_stream { + Some(Ok(next_stream)) => { + self.next_stream = Some(next_stream); + Ok(()) + } + Some(Err(err)) => Err(err), // should propagate streaming error if stream started + None => Err(Error::msg( + "Retries exhausted - should propagate streaming error", + )), + } + } + + fn track_response(&mut self, response: &Annotated) { + if self.retries_left == 0 { + return; + } + let llm_engine_output = match response.data.as_ref() { + Some(output) => output, + None => return, + }; + for token_id in llm_engine_output.token_ids.iter() { + self.request.token_ids.push(*token_id); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::protocols::common::{SamplingOptions, StopConditions}; + use dynamo_runtime::pipeline::context::Controller; + use dynamo_runtime::pipeline::AsyncEngine; + use std::sync::atomic::{AtomicU32, Ordering}; + use tokio::sync::mpsc; + + // Helper to create a mock preprocessed request + fn create_mock_request() -> PreprocessedRequest { + PreprocessedRequest { + token_ids: vec![1, 2, 3], + batch_token_ids: None, + stop_conditions: StopConditions::default(), + sampling_options: SamplingOptions::default(), + eos_token_ids: vec![], + mdc_sum: None, + annotations: vec![], + estimated_prefix_hit_num_blocks: None, + } + } + + // Helper to create mock LLM engine output + fn create_mock_output(token_id: u32) -> Annotated { + Annotated::from_data(LLMEngineOutput { + token_ids: vec![token_id], + tokens: None, + text: Some(format!("token_{}", token_id)), + cum_log_probs: None, + log_probs: None, + finish_reason: None, + index: None, + }) + } + + #[derive(Debug, Clone)] + enum MockBehavior { + /// Always succeeds with all responses + Success, + /// Fails on first call with NoResponders error, then succeeds on subsequent calls + FailThenSuccess, + /// Succeeds initially, fails mid-stream with specific error, then succeeds on retry + MidStreamFail { fail_after: usize }, + /// Succeeds initially, fails mid-stream with specific error, then always fails on retry attempts + MidStreamFailAlways { fail_after: usize }, + /// Succeeds initially, fails mid-stream, then always fails with stream error on retry attempts + MidStreamFailAlwaysStreamError { fail_after: usize }, + /// Always fails with NoResponders error (same as FailThenSuccess first call) + AlwaysFail, + } + + // Unified mock server streaming engine that can simulate different scenarios + struct MockEngine { + behavior: MockBehavior, + num_responses: usize, + token_offset: u32, + call_count: Arc, + } + + impl MockEngine { + fn new(behavior: MockBehavior, num_responses: usize, token_offset: u32) -> Self { + Self { + behavior, + num_responses, + token_offset, + call_count: Arc::new(AtomicU32::new(0)), + } + } + } + + #[async_trait] + impl + AsyncEngine< + SingleIn, + ManyOut>, + anyhow::Error, + > for MockEngine + { + async fn generate( + &self, + request: SingleIn, + ) -> Result>> { + let call_num = self.call_count.fetch_add(1, Ordering::SeqCst); + let (preprocessed_request, _) = request.transfer(()); + + // Calculate how many responses we've already generated based on request token_ids + // Initial request has [1, 2, 3], so anything beyond that are generated responses + let initial_tokens = 3; // [1, 2, 3] + let responses_already_generated = preprocessed_request + .token_ids + .len() + .saturating_sub(initial_tokens); + let _responses_remaining = self + .num_responses + .saturating_sub(responses_already_generated); + + match &self.behavior { + MockBehavior::Success => { + // Always succeed with remaining responses + self.send_responses(responses_already_generated, self.num_responses) + .await + } + MockBehavior::FailThenSuccess => { + if call_num == 0 { + // First call - return "No responders available" error to trigger retry + let nats_error: NatsRequestError = NatsNoResponders.into(); + return Err(nats_error.into()); + } else { + // Subsequent calls - succeed with remaining responses + self.send_responses(responses_already_generated, self.num_responses) + .await + } + } + MockBehavior::MidStreamFail { fail_after } => { + let (tx, rx) = mpsc::channel(1); + let token_offset = self.token_offset; + let fail_after = *fail_after; + let num_responses = self.num_responses; + + if call_num == 0 { + // First call - send some responses then an error to simulate disconnection + tokio::spawn(async move { + // Send responses from current position to fail_after + for i in responses_already_generated..fail_after.min(num_responses) { + let response = create_mock_output(token_offset + 1 + i as u32); + if tx.send(response).await.is_err() { + break; + } + } + // Send the specific error that triggers retry logic + let error_response = Annotated::from_err( + anyhow::Error::msg("Stream ended before generation completed") + .into(), + ); + let _ = tx.send(error_response).await; + }); + } else { + // Second call - send remaining responses from where we left off + tokio::spawn(async move { + for i in responses_already_generated..num_responses { + let response = create_mock_output(token_offset + 1 + i as u32); + if tx.send(response).await.is_err() { + break; + } + } + }); + } + + let stream = tokio_stream::wrappers::ReceiverStream::new(rx); + let ctx = Arc::new(Controller::default()); + Ok(dynamo_runtime::pipeline::ResponseStream::new( + Box::pin(stream), + ctx, + )) + } + MockBehavior::MidStreamFailAlways { fail_after } => { + if call_num == 0 { + // First call - send some responses then an error to simulate disconnection + let (tx, rx) = mpsc::channel(1); + let token_offset = self.token_offset; + let fail_after = *fail_after; + let num_responses = self.num_responses; + + tokio::spawn(async move { + // Send responses from current position to fail_after + for i in responses_already_generated..fail_after.min(num_responses) { + let response = create_mock_output(token_offset + 1 + i as u32); + if tx.send(response).await.is_err() { + break; + } + } + // Send the specific error that triggers retry logic + let error_response = Annotated::from_err( + anyhow::Error::msg("Stream ended before generation completed") + .into(), + ); + let _ = tx.send(error_response).await; + }); + + let stream = tokio_stream::wrappers::ReceiverStream::new(rx); + let ctx = Arc::new(Controller::default()); + Ok(dynamo_runtime::pipeline::ResponseStream::new( + Box::pin(stream), + ctx, + )) + } else { + // Subsequent calls - always fail with NoResponders error (same as AlwaysFail) + let nats_error: NatsRequestError = NatsNoResponders.into(); + Err(nats_error.into()) + } + } + MockBehavior::MidStreamFailAlwaysStreamError { fail_after } => { + let (tx, rx) = mpsc::channel(1); + let token_offset = self.token_offset; + let fail_after = *fail_after; + let num_responses = self.num_responses; + + if call_num == 0 { + // First call - send some responses then an error to simulate disconnection + tokio::spawn(async move { + // Send responses from current position to fail_after + for i in responses_already_generated..fail_after.min(num_responses) { + let response = create_mock_output(token_offset + 1 + i as u32); + if tx.send(response).await.is_err() { + break; + } + } + // Send the specific error that triggers retry logic + let error_response = Annotated::from_err( + anyhow::Error::msg("Stream ended before generation completed") + .into(), + ); + let _ = tx.send(error_response).await; + }); + + let stream = tokio_stream::wrappers::ReceiverStream::new(rx); + let ctx = Arc::new(Controller::default()); + Ok(dynamo_runtime::pipeline::ResponseStream::new( + Box::pin(stream), + ctx, + )) + } else { + // Subsequent calls - immediately send stream error (no successful responses) + tokio::spawn(async move { + // Send the stream error immediately + let error_response = Annotated::from_err( + anyhow::Error::msg("Stream ended before generation completed") + .into(), + ); + let _ = tx.send(error_response).await; + }); + + let stream = tokio_stream::wrappers::ReceiverStream::new(rx); + let ctx = Arc::new(Controller::default()); + Ok(dynamo_runtime::pipeline::ResponseStream::new( + Box::pin(stream), + ctx, + )) + } + } + MockBehavior::AlwaysFail => { + // Always fail with NoResponders error (same as FailThenSuccess first call) + let nats_error: NatsRequestError = NatsNoResponders.into(); + Err(nats_error.into()) + } + } + } + } + + impl MockEngine { + async fn send_responses( + &self, + start: usize, + end: usize, + ) -> Result>> { + let (tx, rx) = mpsc::channel(1); + let token_offset = self.token_offset; + + tokio::spawn(async move { + for i in start..end { + let response = create_mock_output(token_offset + 1 + i as u32); + if tx.send(response).await.is_err() { + break; + } + } + }); + + let stream = tokio_stream::wrappers::ReceiverStream::new(rx); + let ctx = Arc::new(Controller::default()); + Ok(dynamo_runtime::pipeline::ResponseStream::new( + Box::pin(stream), + ctx, + )) + } + } + + /// Test case 1: No migration needed + /// Tests the normal case where the RetryManager successfully processes all responses + /// from a single stream without any failures or need for retries/migration. + /// Expected behavior: All 10 responses should be received successfully. + #[tokio::test] + async fn test_retry_manager_no_migration() { + let request = create_mock_request(); + let mock_engine = Arc::new(MockEngine::new(MockBehavior::Success, 10, 100)); + let next_generate: ServerStreamingEngine> = + mock_engine; + + let mut retry_manager = RetryManager::build(request, next_generate, 0) + .await + .expect("Failed to build RetryManager"); + + let mut responses = Vec::new(); + while let Some(response) = retry_manager.next().await { + responses.push(response); + } + + assert_eq!(responses.len(), 10); + for (i, response) in responses.iter().enumerate() { + assert!(response.err().is_none()); + if let Some(output) = &response.data { + assert_eq!(output.token_ids, vec![101 + i as u32]); // 101, 102, 103, ..., 110 + } + } + } + + /// Test case 2: New request migration + /// Tests the scenario where a worker becomes unreachable for new requests initially, + /// triggering the RetryManager to retry the request. The MockEngine with FailThenSuccess + /// fails on the first call with a "No responders available" error, then succeeds + /// on subsequent calls, simulating a worker becoming available after initial failure. + /// Expected behavior: All 10 responses should be received successfully after retry. + #[tokio::test] + async fn test_retry_manager_new_request_migration() { + let request = create_mock_request(); + let mock_engine = Arc::new(MockEngine::new(MockBehavior::FailThenSuccess, 10, 100)); + let next_generate: ServerStreamingEngine> = + mock_engine; + + let mut retry_manager = RetryManager::build(request, next_generate, 3) + .await + .expect("Failed to build RetryManager"); + + let mut responses = Vec::new(); + while let Some(response) = retry_manager.next().await { + responses.push(response); + } + + assert_eq!(responses.len(), 10); + for (i, response) in responses.iter().enumerate() { + assert!(response.err().is_none()); + if let Some(output) = &response.data { + assert_eq!(output.token_ids, vec![101 + i as u32]); // 101, 102, 103, ..., 110 + } + } + } + + /// Test case 3: Ongoing request migration + /// Tests the scenario where a worker fails mid-stream during an ongoing request. + /// This simulates a connection being lost after partial response delivery, requiring + /// the RetryManager to detect the failure (via "Stream ended before generation completed" error), + /// create a new stream, and continue from where it left off. + /// Expected behavior: 5 responses from first stream + 5 responses from retry stream = 10 total. + #[tokio::test] + async fn test_retry_manager_ongoing_request_migration() { + let request = create_mock_request(); + let mock_engine = Arc::new(MockEngine::new( + MockBehavior::MidStreamFail { fail_after: 5 }, + 10, + 100, + )); + let next_generate: ServerStreamingEngine> = + mock_engine; + + let mut retry_manager = RetryManager::build(request, next_generate, 3) + .await + .expect("Failed to build RetryManager"); + + let mut responses = Vec::new(); + while let Some(response) = retry_manager.next().await { + responses.push(response); + } + + // Should have received all 10 responses (5 from first stream + 5 from second stream) + assert_eq!(responses.len(), 10); + + // Check that we received responses from both streams + for (i, response) in responses.iter().enumerate() { + assert!(response.err().is_none()); + if let Some(output) = &response.data { + assert_eq!(output.token_ids, vec![101 + i as u32]); // 101, 102, 103, ..., 110 + } + } + } + + /// Test case 4: New request migration - indefinite failure + /// Tests the scenario where a worker becomes unreachable for new requests indefinitely. + /// The RetryManager should exhaust all retries and return the original error from the first attempt. + /// Expected behavior: Should receive an error after all retries are exhausted, with the original error. + #[tokio::test] + async fn test_retry_manager_new_request_migration_indefinite_failure() { + let request = create_mock_request(); + let mock_engine = Arc::new(MockEngine::new(MockBehavior::AlwaysFail, 0, 100)); + let next_generate: ServerStreamingEngine> = + mock_engine; + + // Should fail to build due to initial stream creation failure after exhausting all 3 retries + let retry_manager_result = RetryManager::build(request, next_generate, 3).await; + + assert!(retry_manager_result.is_err()); + if let Err(error) = retry_manager_result { + assert!(error.to_string().contains("no responders")); + } + } + + /// Test case 5: Ongoing request migration - indefinite failure + /// Tests the scenario where a worker fails mid-stream indefinitely during ongoing requests. + /// The RetryManager should exhaust all retries and return the original stream disconnection error. + /// Expected behavior: Should receive some responses from first stream, then error after retries exhausted. + #[tokio::test] + async fn test_retry_manager_ongoing_request_migration_indefinite_failure() { + let request = create_mock_request(); + let mock_engine = Arc::new(MockEngine::new( + MockBehavior::MidStreamFailAlways { fail_after: 3 }, + 10, + 100, + )); + let next_generate: ServerStreamingEngine> = + mock_engine; + + let mut retry_manager = RetryManager::build(request, next_generate, 3) // 3 retries + .await + .expect("Failed to build RetryManager"); + + let mut responses = Vec::new(); + + // Collect all responses (both successful and error responses) + while let Some(response) = retry_manager.next().await { + responses.push(response); + } + + // Should have received 4 total responses: 3 successful + 1 error + assert_eq!(responses.len(), 4); + + // First 3 responses should be successful with tokens 101, 102, 103 + for (i, response) in responses[0..3].iter().enumerate() { + assert!(response.err().is_none()); + if let Some(output) = &response.data { + assert_eq!(output.token_ids, vec![101 + i as u32]); // 101, 102, 103 + } + } + + // 4th response should be an error after retries are exhausted + let error_response = &responses[3]; + assert!(error_response.err().is_some()); + if let Some(error) = error_response.err() { + assert!(error + .to_string() + .contains("Stream ended before generation completed")); + } + } + + /// Test case 6: Ongoing request migration - indefinite failure with stream errors + /// Tests the scenario where a worker fails mid-stream indefinitely during ongoing requests, + /// and all retry attempts also fail with stream errors instead of NATS errors. + /// Expected behavior: Should receive some responses from first stream, then error after retries exhausted. + #[tokio::test] + async fn test_retry_manager_ongoing_request_migration_indefinite_failure_stream_error() { + let request = create_mock_request(); + let mock_engine = Arc::new(MockEngine::new( + MockBehavior::MidStreamFailAlwaysStreamError { fail_after: 3 }, + 10, + 100, + )); + let next_generate: ServerStreamingEngine> = + mock_engine; + + let mut retry_manager = RetryManager::build(request, next_generate, 3) // 3 retries + .await + .expect("Failed to build RetryManager"); + + let mut responses = Vec::new(); + + // Collect all responses (both successful and error responses) + while let Some(response) = retry_manager.next().await { + responses.push(response); + } + + // Should have received 4 total responses: 3 successful + 1 error + assert_eq!(responses.len(), 4); + + // First 3 responses should be successful with tokens 101, 102, 103 + for (i, response) in responses[0..3].iter().enumerate() { + assert!(response.err().is_none()); + if let Some(output) = &response.data { + assert_eq!(output.token_ids, vec![101 + i as u32]); // 101, 102, 103 + } + } + + // 4th response should be an error after retries are exhausted + let error_response = &responses[3]; + assert!(error_response.err().is_some()); + if let Some(error) = error_response.err() { + assert!(error + .to_string() + .contains("Stream ended before generation completed")); + } + } +} diff --git a/lib/llm/src/model_card/create.rs b/lib/llm/src/model_card/create.rs index b69525beb7..d628a0178b 100644 --- a/lib/llm/src/model_card/create.rs +++ b/lib/llm/src/model_card/create.rs @@ -92,6 +92,7 @@ impl ModelDeploymentCard { last_published: None, context_length, kv_cache_block_size: 0, + migration_limit: 0, }) } @@ -131,6 +132,7 @@ impl ModelDeploymentCard { last_published: None, context_length, kv_cache_block_size: 0, // set later + migration_limit: 0, }) } } diff --git a/lib/llm/src/model_card/model.rs b/lib/llm/src/model_card/model.rs index 6fd6efe38d..08efc79484 100644 --- a/lib/llm/src/model_card/model.rs +++ b/lib/llm/src/model_card/model.rs @@ -127,6 +127,10 @@ pub struct ModelDeploymentCard { /// Size of a KV cache block - vllm only currently /// Passed to the engine and the KV router. pub kv_cache_block_size: u32, + + /// How many times a request can be migrated to another worker if the HTTP server lost + /// connection to the current worker. + pub migration_limit: u32, } impl ModelDeploymentCard { diff --git a/lib/llm/src/protocols/common/llm_backend.rs b/lib/llm/src/protocols/common/llm_backend.rs index 17fe94970f..0583d18347 100644 --- a/lib/llm/src/protocols/common/llm_backend.rs +++ b/lib/llm/src/protocols/common/llm_backend.rs @@ -136,11 +136,11 @@ impl LLMEngineOutput { } impl MaybeError for LLMEngineOutput { - fn from_err(err: Box) -> Self { + fn from_err(err: Box) -> Self { LLMEngineOutput::error(format!("{:?}", err)) } - fn err(&self) -> Option> { + fn err(&self) -> Option> { if let Some(FinishReason::Error(err_msg)) = &self.finish_reason { Some(anyhow::Error::msg(err_msg.clone()).into()) } else { diff --git a/lib/runtime/src/component/client.rs b/lib/runtime/src/component/client.rs index e5c9941751..ab57a1e3a6 100644 --- a/lib/runtime/src/component/client.rs +++ b/lib/runtime/src/component/client.rs @@ -6,14 +6,8 @@ use crate::pipeline::{ SingleIn, }; use arc_swap::ArcSwap; -use rand::Rng; use std::collections::HashMap; -use std::sync::RwLock; -use std::sync::{ - atomic::{AtomicU64, Ordering}, - Arc, Mutex, -}; -use std::time::Instant; +use std::sync::Arc; use tokio::net::unix::pipe::Receiver; use crate::{ @@ -48,10 +42,8 @@ pub struct Client { pub endpoint: Endpoint, // These are the remotes I know about from watching etcd pub instance_source: Arc, - // These are the instances that are reported as down from sending rpc - instance_inhibited: Arc>>, - // The current active IDs - instance_cache: Arc>>, + // These are the instance source ids less those reported as down from sending rpc + instance_avail: Arc>>, } #[derive(Clone, Debug)] @@ -60,16 +52,13 @@ pub enum InstanceSource { Dynamic(tokio::sync::watch::Receiver>), } -// TODO: Avoid returning a full clone of `Vec` everytime from Client -// See instances() and instances_avail() methods impl Client { // Client will only talk to a single static endpoint pub(crate) async fn new_static(endpoint: Endpoint) -> Result { Ok(Client { endpoint, instance_source: Arc::new(InstanceSource::Static), - instance_inhibited: Arc::new(Mutex::new(HashMap::new())), - instance_cache: Arc::new(ArcSwap::from(Arc::new(vec![]))), + instance_avail: Arc::new(ArcSwap::from(Arc::new(vec![]))), }) } @@ -85,26 +74,12 @@ impl Client { let instance_source = Self::get_or_create_dynamic_instance_source(etcd_client, &endpoint).await?; - let cancel_token = endpoint.drt().primary_token(); let client = Client { endpoint, instance_source, - instance_inhibited: Arc::new(Mutex::new(HashMap::new())), - instance_cache: Arc::new(ArcSwap::from(Arc::new(vec![]))), + instance_avail: Arc::new(ArcSwap::from(Arc::new(vec![]))), }; - - let instance_source_c = client.instance_source.clone(); - let instance_inhibited_c = Arc::clone(&client.instance_inhibited); - let instance_cache_c = Arc::clone(&client.instance_cache); - tokio::task::spawn(async move { - while !cancel_token.is_cancelled() { - refresh_instances(&instance_source_c, &instance_inhibited_c, &instance_cache_c); - tokio::select! { - _ = cancel_token.cancelled() => {} - _ = tokio::time::sleep(INSTANCE_REFRESH_PERIOD) => {} - } - } - }); + client.monitor_instance_source(); Ok(client) } @@ -119,13 +94,20 @@ impl Client { /// Instances available from watching etcd pub fn instances(&self) -> Vec { - instances_inner(self.instance_source.as_ref()) + match self.instance_source.as_ref() { + InstanceSource::Static => vec![], + InstanceSource::Dynamic(watch_rx) => watch_rx.borrow().clone(), + } } pub fn instance_ids(&self) -> Vec { self.instances().into_iter().map(|ep| ep.id()).collect() } + pub fn instance_ids_avail(&self) -> arc_swap::Guard>> { + self.instance_avail.load() + } + /// Wait for at least one Instance to be available for this Endpoint pub async fn wait_for_instances(&self) -> Result> { let mut instances: Vec = vec![]; @@ -143,24 +125,51 @@ impl Client { Ok(instances) } - /// Instances available from watching etcd minus those reported as down - pub fn instance_ids_avail(&self) -> arc_swap::Guard>> { - self.instance_cache.load() + /// Is this component know at startup and not discovered via etcd? + pub fn is_static(&self) -> bool { + matches!(self.instance_source.as_ref(), InstanceSource::Static) } /// Mark an instance as down/unavailable pub fn report_instance_down(&self, instance_id: i64) { - self.instance_inhibited - .lock() - .unwrap() - .insert(instance_id, Instant::now()); + let filtered = self + .instance_ids_avail() + .iter() + .filter_map(|&id| if id == instance_id { None } else { Some(id) }) + .collect::>(); + self.instance_avail.store(Arc::new(filtered)); tracing::debug!("inhibiting instance {instance_id}"); } - /// Is this component know at startup and not discovered via etcd? - pub fn is_static(&self) -> bool { - matches!(self.instance_source.as_ref(), InstanceSource::Static) + /// Monitor the ETCD instance source and update instance_avail. + fn monitor_instance_source(&self) { + let cancel_token = self.endpoint.drt().primary_token(); + let client = self.clone(); + tokio::task::spawn(async move { + let mut rx = match client.instance_source.as_ref() { + InstanceSource::Static => { + tracing::error!("Static instance source is not watchable"); + return; + } + InstanceSource::Dynamic(rx) => rx.clone(), + }; + while !cancel_token.is_cancelled() { + let instance_ids: Vec = rx + .borrow_and_update() + .iter() + .map(|instance| instance.id()) + .collect(); + client.instance_avail.store(Arc::new(instance_ids)); + + tracing::debug!("instance source updated"); + + if let Err(err) = rx.changed().await { + tracing::error!("The Sender is dropped: {}", err); + cancel_token.cancel(); + } + } + }); } async fn get_or_create_dynamic_instance_source( @@ -253,49 +262,3 @@ impl Client { Ok(instance_source) } } - -/// Update the instance id cache -fn refresh_instances( - instance_source: &InstanceSource, - instance_inhibited: &Arc>>, - instance_cache: &Arc>>, -) { - const ETCD_LEASE_TTL: u64 = 10; // seconds - - // TODO: Can we get the remaining TTL from the lease for the instance? - let now = Instant::now(); - - let instances = instances_inner(instance_source); - let mut inhibited = instance_inhibited.lock().unwrap(); - - // 1. Remove inhibited instances that are no longer in `self.instances()` - // 2. Remove inhibited instances that have expired - // 3. Only return instances that are not inhibited after removals - let mut new_inhibited = HashMap::::new(); - let filtered: Vec = instances - .into_iter() - .filter_map(|instance| { - let id = instance.id(); - if let Some(×tamp) = inhibited.get(&id) { - if now.duration_since(timestamp).as_secs() > ETCD_LEASE_TTL { - Some(id) - } else { - new_inhibited.insert(id, timestamp); - None - } - } else { - Some(id) - } - }) - .collect(); - - *inhibited = new_inhibited; - instance_cache.store(Arc::new(filtered)); -} - -fn instances_inner(instance_source: &InstanceSource) -> Vec { - match instance_source { - InstanceSource::Static => vec![], - InstanceSource::Dynamic(watch_rx) => watch_rx.borrow().clone(), - } -} diff --git a/lib/runtime/src/pipeline/network/egress/push_router.rs b/lib/runtime/src/pipeline/network/egress/push_router.rs index 804261e5b1..a55365eac7 100644 --- a/lib/runtime/src/pipeline/network/egress/push_router.rs +++ b/lib/runtime/src/pipeline/network/egress/push_router.rs @@ -178,20 +178,14 @@ where Ok(stream) => { let engine_ctx = stream.context(); let client = self.client.clone(); - let stream = stream.then(move |res| { - let mut report_instance_down: Option<(Client, i64)> = None; + let stream = stream.map(move |res| { if let Some(err) = res.err() { const STREAM_ERR_MSG: &str = "Stream ended before generation completed"; if format!("{:?}", err) == STREAM_ERR_MSG { - report_instance_down = Some((client.clone(), instance_id)); - } - } - async move { - if let Some((client, instance_id)) = report_instance_down { client.report_instance_down(instance_id); } - res } + res }); Ok(ResponseStream::new(Box::pin(stream), engine_ctx)) } diff --git a/lib/runtime/src/protocols/annotated.rs b/lib/runtime/src/protocols/annotated.rs index 0aa7305b0b..4e9cffc418 100644 --- a/lib/runtime/src/protocols/annotated.rs +++ b/lib/runtime/src/protocols/annotated.rs @@ -151,11 +151,11 @@ impl MaybeError for Annotated where R: for<'de> Deserialize<'de> + Serialize, { - fn from_err(err: Box) -> Self { + fn from_err(err: Box) -> Self { Annotated::from_error(format!("{:?}", err)) } - fn err(&self) -> Option> { + fn err(&self) -> Option> { if self.is_error() { if let Some(comment) = &self.comment { if !comment.is_empty() { diff --git a/lib/runtime/src/protocols/maybe_error.rs b/lib/runtime/src/protocols/maybe_error.rs index 8c3e45357b..068fbadc60 100644 --- a/lib/runtime/src/protocols/maybe_error.rs +++ b/lib/runtime/src/protocols/maybe_error.rs @@ -17,10 +17,10 @@ use std::error::Error; pub trait MaybeError { /// Construct an instance from an error. - fn from_err(err: Box) -> Self; + fn from_err(err: Box) -> Self; /// Construct into an error instance. - fn err(&self) -> Option>; + fn err(&self) -> Option>; /// Check if the current instance represents a success. fn is_ok(&self) -> bool { @@ -41,12 +41,12 @@ mod tests { message: String, } impl MaybeError for TestError { - fn from_err(err: Box) -> Self { + fn from_err(err: Box) -> Self { TestError { message: err.to_string(), } } - fn err(&self) -> Option> { + fn err(&self) -> Option> { Some(anyhow::Error::msg(self.message.clone()).into()) } }