diff --git a/components/frontend/src/dynamo/frontend/main.py b/components/frontend/src/dynamo/frontend/main.py index f8960cf0f5..a52ad577b2 100644 --- a/components/frontend/src/dynamo/frontend/main.py +++ b/components/frontend/src/dynamo/frontend/main.py @@ -143,6 +143,12 @@ def parse_args(): default=False, help="KV Router: Enable replica synchronization across multiple router instances. When true, routers will publish and subscribe to events to maintain consistent state.", ) + parser.add_argument( + "--busy-threshold", + type=float, + default=None, + help="Threshold (0.0-1.0) for determining when a worker is considered busy based on KV cache usage. If not set, busy detection is disabled.", + ) parser.add_argument( "--static-endpoint", type=validate_static_endpoint, @@ -205,7 +211,9 @@ async def async_main(): kwargs = { "http_port": flags.http_port, "kv_cache_block_size": flags.kv_cache_block_size, - "router_config": RouterConfig(router_mode, kv_router_config), + "router_config": RouterConfig( + router_mode, kv_router_config, flags.busy_threshold + ), } if flags.static_endpoint: diff --git a/lib/bindings/python/rust/llm/entrypoint.rs b/lib/bindings/python/rust/llm/entrypoint.rs index 2d51e2bf8e..6f0e73481d 100644 --- a/lib/bindings/python/rust/llm/entrypoint.rs +++ b/lib/bindings/python/rust/llm/entrypoint.rs @@ -60,16 +60,22 @@ impl KvRouterConfig { pub struct RouterConfig { router_mode: RouterMode, kv_router_config: KvRouterConfig, + busy_threshold: Option, } #[pymethods] impl RouterConfig { #[new] - #[pyo3(signature = (mode, config=None))] - pub fn new(mode: RouterMode, config: Option) -> Self { + #[pyo3(signature = (mode, config=None, busy_threshold=None))] + pub fn new( + mode: RouterMode, + config: Option, + busy_threshold: Option, + ) -> Self { Self { router_mode: mode, kv_router_config: config.unwrap_or_default(), + busy_threshold, } } } @@ -79,6 +85,7 @@ impl From for RsRouterConfig { RsRouterConfig { router_mode: rc.router_mode.into(), kv_router_config: rc.kv_router_config.inner, + busy_threshold: rc.busy_threshold, } } } diff --git a/lib/llm/src/discovery/watcher.rs b/lib/llm/src/discovery/watcher.rs index 2642ceee7e..c05ad190d2 100644 --- a/lib/llm/src/discovery/watcher.rs +++ b/lib/llm/src/discovery/watcher.rs @@ -50,6 +50,7 @@ pub struct ModelWatcher { notify_on_model: Notify, model_update_tx: Option>, kv_router_config: Option, + busy_threshold: Option, } const ALL_MODEL_TYPES: &[ModelType] = @@ -61,6 +62,7 @@ impl ModelWatcher { model_manager: Arc, router_mode: RouterMode, kv_router_config: Option, + busy_threshold: Option, ) -> ModelWatcher { Self { manager: model_manager, @@ -69,6 +71,7 @@ impl ModelWatcher { notify_on_model: Notify::new(), model_update_tx: None, kv_router_config, + busy_threshold, } } @@ -316,21 +319,31 @@ impl ModelWatcher { None }; - let chat_engine = - entrypoint::build_routed_pipeline::< - NvCreateChatCompletionRequest, - NvCreateChatCompletionStreamResponse, - >(&card, &client, self.router_mode, kv_chooser.clone()) - .await?; + let chat_engine = entrypoint::build_routed_pipeline::< + NvCreateChatCompletionRequest, + NvCreateChatCompletionStreamResponse, + >( + &card, + &client, + self.router_mode, + self.busy_threshold, + kv_chooser.clone(), + ) + .await?; self.manager .add_chat_completions_model(&model_entry.name, chat_engine)?; - let completions_engine = - entrypoint::build_routed_pipeline::< - NvCreateCompletionRequest, - NvCreateCompletionResponse, - >(&card, &client, self.router_mode, kv_chooser) - .await?; + let completions_engine = entrypoint::build_routed_pipeline::< + NvCreateCompletionRequest, + NvCreateCompletionResponse, + >( + &card, + &client, + self.router_mode, + self.busy_threshold, + kv_chooser, + ) + .await?; self.manager .add_completions_model(&model_entry.name, completions_engine)?; } @@ -338,7 +351,9 @@ impl ModelWatcher { let push_router = PushRouter::< NvCreateChatCompletionRequest, Annotated, - >::from_client(client, Default::default()) + >::from_client_with_threshold( + client, Default::default(), self.busy_threshold + ) .await?; let engine = Arc::new(push_router); self.manager @@ -348,7 +363,9 @@ impl ModelWatcher { let push_router = PushRouter::< NvCreateCompletionRequest, Annotated, - >::from_client(client, Default::default()) + >::from_client_with_threshold( + client, Default::default(), self.busy_threshold + ) .await?; let engine = Arc::new(push_router); self.manager @@ -374,7 +391,9 @@ impl ModelWatcher { let router = PushRouter::< PreprocessedEmbeddingRequest, Annotated, - >::from_client(client, self.router_mode) + >::from_client_with_threshold( + client, self.router_mode, self.busy_threshold + ) .await?; // Note: Embeddings don't need KV routing complexity diff --git a/lib/llm/src/entrypoint.rs b/lib/llm/src/entrypoint.rs index 381aab567a..d81cc7cb66 100644 --- a/lib/llm/src/entrypoint.rs +++ b/lib/llm/src/entrypoint.rs @@ -21,6 +21,7 @@ use crate::{ pub struct RouterConfig { pub router_mode: RouterMode, pub kv_router_config: KvRouterConfig, + pub busy_threshold: Option, } impl RouterConfig { @@ -28,8 +29,14 @@ impl RouterConfig { Self { router_mode, kv_router_config, + busy_threshold: None, } } + + pub fn with_busy_threshold(mut self, threshold: Option) -> Self { + self.busy_threshold = threshold; + self + } } #[derive(Clone)] diff --git a/lib/llm/src/entrypoint/input/common.rs b/lib/llm/src/entrypoint/input/common.rs index b060957797..16aa6ad5ee 100644 --- a/lib/llm/src/entrypoint/input/common.rs +++ b/lib/llm/src/entrypoint/input/common.rs @@ -71,6 +71,7 @@ pub async fn prepare_engine( model_manager.clone(), dynamo_runtime::pipeline::RouterMode::RoundRobin, None, + None, )); let models_watcher = etcd_client.kv_get_and_watch_prefix(MODEL_ROOT_PATH).await?; let (_prefix, _watcher, receiver) = models_watcher.dissolve(); @@ -133,7 +134,7 @@ pub async fn prepare_engine( let chat_engine = entrypoint::build_routed_pipeline::< NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse, - >(card, &client, router_mode, kv_chooser.clone()) + >(card, &client, router_mode, None, kv_chooser.clone()) .await?; let service_name = local_model.service_name().to_string(); @@ -216,6 +217,7 @@ pub async fn build_routed_pipeline( card: &ModelDeploymentCard, client: &Client, router_mode: RouterMode, + busy_threshold: Option, chooser: Option>, ) -> anyhow::Result, ManyOut>>> where @@ -232,11 +234,13 @@ where let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator(); let backend = Backend::from_mdc(card.clone()).await?.into_operator(); let migration = Migration::from_mdc(card.clone()).await?.into_operator(); - let router = PushRouter::>::from_client( - client.clone(), - router_mode, - ) - .await?; + let router = + PushRouter::>::from_client_with_threshold( + client.clone(), + router_mode, + busy_threshold, + ) + .await?; let service_backend = match router_mode { RouterMode::Random | RouterMode::RoundRobin | RouterMode::Direct(_) => { ServiceBackend::from_engine(Arc::new(router)) diff --git a/lib/llm/src/entrypoint/input/http.rs b/lib/llm/src/entrypoint/input/http.rs index 8aa545e7ce..de46c92699 100644 --- a/lib/llm/src/entrypoint/input/http.rs +++ b/lib/llm/src/entrypoint/input/http.rs @@ -66,6 +66,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul MODEL_ROOT_PATH, router_config.router_mode, Some(router_config.kv_router_config), + router_config.busy_threshold, Arc::new(http_service.clone()), ) .await?; @@ -109,14 +110,14 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul let chat_engine = entrypoint::build_routed_pipeline::< NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse, - >(card, &client, router_mode, kv_chooser.clone()) + >(card, &client, router_mode, None, kv_chooser.clone()) .await?; manager.add_chat_completions_model(local_model.display_name(), chat_engine)?; let completions_engine = entrypoint::build_routed_pipeline::< NvCreateCompletionRequest, NvCreateCompletionResponse, - >(card, &client, router_mode, kv_chooser) + >(card, &client, router_mode, None, kv_chooser) .await?; manager.add_completions_model(local_model.display_name(), completions_engine)?; @@ -188,6 +189,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul /// Spawns a task that watches for new models in etcd at network_prefix, /// and registers them with the ModelManager so that the HTTP service can use them. +#[allow(clippy::too_many_arguments)] async fn run_watcher( runtime: DistributedRuntime, model_manager: Arc, @@ -195,9 +197,16 @@ async fn run_watcher( network_prefix: &str, router_mode: RouterMode, kv_router_config: Option, + busy_threshold: Option, http_service: Arc, ) -> anyhow::Result<()> { - let mut watch_obj = ModelWatcher::new(runtime, model_manager, router_mode, kv_router_config); + let mut watch_obj = ModelWatcher::new( + runtime, + model_manager, + router_mode, + kv_router_config, + busy_threshold, + ); tracing::info!("Watching for remote model at {network_prefix}"); let models_watcher = etcd_client.kv_get_and_watch_prefix(network_prefix).await?; let (_prefix, _watcher, receiver) = models_watcher.dissolve(); diff --git a/lib/llm/src/http/service/openai.rs b/lib/llm/src/http/service/openai.rs index d3f825ebd3..e41baad5f5 100644 --- a/lib/llm/src/http/service/openai.rs +++ b/lib/llm/src/http/service/openai.rs @@ -108,6 +108,24 @@ impl ErrorMessage { /// If successful, it will return the [`HttpError`] as an [`ErrorMessage::internal_server_error`] /// with the details of the error. pub fn from_anyhow(err: anyhow::Error, alt_msg: &str) -> ErrorResponse { + // First check for PipelineError::ServiceOverloaded + if let Some(pipeline_err) = + err.downcast_ref::() + { + if matches!( + pipeline_err, + dynamo_runtime::pipeline::error::PipelineError::ServiceOverloaded(_) + ) { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ErrorMessage { + error: pipeline_err.to_string(), + }), + ); + } + } + + // Then check for HttpError match err.downcast::() { Ok(http_error) => ErrorMessage::from_http_error(http_error), Err(err) => ErrorMessage::internal_server_error(&format!("{alt_msg}: {err}")), @@ -1150,6 +1168,22 @@ mod tests { ); } + #[test] + fn test_service_overloaded_error_response_from_anyhow() { + use dynamo_runtime::pipeline::error::PipelineError; + + let err: anyhow::Error = PipelineError::ServiceOverloaded( + "All workers are busy, please retry later".to_string(), + ) + .into(); + let (status, response) = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE); + assert_eq!(status, StatusCode::SERVICE_UNAVAILABLE); + assert_eq!( + response.error, + "Service temporarily unavailable: All workers are busy, please retry later" + ); + } + #[test] fn test_validate_input_is_text_only_accepts_text() { let request = make_base_request(); diff --git a/lib/llm/src/kv_router.rs b/lib/llm/src/kv_router.rs index 4bee158ea9..66adadb33c 100644 --- a/lib/llm/src/kv_router.rs +++ b/lib/llm/src/kv_router.rs @@ -29,13 +29,13 @@ pub mod scoring; pub mod sequence; use crate::{ + discovery::{ModelEntry, MODEL_ROOT_PATH}, kv_router::{ approx::ApproxKvIndexer, indexer::{ compute_block_hash_for_seq, compute_seq_hash_for_block, KvIndexer, KvIndexerInterface, KvRouterError, OverlapScores, RouterEvent, }, - metrics_aggregator::watch_model_runtime_configs, protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult}, scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest}, scoring::ProcessedEndpoints, @@ -177,14 +177,25 @@ impl KvRouter { } }; - // Create runtime config watcher + // Create runtime config watcher using the generic etcd watcher // TODO: Migrate to discovery_client() once it exposes kv_get_and_watch_prefix functionality let etcd_client = component .drt() .etcd_client() .expect("Cannot KV route without etcd client"); - let runtime_configs_rx = - watch_model_runtime_configs(etcd_client, cancellation_token.clone()).await?; + + use dynamo_runtime::utils::typed_prefix_watcher::{ + key_extractors, watch_prefix_with_extraction, + }; + let runtime_configs_watcher = watch_prefix_with_extraction( + etcd_client, + MODEL_ROOT_PATH, + key_extractors::lease_id, + |model_entry: ModelEntry| model_entry.runtime_config, + cancellation_token.clone(), + ) + .await?; + let runtime_configs_rx = runtime_configs_watcher.receiver(); let indexer = if kv_router_config.use_kv_events { Indexer::KvIndexer(KvIndexer::new(cancellation_token.clone(), block_size)) diff --git a/lib/llm/src/kv_router/metrics_aggregator.rs b/lib/llm/src/kv_router/metrics_aggregator.rs index 9fa3f11865..7ab4e1372c 100644 --- a/lib/llm/src/kv_router/metrics_aggregator.rs +++ b/lib/llm/src/kv_router/metrics_aggregator.rs @@ -18,14 +18,10 @@ use std::sync::Once; pub use crate::kv_router::protocols::{ForwardPassMetrics, LoadMetrics, PredictiveLoadMetrics}; use crate::kv_router::KV_METRICS_ENDPOINT; -use crate::discovery::{ModelEntry, MODEL_ROOT_PATH}; use crate::kv_router::scoring::Endpoint; use crate::kv_router::ProcessedEndpoints; -use crate::local_model::runtime_config::ModelRuntimeConfig; use dynamo_runtime::component::Component; -use dynamo_runtime::transports::etcd::{Client as EtcdClient, WatchEvent}; use dynamo_runtime::{service::EndpointInfo, utils::Duration, Result}; -use std::collections::HashMap; use tokio::sync::watch; use tokio_util::sync::CancellationToken; @@ -212,71 +208,3 @@ pub async fn collect_endpoints_task( } } } - -pub async fn watch_model_runtime_configs( - etcd_client: EtcdClient, - cancellation_token: CancellationToken, -) -> Result>> { - let (watch_tx, watch_rx) = watch::channel(HashMap::new()); - - let prefix_watcher = etcd_client.kv_get_and_watch_prefix(MODEL_ROOT_PATH).await?; - let (_prefix, _watcher, mut events_rx) = prefix_watcher.dissolve(); - - tokio::spawn(async move { - let mut runtime_configs: HashMap = HashMap::new(); - - loop { - tokio::select! { - _ = cancellation_token.cancelled() => { - tracing::debug!("Runtime config watcher cancelled"); - break; - } - event = events_rx.recv() => { - let Some(event) = event else { - tracing::debug!("Runtime config watch stream closed"); - break; - }; - - match event { - WatchEvent::Put(kv) => { - let Ok(model_entry) = serde_json::from_slice::(kv.value()) else { - tracing::warn!( - "Failed to parse ModelEntry from etcd. Key: {}", - kv.key_str().unwrap_or("") - ); - continue; - }; - - let lease_id = kv.lease(); - - if let Some(runtime_config) = model_entry.runtime_config { - runtime_configs.insert(lease_id, runtime_config); - tracing::trace!("Updated runtime config for lease_id: {}", lease_id); - } else { - runtime_configs.remove(&lease_id); - tracing::trace!("Removed runtime config (no config in ModelEntry)"); - } - - if watch_tx.send(runtime_configs.clone()).is_err() { - tracing::error!("Failed to send runtime configs update; receiver dropped"); - break; - } - } - WatchEvent::Delete(kv) => { - let lease_id = kv.lease(); - runtime_configs.remove(&lease_id); - tracing::trace!("Removed runtime config for deleted entry"); - - if watch_tx.send(runtime_configs.clone()).is_err() { - tracing::error!("Failed to send runtime configs update; receiver dropped"); - break; - } - } - } - } - } - } - }); - - Ok(watch_rx) -} diff --git a/lib/llm/src/kv_router/publisher.rs b/lib/llm/src/kv_router/publisher.rs index 22c88b08d0..ab1e04ff4b 100644 --- a/lib/llm/src/kv_router/publisher.rs +++ b/lib/llm/src/kv_router/publisher.rs @@ -503,14 +503,16 @@ impl WorkerMetricsPublisher { let handler = Arc::new(KvLoadEndpointHandler::new(metrics_rx.clone())); let handler = Ingress::for_engine(handler)?; - // let worker_id = component - // .drt() - // .primary_lease() - // .map(|lease| lease.id()) - // .unwrap_or_else(|| { - // tracing::warn!("Component is static, assuming worker_id of 0"); - // 0 - // }); + let worker_id = component + .drt() + .primary_lease() + .map(|lease| lease.id()) + .unwrap_or_else(|| { + tracing::warn!("Component is static, assuming worker_id of 0"); + 0 + }); + + self.start_nats_metrics_publishing(component.namespace().clone(), worker_id); component .endpoint(KV_METRICS_ENDPOINT) diff --git a/lib/llm/src/mocker/engine.rs b/lib/llm/src/mocker/engine.rs index c7467daa86..bc769d2d37 100644 --- a/lib/llm/src/mocker/engine.rs +++ b/lib/llm/src/mocker/engine.rs @@ -42,8 +42,8 @@ use futures::StreamExt; use rand::Rng; use std::collections::HashMap; use std::sync::Arc; +use std::time::Duration; use tokio::sync::{mpsc, Mutex, OnceCell}; -use tokio::time::{interval, Duration}; use tokio_stream::wrappers::ReceiverStream; use uuid::Uuid; @@ -174,7 +174,7 @@ impl MockVllmEngine { (schedulers, kv_event_receivers) } - /// Start background tasks to poll and publish metrics every second + /// Start background tasks to publish metrics on change async fn start_metrics_publishing( schedulers: &[Scheduler], component: Option, @@ -202,19 +202,18 @@ impl MockVllmEngine { tracing::info!("Starting metrics background tasks"); for (dp_rank, scheduler) in schedulers.iter().enumerate() { - let scheduler = scheduler.clone(); + let mut metrics_rx = scheduler.metrics_receiver(); let publisher = metrics_publisher.clone(); let dp_rank = dp_rank as u32; let cancel_token = cancel_token.clone(); tokio::spawn(async move { - let mut interval = interval(Duration::from_millis(100)); - loop { tokio::select! { - _ = interval.tick() => { - // Get metrics from scheduler - let metrics = scheduler.get_forward_pass_metrics().await; + // Watch for metrics changes + Ok(_) = metrics_rx.changed() => { + // Get the latest metrics + let metrics = metrics_rx.borrow().clone(); // Publish metrics if let Err(e) = publisher.publish(Arc::new(metrics)) { @@ -568,7 +567,7 @@ mod integration_tests { let engine = MockVllmEngine::new(args); engine.start(test_component.clone()).await?; - tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; + tokio::time::sleep(Duration::from_millis(500)).await; let engine = Arc::new(engine); tracing::info!("✓ MockVllmEngine created with DP_SIZE: {DP_SIZE}"); @@ -598,7 +597,7 @@ mod integration_tests { tracing::info!("✓ Server started in background"); // Give server time to start - tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; + tokio::time::sleep(Duration::from_millis(500)).await; tracing::info!("✓ Server startup delay completed"); // Print all registered instances from etcd @@ -733,7 +732,7 @@ mod integration_tests { cancel_token, ) .await; - tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; + tokio::time::sleep(Duration::from_millis(500)).await; let processed_endpoints = metrics_aggregator.get_endpoints(); tracing::info!( diff --git a/lib/llm/src/mocker/scheduler.rs b/lib/llm/src/mocker/scheduler.rs index 97ca4c4e13..063f41f7bd 100644 --- a/lib/llm/src/mocker/scheduler.rs +++ b/lib/llm/src/mocker/scheduler.rs @@ -250,11 +250,10 @@ impl SchedulerState { /// Manages scheduling of requests using KvManager resources #[derive(Clone)] pub struct Scheduler { - dp_rank: Option, state: Arc>, kv_manager: Arc>, request_tx: mpsc::UnboundedSender, - hit_rates: Arc>>, + metrics_rx: tokio::sync::watch::Receiver, } impl Scheduler { @@ -292,13 +291,16 @@ impl Scheduler { // Create channel for request handling let (request_tx, mut request_rx) = mpsc::unbounded_channel::(); + let mut initial_metrics = ForwardPassMetrics::default(); + initial_metrics.worker_stats.data_parallel_rank = dp_rank; + let (metrics_tx, metrics_rx) = + tokio::sync::watch::channel::(initial_metrics); // Create a clone for the background task let state_clone = state.clone(); let kv_manager_clone = kv_manager.clone(); let output_tx_clone = output_tx.clone(); let cancel_token_clone = cancellation_token.unwrap_or_default().clone(); - let hit_rates_clone = hit_rates.clone(); // Spawn main background task with cancellation token tokio::spawn(async move { @@ -376,7 +378,7 @@ impl Scheduler { // Compute and store hit rate let hit_rate = if !active_sequence.is_empty() { 1.0 - (new_tokens as f32 / active_sequence.len() as f32) } else { 0.0 }; { - let mut hit_rates_guard = hit_rates_clone.lock().await; + let mut hit_rates_guard = hit_rates.lock().await; hit_rates_guard.push_back(hit_rate); if hit_rates_guard.len() > 1000 { hit_rates_guard.pop_front(); @@ -442,6 +444,17 @@ impl Scheduler { state_guard.reset_active_tokens(); + { + let hit_rates_guard = hit_rates.lock().await; + let metrics = get_fwd_pass_metrics( + &state_guard, + &kv_manager_guard, + &hit_rates_guard, + dp_rank, + ); + let _ = metrics_tx.send(metrics); + } + // Process decoding let uuids: Vec = state_guard.decode.keys().cloned().collect(); if !uuids.is_empty() { @@ -495,6 +508,17 @@ impl Scheduler { } } + { + let hit_rates_guard = hit_rates.lock().await; + let metrics = get_fwd_pass_metrics( + &state_guard, + &kv_manager_guard, + &hit_rates_guard, + dp_rank, + ); + let _ = metrics_tx.send(metrics); + } + if send_failed || is_complete { state_guard.complete(&uuid); continue; @@ -513,11 +537,10 @@ impl Scheduler { }); Self { - dp_rank, state, kv_manager, request_tx, - hit_rates, + metrics_rx, } } @@ -555,56 +578,60 @@ impl Scheduler { kv_manager.current_capacity_perc() } - /// Returns forward pass metrics for monitoring purposes - pub async fn get_forward_pass_metrics(&self) -> ForwardPassMetrics { - // Acquire all locks in consistent order: state -> kv_manager -> hit_rates - let state = self.state.lock().await; - let kv_manager = self.kv_manager.lock().await; - let hit_rates_guard = self.hit_rates.lock().await; - - // Get state metrics - let request_active_slots = state.decode.len() as u64; - let num_requests_waiting = state.waiting.len() as u64; + /// Get a watch receiver for forward pass metrics + pub fn metrics_receiver(&self) -> tokio::sync::watch::Receiver { + self.metrics_rx.clone() + } +} - // Get KV manager metrics - let active_blocks_count = kv_manager.active_blocks().len() as u64; - let total_capacity = kv_manager.max_capacity() as u64; - let gpu_cache_usage_perc = if total_capacity > 0 { - active_blocks_count as f32 / total_capacity as f32 - } else { - 0.0 - }; +/// Calculate forward pass metrics from current state +fn get_fwd_pass_metrics( + state: &SchedulerState, + kv_manager: &KvManager, + hit_rates: &VecDeque, + dp_rank: Option, +) -> ForwardPassMetrics { + // Get state metrics + let request_active_slots = state.decode.len() as u64; + let num_requests_waiting = state.waiting.len() as u64; + + // Get KV manager metrics + let active_blocks_count = kv_manager.active_blocks().len() as u64; + let total_capacity = kv_manager.max_capacity() as u64; + let gpu_cache_usage_perc = if total_capacity > 0 { + active_blocks_count as f32 / total_capacity as f32 + } else { + 0.0 + }; - // Get hit rate metrics - let gpu_prefix_cache_hit_rate = if hit_rates_guard.is_empty() { - 0.0 - } else { - let sum: f32 = hit_rates_guard.iter().sum(); - sum / hit_rates_guard.len() as f32 - }; + // Get hit rate metrics + let gpu_prefix_cache_hit_rate = if hit_rates.is_empty() { + 0.0 + } else { + let sum: f32 = hit_rates.iter().sum(); + sum / hit_rates.len() as f32 + }; - let worker_stats = WorkerStats { - data_parallel_rank: self.dp_rank, - request_active_slots, - request_total_slots: 1024, // vllm max_num_seqs for gpu >= 70 vram, otherwise 256, fallback is 128 - num_requests_waiting, - }; + let worker_stats = WorkerStats { + data_parallel_rank: dp_rank, + request_active_slots, + request_total_slots: 1024, // vllm max_num_seqs for gpu >= 70 vram, otherwise 256, fallback is 128 + num_requests_waiting, + }; - let kv_stats = KvStats { - kv_active_blocks: active_blocks_count, - kv_total_blocks: total_capacity, - gpu_cache_usage_perc, - gpu_prefix_cache_hit_rate, - }; + let kv_stats = KvStats { + kv_active_blocks: active_blocks_count, + kv_total_blocks: total_capacity, + gpu_cache_usage_perc, + gpu_prefix_cache_hit_rate, + }; - let spec_decode_stats = None; + let spec_decode_stats = None; - ForwardPassMetrics { - worker_stats, - kv_stats, - spec_decode_stats, - } - // Guards drop naturally here in reverse order (LIFO): hit_rates_guard, kv_manager, state + ForwardPassMetrics { + worker_stats, + kv_stats, + spec_decode_stats, } } @@ -761,6 +788,9 @@ mod tests { let timeout = tokio::time::sleep(Duration::from_secs(2)); tokio::pin!(timeout); + // Get metrics receiver + let metrics_rx = scheduler.metrics_receiver(); + // Set up debug ticker interval let mut debug_interval = interval(Duration::from_millis(500)); @@ -770,7 +800,7 @@ mod tests { // Manual debug ticker that prints forward pass metrics _ = debug_interval.tick() => { - let _metrics = scheduler.get_forward_pass_metrics().await; + let _metrics = metrics_rx.borrow().clone(); println!("Forward Pass Metrics: {_metrics:#?}"); } @@ -862,6 +892,9 @@ mod tests { let timeout = tokio::time::sleep(Duration::from_millis(500)); tokio::pin!(timeout); + // Get metrics receiver + let metrics_rx = scheduler.metrics_receiver(); + // Set up debug ticker interval let mut debug_interval = interval(Duration::from_millis(500)); @@ -871,7 +904,7 @@ mod tests { // Manual debug ticker that prints forward pass metrics _ = debug_interval.tick() => { - let _metrics = scheduler.get_forward_pass_metrics().await; + let _metrics = metrics_rx.borrow().clone(); println!("Forward Pass Metrics: {_metrics:#?}"); } @@ -888,8 +921,11 @@ mod tests { } } + // Wait a bit for final metrics update + tokio::time::sleep(Duration::from_millis(100)).await; + // Verify forward pass metrics - let metrics = scheduler.get_forward_pass_metrics().await; + let metrics = metrics_rx.borrow().clone(); assert_eq!( metrics.worker_stats.num_requests_waiting, 0, @@ -958,7 +994,8 @@ mod tests { tokio::time::sleep(Duration::from_secs(1)).await; // Check forward pass metrics - let metrics = scheduler.get_forward_pass_metrics().await; + let metrics_rx = scheduler.metrics_receiver(); + let metrics = metrics_rx.borrow().clone(); assert_eq!( metrics.kv_stats.gpu_cache_usage_perc, diff --git a/lib/runtime/src/component/client.rs b/lib/runtime/src/component/client.rs index ab57a1e3a6..76336535a1 100644 --- a/lib/runtime/src/component/client.rs +++ b/lib/runtime/src/component/client.rs @@ -44,6 +44,8 @@ pub struct Client { pub instance_source: Arc, // These are the instance source ids less those reported as down from sending rpc instance_avail: Arc>>, + // These are the instance source ids less those reported as busy (above threshold) + instance_free: Arc>>, } #[derive(Clone, Debug)] @@ -59,6 +61,7 @@ impl Client { endpoint, instance_source: Arc::new(InstanceSource::Static), instance_avail: Arc::new(ArcSwap::from(Arc::new(vec![]))), + instance_free: Arc::new(ArcSwap::from(Arc::new(vec![]))), }) } @@ -76,8 +79,9 @@ impl Client { let client = Client { endpoint, - instance_source, + instance_source: instance_source.clone(), instance_avail: Arc::new(ArcSwap::from(Arc::new(vec![]))), + instance_free: Arc::new(ArcSwap::from(Arc::new(vec![]))), }; client.monitor_instance_source(); Ok(client) @@ -108,6 +112,10 @@ impl Client { self.instance_avail.load() } + pub fn instance_ids_free(&self) -> arc_swap::Guard>> { + self.instance_free.load() + } + /// Wait for at least one Instance to be available for this Endpoint pub async fn wait_for_instances(&self) -> Result> { let mut instances: Vec = vec![]; @@ -142,6 +150,16 @@ impl Client { tracing::debug!("inhibiting instance {instance_id}"); } + /// Update the set of free instances based on busy instance IDs + pub fn update_free_instances(&self, busy_instance_ids: &[i64]) { + let all_instance_ids = self.instance_ids(); + let free_ids: Vec = all_instance_ids + .into_iter() + .filter(|id| !busy_instance_ids.contains(id)) + .collect(); + self.instance_free.store(Arc::new(free_ids)); + } + /// Monitor the ETCD instance source and update instance_avail. fn monitor_instance_source(&self) { let cancel_token = self.endpoint.drt().primary_token(); @@ -160,7 +178,10 @@ impl Client { .iter() .map(|instance| instance.id()) .collect(); - client.instance_avail.store(Arc::new(instance_ids)); + + // TODO: this resets both tracked available and free instances + client.instance_avail.store(Arc::new(instance_ids.clone())); + client.instance_free.store(Arc::new(instance_ids)); tracing::debug!("instance source updated"); diff --git a/lib/runtime/src/pipeline/error.rs b/lib/runtime/src/pipeline/error.rs index 6f75fc86d3..794cf2228b 100644 --- a/lib/runtime/src/pipeline/error.rs +++ b/lib/runtime/src/pipeline/error.rs @@ -131,6 +131,10 @@ pub enum PipelineError { #[error("NATS KV Err: {0} for bucket '{1}")] KeyValueError(String, String), + + /// All instances are busy and cannot handle new requests + #[error("Service temporarily unavailable: {0}")] + ServiceOverloaded(String), } #[derive(Debug, thiserror::Error)] diff --git a/lib/runtime/src/pipeline/network/egress/push_router.rs b/lib/runtime/src/pipeline/network/egress/push_router.rs index a55365eac7..4700f838f9 100644 --- a/lib/runtime/src/pipeline/network/egress/push_router.rs +++ b/lib/runtime/src/pipeline/network/egress/push_router.rs @@ -2,11 +2,13 @@ // SPDX-License-Identifier: Apache-2.0 use super::{AsyncEngineContextProvider, ResponseStream}; +use crate::utils::worker_monitor::WorkerMonitor; use crate::{ component::{Client, Endpoint, InstanceSource}, engine::{AsyncEngine, Data}, pipeline::{ - error::PipelineErrorExt, AddressedPushRouter, AddressedRequest, Error, ManyOut, SingleIn, + error::{PipelineError, PipelineErrorExt}, + AddressedPushRouter, AddressedRequest, Error, ManyOut, SingleIn, }, protocols::maybe_error::MaybeError, traits::DistributedRuntimeProvider, @@ -52,6 +54,13 @@ where /// addresses it, then passes it to AddressedPushRouter which does the network traffic. addressed: Arc, + /// Worker monitor for tracking KV cache usage + worker_monitor: Option>, + + /// Threshold for determining when a worker is busy (0.0 to 1.0) + /// If None, busy detection is disabled + busy_threshold: Option, + /// An internal Rust type. This says that PushRouter is generic over the T and U types, /// which are the input and output types of it's `generate` function. It allows the /// compiler to specialize us at compile time. @@ -86,15 +95,43 @@ where T: Data + Serialize, U: Data + for<'de> Deserialize<'de> + MaybeError, { + /// Create a new PushRouter without busy threshold (no busy detection) pub async fn from_client(client: Client, router_mode: RouterMode) -> anyhow::Result { + Self::from_client_with_threshold(client, router_mode, None).await + } + + /// Create a new PushRouter with optional busy threshold + pub async fn from_client_with_threshold( + client: Client, + router_mode: RouterMode, + busy_threshold: Option, + ) -> anyhow::Result { let addressed = addressed_router(&client.endpoint).await?; - Ok(PushRouter { - client, + + // Create worker monitor only if we have a threshold and are in dynamic mode + let worker_monitor = match (busy_threshold, client.instance_source.as_ref()) { + (Some(threshold), InstanceSource::Dynamic(_)) => { + let monitor = Arc::new(WorkerMonitor::new_with_threshold( + Arc::new(client.clone()), + threshold, + )); + monitor.start_monitoring().await?; + Some(monitor) + } + _ => None, + }; + + let router = PushRouter { + client: client.clone(), addressed, router_mode, round_robin_counter: Arc::new(AtomicU64::new(0)), + worker_monitor, + busy_threshold, _phantom: PhantomData, - }) + }; + + Ok(router) } /// Issue a request to the next available instance in a round-robin fashion @@ -170,6 +207,21 @@ where instance_id: i64, request: SingleIn, ) -> anyhow::Result> { + // Check if all workers are busy (only if busy threshold is set) + if self.busy_threshold.is_some() { + let free_instances = self.client.instance_ids_free(); + if free_instances.is_empty() { + // Check if we actually have any instances at all + let all_instances = self.client.instance_ids(); + if !all_instances.is_empty() { + return Err(PipelineError::ServiceOverloaded( + "All workers are busy, please retry later".to_string(), + ) + .into()); + } + } + } + let subject = self.client.endpoint.subject_to(instance_id); let request = request.map(|req| AddressedRequest::new(req, subject)); diff --git a/lib/runtime/src/utils.rs b/lib/runtime/src/utils.rs index 432764f5c1..7f7a56dae6 100644 --- a/lib/runtime/src/utils.rs +++ b/lib/runtime/src/utils.rs @@ -19,3 +19,5 @@ pub mod leader_worker_barrier; pub mod pool; pub mod stream; pub mod task; +pub mod typed_prefix_watcher; +pub mod worker_monitor; diff --git a/lib/runtime/src/utils/typed_prefix_watcher.rs b/lib/runtime/src/utils/typed_prefix_watcher.rs new file mode 100644 index 0000000000..24802c19bc --- /dev/null +++ b/lib/runtime/src/utils/typed_prefix_watcher.rs @@ -0,0 +1,229 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Generic etcd watcher utilities for maintaining collated state from etcd prefixes. +//! +//! This module provides reusable patterns for watching etcd prefixes and maintaining +//! HashMap-based state that automatically updates based on etcd events. + +use crate::transports::etcd::{Client as EtcdClient, WatchEvent}; +use crate::Result; +use etcd_client::KeyValue; +use serde::de::DeserializeOwned; +use std::collections::HashMap; +use std::fmt::Debug; +use tokio::sync::watch; +use tokio_util::sync::CancellationToken; + +/// A generic etcd prefix watcher that maintains a HashMap of deserialized values. +/// +/// This struct watches an etcd prefix and maintains a HashMap where: +/// - Keys are extracted from the etcd KeyValue (e.g., lease_id, key string, etc.) +/// - Values are extracted from the deserialized type using a value extractor +/// +/// # Type Parameters +/// - `K`: The key type for the HashMap (must be hashable) +/// - `V`: The value type stored in the HashMap +pub struct TypedPrefixWatcher +where + K: Clone + Eq + std::hash::Hash + Send + Sync + 'static, + V: Clone + Send + Sync + 'static, +{ + rx: watch::Receiver>, +} + +impl TypedPrefixWatcher +where + K: Clone + Eq + std::hash::Hash + Send + Sync + Debug + 'static, + V: Clone + Send + Sync + 'static, +{ + /// Get a receiver for the current state + pub fn receiver(&self) -> watch::Receiver> { + self.rx.clone() + } + + /// Get the current state + pub fn current(&self) -> HashMap { + self.rx.borrow().clone() + } +} + +/// Watch an etcd prefix and maintain a HashMap of values with field extraction +/// +/// This function watches an etcd prefix and maintains a HashMap where values are +/// extracted from a deserialized type using a value extractor function. +/// +/// # Type Parameters +/// - `K`: The key type for the HashMap +/// - `V`: The value type stored in the HashMap +/// - `T`: The type to deserialize from etcd +/// +/// # Arguments +/// - `client`: The etcd client to use +/// - `prefix`: The prefix to watch in etcd +/// - `key_extractor`: Function to extract the key from a KeyValue +/// - `value_extractor`: Function to extract the value from the deserialized type +/// - `cancellation_token`: Token to stop the watcher +/// +/// # Example +/// ```ignore +/// // Watch for ModelEntry objects and extract runtime_config field +/// let watcher = watch_prefix_with_extraction( +/// etcd_client, +/// "models/", +/// |kv| Some(kv.lease()), // Use lease_id as key +/// |entry: ModelEntry| entry.runtime_config, // Extract runtime_config field +/// cancellation_token, +/// ).await?; +/// ``` +pub async fn watch_prefix_with_extraction( + client: EtcdClient, + prefix: impl Into, + key_extractor: impl Fn(&KeyValue) -> Option + Send + 'static, + value_extractor: impl Fn(T) -> Option + Send + 'static, + cancellation_token: CancellationToken, +) -> Result> +where + K: Clone + Eq + std::hash::Hash + Send + Sync + Debug + 'static, + V: Clone + Send + Sync + 'static, + T: DeserializeOwned + Send + 'static, +{ + let (watch_tx, watch_rx) = watch::channel(HashMap::new()); + let prefix = prefix.into(); + + let prefix_watcher = client.kv_get_and_watch_prefix(&prefix).await?; + let (prefix_str, _watcher, mut events_rx) = prefix_watcher.dissolve(); + + tokio::spawn(async move { + let mut state: HashMap = HashMap::new(); + + loop { + tokio::select! { + _ = cancellation_token.cancelled() => { + tracing::debug!("TypedPrefixWatcher for prefix '{}' cancelled", prefix_str); + break; + } + event = events_rx.recv() => { + let Some(event) = event else { + tracing::debug!("TypedPrefixWatcher watch stream closed for prefix '{}'", prefix_str); + break; + }; + + match event { + WatchEvent::Put(kv) => { + // Extract the key + let Some(key) = key_extractor(&kv) else { + tracing::trace!("Skipping entry - key extractor returned None"); + continue; + }; + + // Deserialize the value + let deserialized = match serde_json::from_slice::(kv.value()) { + Ok(val) => val, + Err(e) => { + tracing::warn!( + "Failed to deserialize value from etcd. Key: {}, Error: {}", + kv.key_str().unwrap_or(""), + e + ); + continue; + } + }; + + // Extract the value + match value_extractor(deserialized) { + Some(v) => { + state.insert(key.clone(), v); + tracing::trace!("Updated entry for key {:?}", key); + } + None => { + state.remove(&key); + tracing::trace!("Removed entry for key {:?} (extractor returned None)", key); + } + } + + if watch_tx.send(state.clone()).is_err() { + tracing::error!("Failed to send update; receiver dropped"); + break; + } + } + WatchEvent::Delete(kv) => { + if let Some(key) = key_extractor(&kv) { + state.remove(&key); + tracing::trace!("Removed entry for deleted key {:?}", key); + + if watch_tx.send(state.clone()).is_err() { + tracing::error!("Failed to send update; receiver dropped"); + break; + } + } + } + } + } + } + } + + tracing::info!("TypedPrefixWatcher for prefix '{}' stopped", prefix_str); + }); + + Ok(TypedPrefixWatcher { rx: watch_rx }) +} + +/// Watch an etcd prefix and maintain a HashMap of values without field extraction +/// +/// This is a simpler version when you want to store the entire deserialized value. +/// +/// # Example +/// ```ignore +/// // Watch for TestConfig objects directly +/// let watcher = watch_prefix( +/// etcd_client, +/// "configs/", +/// |kv| Some(kv.lease()), // Use lease_id as key +/// cancellation_token, +/// ).await?; +/// ``` +pub async fn watch_prefix( + client: EtcdClient, + prefix: impl Into, + key_extractor: impl Fn(&KeyValue) -> Option + Send + 'static, + cancellation_token: CancellationToken, +) -> Result> +where + K: Clone + Eq + std::hash::Hash + Send + Sync + Debug + 'static, + V: Clone + DeserializeOwned + Send + Sync + 'static, +{ + watch_prefix_with_extraction( + client, + prefix, + key_extractor, + |v: V| Some(v), // Identity function - just return the value + cancellation_token, + ) + .await +} + +/// Common key extractors for convenience +pub mod key_extractors { + use etcd_client::KeyValue; + + /// Extract the lease ID as the key + pub fn lease_id(kv: &KeyValue) -> Option { + Some(kv.lease()) + } + + /// Extract the key as a string (without prefix) + pub fn key_string(prefix: &str) -> impl Fn(&KeyValue) -> Option { + let prefix = prefix.to_string(); + move |kv: &KeyValue| { + kv.key_str() + .ok() + .map(|k| k.strip_prefix(&prefix).unwrap_or(k).to_string()) + } + } + + /// Extract the full key as a string + pub fn full_key_string(kv: &KeyValue) -> Option { + kv.key_str().ok().map(|s| s.to_string()) + } +} diff --git a/lib/runtime/src/utils/worker_monitor.rs b/lib/runtime/src/utils/worker_monitor.rs new file mode 100644 index 0000000000..ed3ce34d74 --- /dev/null +++ b/lib/runtime/src/utils/worker_monitor.rs @@ -0,0 +1,190 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +// TODO: Make load comparisons and runtime metrics a generic trait so this monitoring +// system is not tied to KV cache concepts, which are LLM-specific. This would allow +// different types of workers to define their own load metrics and busy thresholds. + +use crate::component::{Client, InstanceSource}; +use crate::traits::events::EventSubscriber; +use crate::traits::DistributedRuntimeProvider; +use crate::utils::typed_prefix_watcher::{key_extractors, watch_prefix_with_extraction}; +use std::collections::HashMap; +use std::sync::{Arc, RwLock}; +use tokio::sync::watch; +use tokio_stream::StreamExt; + +// Constants for monitoring configuration +const KV_METRICS_SUBJECT: &str = "kv_metrics"; +const MODEL_ROOT_PATH: &str = "models"; + +// Internal structs for deserializing metrics events +#[derive(serde::Deserialize)] +struct LoadEvent { + worker_id: i64, + data: ForwardPassMetrics, +} + +#[derive(serde::Deserialize)] +struct ForwardPassMetrics { + kv_stats: KvStats, +} + +#[derive(serde::Deserialize)] +struct KvStats { + kv_active_blocks: u64, +} + +#[derive(serde::Deserialize)] +struct ModelEntry { + runtime_config: Option, +} + +#[derive(serde::Deserialize)] +struct RuntimeConfig { + total_kv_blocks: Option, +} + +/// Worker load monitoring state +#[derive(Clone, Debug)] +pub struct WorkerLoadState { + pub kv_active_blocks: Option, + pub kv_total_blocks: Option, +} + +impl WorkerLoadState { + pub fn is_busy(&self, threshold: f64) -> bool { + match (self.kv_active_blocks, self.kv_total_blocks) { + (Some(active), Some(total)) if total > 0 => { + (active as f64) > (threshold * total as f64) + } + _ => false, + } + } +} + +/// Worker monitor for tracking KV cache usage and busy states +pub struct WorkerMonitor { + client: Arc, + worker_load_states: Arc>>, + busy_threshold: f64, +} + +impl WorkerMonitor { + /// Create a new worker monitor with custom threshold + pub fn new_with_threshold(client: Arc, busy_threshold: f64) -> Self { + Self { + client, + worker_load_states: Arc::new(RwLock::new(HashMap::new())), + busy_threshold, + } + } + + /// Get the worker load states for external access + pub fn load_states(&self) -> Arc>> { + self.worker_load_states.clone() + } + + /// Start background monitoring of worker KV cache usage + pub async fn start_monitoring(&self) -> anyhow::Result<()> { + let endpoint = &self.client.endpoint; + let component = endpoint.component(); + + let Some(etcd_client) = component.drt().etcd_client() else { + // Static mode, no monitoring needed + return Ok(()); + }; + + let runtime_configs_watcher = watch_prefix_with_extraction( + etcd_client, + MODEL_ROOT_PATH, + key_extractors::lease_id, + |entry: ModelEntry| entry.runtime_config.and_then(|rc| rc.total_kv_blocks), + component.drt().child_token(), + ) + .await?; + let mut config_events_rx = runtime_configs_watcher.receiver(); + + // Subscribe to KV metrics events + let mut kv_metrics_rx = component.namespace().subscribe(KV_METRICS_SUBJECT).await?; + + let worker_load_states = self.worker_load_states.clone(); + let client = self.client.clone(); + let cancellation_token = component.drt().child_token(); + let busy_threshold = self.busy_threshold; // Capture threshold for the closure + + // Spawn background monitoring task + tokio::spawn(async move { + let mut previous_busy_instances = Vec::new(); // Track previous state + + loop { + tokio::select! { + _ = cancellation_token.cancelled() => { + tracing::debug!("Worker monitoring cancelled"); + break; + } + + // Handle runtime config updates - now receives full HashMap + _ = config_events_rx.changed() => { + let runtime_configs = config_events_rx.borrow().clone(); + + let mut states = worker_load_states.write().unwrap(); + states.retain(|lease_id, _| runtime_configs.contains_key(lease_id)); + + // Update worker load states with total blocks + for (lease_id, total_blocks) in runtime_configs.iter() { + let state = states.entry(*lease_id).or_insert(WorkerLoadState { + kv_active_blocks: None, + kv_total_blocks: None, + }); + state.kv_total_blocks = Some(*total_blocks); + } + } + + // Handle KV metrics updates + kv_event = kv_metrics_rx.next() => { + let Some(event) = kv_event else { + tracing::debug!("KV metrics stream closed"); + break; + }; + + if let Ok(load_event) = serde_json::from_slice::(&event.payload) { + let worker_id = load_event.worker_id; + let active_blocks = load_event.data.kv_stats.kv_active_blocks; + + // Update worker load state + let mut states = worker_load_states.write().unwrap(); + let state = states.entry(worker_id).or_insert(WorkerLoadState { + kv_active_blocks: None, + kv_total_blocks: None, + }); + state.kv_active_blocks = Some(active_blocks); + drop(states); + + // Recalculate all busy instances and update + let states = worker_load_states.read().unwrap(); + let busy_instances: Vec = states + .iter() + .filter_map(|(&id, state)| { + state.is_busy(busy_threshold).then_some(id) + }) + .collect(); + drop(states); + + // Only update if busy_instances has changed + if busy_instances != previous_busy_instances { + tracing::debug!("Busy instances changed: {:?}", busy_instances); + client.update_free_instances(&busy_instances); + previous_busy_instances = busy_instances; + } + } + } + } + } + + tracing::info!("Worker monitoring task exiting"); + }); + + Ok(()) + } +} diff --git a/tests/router/test_router_e2e_with_mockers.py b/tests/router/test_router_e2e_with_mockers.py index 92e5b4d6c0..4036a974bc 100644 --- a/tests/router/test_router_e2e_with_mockers.py +++ b/tests/router/test_router_e2e_with_mockers.py @@ -5,6 +5,8 @@ import json import logging import os +import random +from typing import Any, Dict import aiohttp import pytest @@ -22,6 +24,19 @@ NUM_REQUESTS = 100 PORT = 8090 # Starting port for mocker instances +# Shared test payload for all tests +TEST_PAYLOAD: Dict[str, Any] = { + "model": MODEL_NAME, + "messages": [ + { + "role": "user", + "content": "In a quiet meadow tucked between rolling hills, a plump gray rabbit nibbled on clover beneath the shade of a gnarled oak tree. Its ears twitched at the faint rustle of leaves, but it remained calm, confident in the safety of its burrow just a few hops away. The late afternoon sun warmed its fur, and tiny dust motes danced in the golden light as bees hummed lazily nearby. Though the rabbit lived a simple life, every day was an adventure of scents, shadows, and snacks—an endless search for the tastiest patch of greens and the softest spot to nap.", + } + ], + "stream": True, + "max_tokens": 10, +} + class MockerProcess(ManagedProcess): """Manages a single mocker engine instance""" @@ -88,6 +103,89 @@ def __exit__(self, exc_type, exc_val, exc_tb): super().__exit__(exc_type, exc_val, exc_tb) +async def send_request_with_retry(url: str, payload: dict, max_retries: int = 4): + """Send a single request with exponential backoff retry""" + wait_time = 1 # Start with 1 second + + for attempt in range(max_retries + 1): + await asyncio.sleep(wait_time) + try: + async with aiohttp.ClientSession() as session: + async with session.post(url, json=payload) as response: + if response.status == 200: + # Read the response to ensure it's valid + async for _ in response.content: + pass + logger.info(f"First request succeeded on attempt {attempt + 1}") + return True + else: + logger.warning( + f"Attempt {attempt + 1} failed with status {response.status}" + ) + except Exception as e: + logger.warning(f"Attempt {attempt + 1} failed with error: {e}") + + if attempt < max_retries: + wait_time *= 2 # Double the wait time + + return False + + +async def send_concurrent_requests(urls: list, payload: dict, num_requests: int): + """Send multiple requests concurrently, alternating between URLs if multiple provided""" + + # First, send test requests with retry to ensure all systems are ready + for i, url in enumerate(urls): + logger.info(f"Sending initial test request to URL {i} ({url}) with retry...") + if not await send_request_with_retry(url, payload): + raise RuntimeError(f"Failed to connect to URL {i} after multiple retries") + + async def send_single_request(session: aiohttp.ClientSession, request_id: int): + # Alternate between URLs based on request_id + url = urls[request_id % len(urls)] + url_index = request_id % len(urls) + + try: + async with session.post(url, json=payload) as response: + if response.status != 200: + logger.error( + f"Request {request_id} to URL {url_index} failed with status {response.status}" + ) + return False + + # For streaming responses, read the entire stream + chunks = [] + async for line in response.content: + if line: + chunks.append(line) + + logger.debug( + f"Request {request_id} to URL {url_index} completed with {len(chunks)} chunks" + ) + return True + + except Exception as e: + logger.error( + f"Request {request_id} to URL {url_index} failed with error: {e}" + ) + return False + + # Send all requests at once + async with aiohttp.ClientSession() as session: + tasks = [send_single_request(session, i) for i in range(num_requests)] + results = await asyncio.gather(*tasks) + + successful = sum(1 for r in results if r) + failed = sum(1 for r in results if not r) + + logger.info(f"Completed all requests: {successful} successful, {failed} failed") + + assert ( + successful == num_requests + ), f"Expected {num_requests} successful requests, got {successful}" + logger.info(f"All {num_requests} requests completed successfully") + + @pytest.mark.pre_merge def test_mocker_kv_router(request, runtime_services): """ @@ -128,26 +226,13 @@ def test_mocker_kv_router(request, runtime_services): for mocker in mocker_processes: mocker.__enter__() - # Send test requests - test_payload = { - "model": MODEL_NAME, - "messages": [ - { - "role": "user", - "content": "In a quiet meadow tucked between rolling hills, a plump gray rabbit nibbled on clover beneath the shade of a gnarled oak tree. Its ears twitched at the faint rustle of leaves, but it remained calm, confident in the safety of its burrow just a few hops away. The late afternoon sun warmed its fur, and tiny dust motes danced in the golden light as bees hummed lazily nearby. Though the rabbit lived a simple life, every day was an adventure of scents, shadows, and snacks—an endless search for the tastiest patch of greens and the softest spot to nap.", - } - ], - "stream": True, - "max_tokens": 10, - } - # Use async to send requests concurrently for better performance asyncio.run( send_concurrent_requests( [ f"http://localhost:{frontend_port}/v1/chat/completions" ], # Pass as list - test_payload, + TEST_PAYLOAD, NUM_REQUESTS, ) ) @@ -209,19 +294,6 @@ def test_mocker_two_kv_router(request, runtime_services): for mocker in mocker_processes: mocker.__enter__() - # Send test requests - test_payload = { - "model": MODEL_NAME, - "messages": [ - { - "role": "user", - "content": "In a quiet meadow tucked between rolling hills, a plump gray rabbit nibbled on clover beneath the shade of a gnarled oak tree. Its ears twitched at the faint rustle of leaves, but it remained calm, confident in the safety of its burrow just a few hops away. The late afternoon sun warmed its fur, and tiny dust motes danced in the golden light as bees hummed lazily nearby. Though the rabbit lived a simple life, every day was an adventure of scents, shadows, and snacks—an endless search for the tastiest patch of greens and the softest spot to nap.", - } - ], - "stream": True, - "max_tokens": 10, - } - # Build URLs for both routers router_urls = [ f"http://localhost:{port}/v1/chat/completions" for port in router_ports @@ -231,7 +303,7 @@ def test_mocker_two_kv_router(request, runtime_services): asyncio.run( send_concurrent_requests( router_urls, - test_payload, + TEST_PAYLOAD, NUM_REQUESTS, ) ) @@ -253,84 +325,177 @@ def test_mocker_two_kv_router(request, runtime_services): os.unlink(mocker_args_file) -async def send_request_with_retry(url: str, payload: dict, max_retries: int = 4): - """Send a single request with exponential backoff retry""" - wait_time = 1 # Start with 1 second +@pytest.mark.pre_merge +@pytest.mark.skip(reason="Flaky, temporarily disabled") +def test_mocker_kv_router_overload_503(request, runtime_services): + """ + Test that KV router returns 503 when all workers are busy. + This test uses limited resources to intentionally trigger the overload condition. + """ - for attempt in range(max_retries + 1): - await asyncio.sleep(wait_time) - try: - async with aiohttp.ClientSession() as session: - async with session.post(url, json=payload) as response: - if response.status == 200: - # Read the response to ensure it's valid - async for _ in response.content: - pass - logger.info(f"First request succeeded on attempt {attempt + 1}") - return True - else: - logger.warning( - f"Attempt {attempt + 1} failed with status {response.status}" - ) - except Exception as e: - logger.warning(f"Attempt {attempt + 1} failed with error: {e}") + # runtime_services starts etcd and nats + logger.info("Starting mocker KV router overload test for 503 status") - if attempt < max_retries: - wait_time *= 2 # Double the wait time + # Create mocker args file with limited resources + mocker_args = { + "speedup_ratio": 10, + "block_size": 4, # Smaller block size + "num_gpu_blocks": 64, # Limited GPU blocks to exhaust quickly + } - return False + mocker_args_file = os.path.join(request.node.name, "mocker_args_overload.json") + with open(mocker_args_file, "w") as f: + json.dump(mocker_args, f) + try: + # Start KV router (frontend) with limited block size + frontend_port = PORT + 10 # Use different port to avoid conflicts + logger.info( + f"Starting KV router frontend on port {frontend_port} with limited resources" + ) -async def send_concurrent_requests(urls: list, payload: dict, num_requests: int): - """Send multiple requests concurrently, alternating between URLs if multiple provided""" + # Custom command for router with limited block size + command = [ + "python", + "-m", + "dynamo.frontend", + "--busy-threshold", + "0.2", + "--kv-cache-block-size", + "4", # Match the mocker's block size + "--router-mode", + "kv", + "--http-port", + str(frontend_port), + ] - # First, send test requests with retry to ensure all systems are ready - for i, url in enumerate(urls): - logger.info(f"Sending initial test request to URL {i} ({url}) with retry...") - if not await send_request_with_retry(url, payload): - raise RuntimeError(f"Failed to connect to URL {i} after multiple retries") + kv_router = ManagedProcess( + command=command, + timeout=60, + display_output=True, + health_check_ports=[frontend_port], + health_check_urls=[ + ( + f"http://localhost:{frontend_port}/v1/models", + lambda r: r.status_code == 200, + ) + ], + log_dir=request.node.name, + terminate_existing=False, + ) + kv_router.__enter__() - async def send_single_request(session: aiohttp.ClientSession, request_id: int): - # Alternate between URLs based on request_id - url = urls[request_id % len(urls)] - url_index = request_id % len(urls) + # Start single mocker instance with limited resources + endpoint = "dyn://test-namespace.mocker.generate" + logger.info( + f"Starting single mocker instance with limited resources on endpoint {endpoint}" + ) - try: - async with session.post(url, json=payload) as response: - if response.status != 200: - logger.error( - f"Request {request_id} to URL {url_index} failed with status {response.status}" - ) - return False + mocker = MockerProcess(request, endpoint, mocker_args_file) + mocker.__enter__() - # For streaming responses, read the entire stream - chunks = [] - async for line in response.content: - if line: - chunks.append(line) + url = f"http://localhost:{frontend_port}/v1/chat/completions" - logger.debug( - f"Request {request_id} to URL {url_index} completed with {len(chunks)} chunks" - ) - return True + # Custom payload for 503 test with more tokens to consume resources + test_payload_503 = { + **TEST_PAYLOAD, + "max_tokens": 50, # Longer output to consume more blocks + } - except Exception as e: - logger.error( - f"Request {request_id} to URL {url_index} failed with error: {e}" - ) - return False + # First, send one request with retry to ensure system is ready + logger.info("Sending initial request to ensure system is ready...") + asyncio.run(send_concurrent_requests([url], test_payload_503, 1)) - # Send all requests at once - async with aiohttp.ClientSession() as session: - tasks = [send_single_request(session, i) for i in range(num_requests)] - results = await asyncio.gather(*tasks) + # Now send 50 concurrent requests to exhaust resources, then verify 503 + logger.info("Sending 50 concurrent requests to exhaust resources...") - successful = sum(1 for r in results if r) - failed = sum(1 for r in results if not r) + async def exhaust_resources_and_verify_503(): + async with aiohttp.ClientSession() as session: + # Start 50 long-running requests concurrently + tasks = [] + for i in range(50): + # Create unique shuffled content for each request + content_words = TEST_PAYLOAD["messages"][0]["content"].split() + random.shuffle(content_words) + shuffled_content = " ".join(content_words) + + # Create unique payload for this request + unique_payload = { + **TEST_PAYLOAD, + "max_tokens": 50, + "messages": [ + {**TEST_PAYLOAD["messages"][0], "content": shuffled_content} + ], + } + + async def send_long_request(req_id, payload): + try: + async with session.post(url, json=payload) as response: + if response.status == 200: + # Don't read the response fully, just hold the connection + await asyncio.sleep( + 10 + ) # Hold connection for 10 seconds + return True + else: + logger.info( + f"Request {req_id} got status {response.status}" + ) + return False + except Exception as e: + logger.info(f"Request {req_id} failed: {e}") + return False + + tasks.append( + asyncio.create_task(send_long_request(i, unique_payload)) + ) - logger.info(f"Completed all requests: {successful} successful, {failed} failed") + # Wait briefly to ensure requests are in-flight + await asyncio.sleep(0.2) + + # Now send one more request that should get 503 + logger.info("Sending additional request that should receive 503...") + try: + async with session.post(url, json=test_payload_503) as response: + status_code = response.status + if status_code == 503: + body = await response.json() + logger.info(f"Got expected 503 response: {body}") + assert "Service temporarily unavailable" in body.get( + "error", "" + ) or "All workers are busy" in body.get( + "error", "" + ), f"Expected service overload error message, got: {body}" + return True + else: + logger.error(f"Expected 503 but got {status_code}") + if status_code == 200: + logger.error( + "Request unexpectedly succeeded when it should have been rejected" + ) + return False + except Exception as e: + logger.error(f"Failed to send overload test request: {e}") + return False + finally: + # Cancel all background tasks + for task in tasks: + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) - assert ( - successful == num_requests - ), f"Expected {num_requests} successful requests, got {successful}" - logger.info(f"All {num_requests} requests completed successfully") + # Run the test + success = asyncio.run(exhaust_resources_and_verify_503()) + assert success, "Failed to verify 503 response when resources are exhausted" + + logger.info("Successfully verified 503 response when all workers are busy") + + finally: + # Clean up + if "kv_router" in locals(): + kv_router.__exit__(None, None, None) + + if "mocker" in locals(): + mocker.__exit__(None, None, None) + + if os.path.exists(mocker_args_file): + os.unlink(mocker_args_file)