diff --git a/Cargo.lock b/Cargo.lock index a5e4b8e2d3..635eebe076 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1845,6 +1845,7 @@ dependencies = [ "chrono", "criterion", "cudarc", + "dashmap", "derive-getters", "derive_builder", "dialoguer", diff --git a/components/frontend/src/dynamo/frontend/main.py b/components/frontend/src/dynamo/frontend/main.py index f0f430d0b1..06594443bc 100644 --- a/components/frontend/src/dynamo/frontend/main.py +++ b/components/frontend/src/dynamo/frontend/main.py @@ -112,6 +112,12 @@ def parse_args(): help=" KV Router. Disable KV events.", ) parser.set_defaults(use_kv_events=True) + parser.add_argument( + "--router-replica-sync", + action="store_true", + default=False, + help="KV Router: Enable replica synchronization across multiple router instances. When true, routers will publish and subscribe to events to maintain consistent state.", + ) parser.add_argument( "--static-endpoint", type=validate_static_endpoint, @@ -148,6 +154,7 @@ async def async_main(): overlap_score_weight=flags.kv_overlap_score_weight, router_temperature=flags.router_temperature, use_kv_events=flags.use_kv_events, + router_replica_sync=flags.router_replica_sync, ) elif flags.router_mode == "random": router_mode = RouterMode.Random diff --git a/components/metrics/src/lib.rs b/components/metrics/src/lib.rs index 80b2f163dd..bf925b96e3 100644 --- a/components/metrics/src/lib.rs +++ b/components/metrics/src/lib.rs @@ -84,7 +84,7 @@ use std::net::SocketAddr; use std::time::Duration as StdDuration; use dynamo_llm::kv_router::protocols::{ForwardPassMetrics, LoadMetrics}; -use dynamo_llm::kv_router::scheduler::Endpoint; +use dynamo_llm::kv_router::scoring::Endpoint; use dynamo_llm::kv_router::scoring::ProcessedEndpoints; use dynamo_runtime::{ diff --git a/components/router/src/main.rs b/components/router/src/main.rs index 5f741338fd..1ee7fb64fd 100644 --- a/components/router/src/main.rs +++ b/components/router/src/main.rs @@ -66,7 +66,7 @@ async fn app(runtime: Runtime) -> Result<()> { let selector = Box::new(CustomWorkerSelector::default()); - let router = KvRouter::new(component.clone(), args.block_size, Some(selector), true).await?; + let router = KvRouter::new(component.clone(), args.block_size, Some(selector), None).await?; let router = Ingress::for_engine(Arc::new(router))?; component diff --git a/docs/architecture/kv_cache_routing.md b/docs/architecture/kv_cache_routing.md index a78feef9f5..6ca0181186 100644 --- a/docs/architecture/kv_cache_routing.md +++ b/docs/architecture/kv_cache_routing.md @@ -17,12 +17,13 @@ For performance testing, compare a typical workload with `--router-mode random|r The KV-aware routing arguments: -- `--kv-overlap-score-weight`: Sets the amount of weighting on overlaps with prefix caches, which directly contributes to the prefill cost. A large weight is expected to yield a better TTFT (at the expense of worse ITL). When set to 0, prefix caches are not considered at all (falling back to pure load balancing behavior on the active blocks). +- `--kv-overlap-score-weight`: Sets the amount of weighting on overlaps with prefix caches, which directly contributes to the prefill cost. A large weight is expected to yield a better TTFT (at the expense of worse ITL). When set to 0, prefix caches are not considered at all (falling back to pure load balancing behavior on the active blocks). Defaults to 1. -- `--router-temperature`: Sets the temperature when randomly selecting workers to route to via softmax sampling on the router cost logits. Setting it to 0 recovers the deterministic behavior where the min logit is picked. +- `--router-temperature`: Sets the temperature when randomly selecting workers to route to via softmax sampling on the router cost logits. Setting it to 0 (default) recovers the deterministic behavior where the min logit is picked. -- `--use-kv-events`: Sets whether to listen to KV events for maintaining the global view of cached blocks. If true, then we use the `KvIndexer` to listen to the block creation and deletion events. If false, `ApproxKvIndexer`, which assumes the kv cache of historical prompts exists for fixed time durations (hard-coded to 120s), is used to predict the kv cache hit ratio in each engine. Set false if your backend engine does not emit KV events. +- `--use-kv-events`/`--no-kv-events`: Sets whether to listen to KV events for maintaining the global view of cached blocks. If true (default), then we use the `KvIndexer` to listen to the block creation and deletion events. If false, `ApproxKvIndexer`, which assumes the kv cache of historical prompts exists for fixed time durations (hard-coded to 120s), is used to predict the kv cache hit ratio in each engine. Set false if your backend engine does not emit KV events. +- `--router-replica-sync`: Enables state synchronization between multiple router replicas via NATS. Disabled by default, and can be enabled by passing the flag in. When enabled, router replicas share their view of KV cache distribution and active sequences, allowing all routers to make optimal routing decisions even when requests are distributed across multiple router instances. This improves fault tolerance and routing accuracy in multi-router deployments. ## Architecture @@ -45,6 +46,22 @@ We can then use the default routing methods exposed by the client class to send KV Cache routing uses direct routing with a special worker selection algorithm. +## Serving Two Router Replicas + +For improved fault tolerance, you can launch two frontend + router replicas. Since the frontend and router are currently tied together, you'll need to use two different HTTP ports for each instance. + +To enable state sharing between the router replicas (which provides more accurate routing decisions), use the `--router-replica-sync` flag when starting the frontend: + +```bash +# Router replica 1 +python -m dynamo.frontend --router-mode kv --port 8000 --router-replica-sync + +# Router replica 2 +python -m dynamo.frontend --router-mode kv --port 8001 --router-replica-sync +``` + +When `--router-replica-sync` is enabled, the router replicas will communicate with each other via NATS to maintain consistent state across instances. This allows both routers to have a complete view of the KV cache distribution and make optimal routing decisions, even when requests are distributed across multiple router instances. + ## Understanding KV Cache The leading Large Language Models (LLMs) today are auto-regressive and based off of the [transformer architecture](https://proceedings.neurips.cc/paper_files/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf). One key inference optimization technique is to cache the already computed keys and values and to reuse them for the future tokens. This is called the [KV Cache](https://developer.nvidia.com/blog/mastering-llm-techniques-inference-optimization/#key-value_caching). @@ -88,30 +105,46 @@ Further details can be found for: [TRT-LLM](https://developer.nvidia.com/blog/in | +------------------+------------------+ | | | - | KV match: 15% | KV match: 50% | KV match: 75% + | Cached: 2 blocks | Cached: 5 blocks | Cached: 8 blocks + | Prefill: 8 blks | Prefill: 5 blks | Prefill: 2 blks + | Decode: 10 blks | Decode: 7 blks | Decode: 9 blks v v v +----------------+ +----------------+ +----------------+ | Worker 1 | | Worker 2 | | Worker 3 | - | (Load: 30%) | | (Load: 50%) | | (Load: 80%) | +----------------+ +----------------+ +----------------+ ``` Load balancing in LLM serving becomes complex when enabling KV Cache reuse. While KV Cache reuse can save significant computation, if the routing strategy is not aware of the unique KV states of each worker we can: -- miss opportunities for KV Cache reuse if routing to the “wrong” node +- miss opportunities for KV Cache reuse if routing to the "wrong" node - get into an imbalanced state where a few workers are processing many requests, lowering throughput of entire system -The best way to solve these issues is for the router to have a global view of KV Cache and load. With this view, the router can use a cost function to score the workers and make decisions to maximize cache hits while keeping the system balanced and throughput high. +The router uses a cost function that considers both the prefill cost (influenced by cached blocks) and the decode load to make optimal routing decisions: + +### Cost Calculation + +1. **Prefill blocks**: The number of tokens that need to be processed during prefill is predicted based on the request's input tokens and the cached blocks available on each worker. This is divided by the block size to get the effective "prefill blocks". This prediction is updated when the first output token is produced, signaling prefill completion. -In the above image, our cost function is (KV match - Load) so we select Worker 2 even though Worker 3 would offer the best KV match. -- Worker 1 = (0.15 - 0.30) = -0.15 -- **Worker 2 = (0.50 - 0.50) = 0** -- Worker 3 = (0.75 - 0.80) = -0.05 +2. **Decode blocks**: The number of blocks needed during the decode phase is predicted based on the request's input tokens and the current active sequences on each worker. This is updated when the request is freed (blocks are dereferenced or freed). + +3. **Cost formula**: `cost = overlap_score_weight * prefill_blocks + decode_blocks` + - Lower cost is better + - The `overlap_score_weight` parameter controls the importance of cache hits vs. load balancing + - A higher weight prioritizes cache reuse (better TTFT) while a lower weight prioritizes load distribution (better ITL) + +### Worker Selection + +The router selects the worker with the lowest cost. When `router_temperature` is set to a non-zero value, the router uses softmax sampling on the normalized cost logits to introduce randomness in the selection, which can help with load distribution. + +Example calculation with `overlap_score_weight = 1.0`: +- Worker 1: cost = 1.0 * 8 + 10 = 18 +- **Worker 2: cost = 1.0 * 5 + 7 = 12** (selected - lowest cost) +- Worker 3: cost = 1.0 * 2 + 9 = 11 ## Events -In Dynamo, we want to support KV Cache Routing and load balancing for many backends that have different implementations of KV Cache and record different metrics. To that end, we built a KVPublisher that can be plugged into any framework to publish KV Events and a WorkerMetricsPublisher that can publish Metric Events. +In Dynamo, we support KV Cache Routing for many backends that have different implementations of KV Cache. To enable this, we built a KVPublisher that can be plugged into any framework to publish KV Events. -On the receiving side we have a KVIndexer which accepts events from the KVPublisher and puts them into a global prefix tree and a KvMetricsAggregator which aggregates metric events by worker. +On the receiving side we have a KVIndexer which accepts events from the KVPublisher and puts them into a global prefix tree for tracking cached blocks across all workers. ```text +----------------+ +-----------------+ @@ -121,13 +154,8 @@ On the receiving side we have a KVIndexer which accepts events from the KVPublis | +------------+ | remove_kv_block() | | KVIndexer | | | |KVPublisher | |------------------------>| +-------------+ | | +------------+ | | | -| | num_request_waiting | +--------------+| -| +------------+ | gpu_cache_usage_perc | |KvMetricsAggre|| -| |KvMetrics | |------------------------>| | gator || -| |Publisher | | ... | +--------------+| -| +------------+ | +-----------------+ -+----------------+ - +| | | | ++----------------+ +-----------------+ ``` ### KVPublisher @@ -144,18 +172,15 @@ The KVIndexer builds and maintains a global view of cached blocks in a prefix tr The KVIndexer has a method `find_matches_for_request`, which takes in tokens and returns a dictionary with keys of worker id and values of the number of matched KV Blocks. -### WorkerMetricsPublisher -We added a KvMetrics Publisher which sends the following metrics to the KvMetricsAggregator: -- num_requests_waiting -- gpu_cache_usage_perc -- gpu_prefix_cache_hit_rate -- request_active_slots -- request_total_slots -- kv_active_blocks -- kv_total_blocks +### Inter-Router Communication + +In multi-router deployments, each router only observes a subset of requests. To maintain a consistent global view of active sequences and KV cache states, routers broadcast their local actions to other replicas through three synchronization events: + +1. **AddRequest**: Published when assigning a request to a worker, containing the request ID, worker ID, token sequence blocks, and overlap score. This updates other routers' tracking of which blocks are in use. + +2. **MarkPrefillCompleted**: Published when a request transitions from prefill to decode phase, signaling that prefill tokens should no longer count toward the worker's active prefill load. -Currently, the WorkerMetricsPublisher exists as a Python binding. +3. **Free**: Published when a request completes and its resources are released, allowing other routers to update their block reference counts. -### KvMetricsAggregator -The KvMetricsAggregator receives these metrics and aggregates them. It has a method `get_metrics` which returns an object of `AggregatedMetrics`. +Each event includes a unique router ID to prevent processing of self-generated events. This asynchronous communication ensures all routers maintain synchronized KV cache state for optimal routing decisions despite handling different request streams. diff --git a/launch/dynamo-run/src/flags.rs b/launch/dynamo-run/src/flags.rs index 717f03e809..a3ba159b7a 100644 --- a/launch/dynamo-run/src/flags.rs +++ b/launch/dynamo-run/src/flags.rs @@ -96,6 +96,12 @@ pub struct Flags { #[arg(long)] pub use_kv_events: Option, + /// KV Router: Whether to enable replica synchronization across multiple router instances. + /// When true, routers will publish and subscribe to events to maintain consistent state. + /// Default: false + #[arg(long)] + pub router_replica_sync: Option, + /// Max model context length. Reduce this if you don't have enough VRAM for the full model /// context length (e.g. Llama 4). /// Defaults to the model's max, which is usually model_max_length in tokenizer_config.json. @@ -223,6 +229,7 @@ impl Flags { self.kv_overlap_score_weight, self.router_temperature, self.use_kv_events, + self.router_replica_sync, self.max_num_batched_tokens, ), ) diff --git a/lib/bindings/python/Cargo.lock b/lib/bindings/python/Cargo.lock index d2bea055c1..a884332381 100644 --- a/lib/bindings/python/Cargo.lock +++ b/lib/bindings/python/Cargo.lock @@ -1140,6 +1140,7 @@ dependencies = [ "candle-core", "chrono", "cudarc", + "dashmap", "derive-getters", "derive_builder", "dialoguer", diff --git a/lib/bindings/python/rust/llm/entrypoint.rs b/lib/bindings/python/rust/llm/entrypoint.rs index a98c89a7c3..c2843bac01 100644 --- a/lib/bindings/python/rust/llm/entrypoint.rs +++ b/lib/bindings/python/rust/llm/entrypoint.rs @@ -35,13 +35,19 @@ pub struct KvRouterConfig { #[pymethods] impl KvRouterConfig { #[new] - #[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true))] - fn new(overlap_score_weight: f64, router_temperature: f64, use_kv_events: bool) -> Self { + #[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, router_replica_sync=false))] + fn new( + overlap_score_weight: f64, + router_temperature: f64, + use_kv_events: bool, + router_replica_sync: bool, + ) -> Self { KvRouterConfig { inner: RsKvRouterConfig { overlap_score_weight, router_temperature, use_kv_events, + router_replica_sync, ..Default::default() }, } diff --git a/lib/llm/Cargo.toml b/lib/llm/Cargo.toml index edb9a82cc6..368e7072c3 100644 --- a/lib/llm/Cargo.toml +++ b/lib/llm/Cargo.toml @@ -85,6 +85,7 @@ derive-getters = "0.5" offset-allocator = "0.2" regex = "1" rayon = "1" +dashmap = { version = "5.5.3" } # input/text dialoguer = { version = "0.11", default-features = false, features = ["editor", "history"] } diff --git a/lib/llm/src/discovery/model_manager.rs b/lib/llm/src/discovery/model_manager.rs index caa105e921..6773c2e98e 100644 --- a/lib/llm/src/discovery/model_manager.rs +++ b/lib/llm/src/discovery/model_manager.rs @@ -217,7 +217,7 @@ impl ModelManager { component.clone(), kv_cache_block_size, Some(selector), - kv_router_config.unwrap_or_default().use_kv_events, + kv_router_config, ) .await?; let new_kv_chooser = Arc::new(chooser); diff --git a/lib/llm/src/kv_router.rs b/lib/llm/src/kv_router.rs index 2e5d02fe82..1e488872e1 100644 --- a/lib/llm/src/kv_router.rs +++ b/lib/llm/src/kv_router.rs @@ -15,11 +15,11 @@ use dynamo_runtime::{ protocols::annotated::Annotated, }; use futures::stream::{self, StreamExt}; -use tokio::sync::Mutex; pub mod approx; pub mod indexer; pub mod metrics_aggregator; +pub mod prefill_counter; pub mod protocols; pub mod publisher; pub mod recorder; @@ -48,9 +48,18 @@ use dynamo_runtime::traits::events::EventSubscriber; // [gluo TODO] shouldn't need to be public // this should be discovered from the component + +// for metric scraping (pull-based) +pub const KV_METRICS_ENDPOINT: &str = "load_metrics"; + +// for metric publishing (push-based) pub const KV_EVENT_SUBJECT: &str = "kv_events"; pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate"; -pub const KV_METRICS_ENDPOINT: &str = "load_metrics"; +pub const KV_METRICS_SUBJECT: &str = "kv_metrics"; + +// for inter-router comms +pub const PREFILL_SUBJECT: &str = "prefill_events"; +pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events"; /// A trait that users can implement to define custom selection logic pub trait WorkerSelector { @@ -71,6 +80,8 @@ pub struct KvRouterConfig { pub use_kv_events: bool, + pub router_replica_sync: bool, + // TODO: this is not actually used for now // Would need this (along with total kv blocks) to trigger AllWorkersBusy error for e.g. rate-limiting pub max_num_batched_tokens: u32, @@ -82,6 +93,7 @@ impl Default for KvRouterConfig { overlap_score_weight: 1.0, router_temperature: 0.0, use_kv_events: true, + router_replica_sync: false, max_num_batched_tokens: 8192, } } @@ -94,6 +106,7 @@ impl KvRouterConfig { overlap_score_weight: Option, temperature: Option, use_kv_events: Option, + replica_sync: Option, max_num_batched_tokens: Option, ) -> Self { let default = Self::default(); @@ -101,6 +114,7 @@ impl KvRouterConfig { overlap_score_weight: overlap_score_weight.unwrap_or(default.overlap_score_weight), router_temperature: temperature.unwrap_or(default.router_temperature), use_kv_events: use_kv_events.unwrap_or(default.use_kv_events), + router_replica_sync: replica_sync.unwrap_or(default.router_replica_sync), max_num_batched_tokens: max_num_batched_tokens .unwrap_or(default.max_num_batched_tokens), } @@ -135,10 +149,6 @@ pub struct KvRouter { scheduler: KvScheduler, block_size: u32, - - // To ensure blocking reads / writes - // TODO: benchmark tradeoffs - find_best_match_mutex: Mutex<()>, } impl KvRouter { @@ -146,8 +156,10 @@ impl KvRouter { component: Component, block_size: u32, selector: Option>, - use_kv_events: bool, + kv_router_config: Option, ) -> Result { + let kv_router_config = kv_router_config.unwrap_or_default(); + let cancellation_token = component .drt() .primary_lease() @@ -164,7 +176,7 @@ impl KvRouter { } }; - let indexer = if use_kv_events { + let indexer = if kv_router_config.use_kv_events { Indexer::KvIndexer(KvIndexer::new(cancellation_token.clone(), block_size)) } else { // hard code 120 seconds for now @@ -176,10 +188,11 @@ impl KvRouter { }; let scheduler = KvScheduler::start( - component.namespace().clone(), + component.clone(), block_size, instances_rx, selector, + kv_router_config.router_replica_sync, ) .await?; @@ -215,7 +228,6 @@ impl KvRouter { indexer, scheduler, block_size, - find_best_match_mutex: Mutex::new(()), // Add this }) } @@ -227,10 +239,6 @@ impl KvRouter { context_id: &str, tokens: &[u32], ) -> anyhow::Result<(i64, u32)> { - // Acquire mutex to serialize access - // TODO: may as well make all the subroutines synchronous if benchmarking favors this - let _guard = self.find_best_match_mutex.lock().await; - let isl_tokens = tokens.len(); let block_hashes = compute_block_hash_for_seq(tokens, self.block_size); @@ -263,17 +271,14 @@ impl KvRouter { Ok((best_worker_id, overlap_amount)) } - /// Free all blocks associated with a request - pub async fn mark_prefill_completed(&self, request_id: &String) { + pub async fn mark_prefill_completed(&self, request_id: &str) { self.scheduler.mark_prefill_completed(request_id).await } - /// Free all blocks associated with a request - pub async fn free(&self, request_id: &String) { + pub async fn free(&self, request_id: &str) { self.scheduler.free(request_id).await } - /// Get the block size this router was configured with pub fn block_size(&self) -> u32 { self.block_size } diff --git a/lib/llm/src/kv_router/metrics_aggregator.rs b/lib/llm/src/kv_router/metrics_aggregator.rs index 018f6cf84a..7ab4e1372c 100644 --- a/lib/llm/src/kv_router/metrics_aggregator.rs +++ b/lib/llm/src/kv_router/metrics_aggregator.rs @@ -18,7 +18,7 @@ use std::sync::Once; pub use crate::kv_router::protocols::{ForwardPassMetrics, LoadMetrics, PredictiveLoadMetrics}; use crate::kv_router::KV_METRICS_ENDPOINT; -use crate::kv_router::scheduler::Endpoint; +use crate::kv_router::scoring::Endpoint; use crate::kv_router::ProcessedEndpoints; use dynamo_runtime::component::Component; use dynamo_runtime::{service::EndpointInfo, utils::Duration, Result}; diff --git a/lib/llm/src/kv_router/prefill_counter.rs b/lib/llm/src/kv_router/prefill_counter.rs new file mode 100644 index 0000000000..5052faf752 --- /dev/null +++ b/lib/llm/src/kv_router/prefill_counter.rs @@ -0,0 +1,545 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use anyhow::Result; +use dynamo_runtime::component::Component; +use dynamo_runtime::traits::events::{EventPublisher, EventSubscriber}; +use futures::StreamExt; +use std::sync::Arc; +use uuid::Uuid; + +use super::protocols::{PrefillEvent, PrefillEventData}; +use crate::kv_router::PREFILL_SUBJECT; +use dashmap::DashMap; +use std::collections::HashMap; +use std::hash::Hash; + +pub fn get_snapshot(state: &DashMap) -> HashMap +where + K: Clone + Hash + Eq, + V: Copy, +{ + state + .iter() + .map(|entry| (entry.key().clone(), *entry.value())) + .collect() +} + +#[derive(Default)] +struct PrefillCounterState { + tokens_map: HashMap, // Plain HashMap + running_sum: usize, // Plain usize +} + +impl PrefillCounterState { + fn insert(&mut self, key: String, value: usize) -> Option { + // Takes &mut self + let old_value = self.tokens_map.insert(key, value); + + if let Some(old) = old_value { + self.running_sum -= old; + self.running_sum += value; + } else { + self.running_sum += value; + } + + old_value + } + + fn remove(&mut self, key: &str) -> Option { + // Takes &mut self + let removed = self.tokens_map.remove(key); + + if let Some(value) = removed { + self.running_sum -= value; + } + + removed + } + + fn running_sum(&self) -> usize { + self.running_sum + } +} + +/// A counter that tracks pending prefill tokens for each request. +/// +/// This struct maintains a local hashmap of request_id to token count, +/// and a running sum of all tokens. It no longer handles its own subscriptions. +#[derive(Default)] // Removed Clone +pub struct PrefillCounter { + state: PrefillCounterState, // No Arc, direct ownership +} + +impl PrefillCounter { + // Internal methods for direct state manipulation (no publishing) + fn insert_direct(&mut self, request_id: String, tokens: usize) -> Option { + // Takes &mut self + self.state.insert(request_id, tokens) + } + + fn remove_direct(&mut self, request_id: &str) -> Option { + // Takes &mut self + self.state.remove(request_id) + } + + #[allow(dead_code)] + fn update_direct(&mut self, request_id: String, new_tokens: usize) { + // Takes &mut self + if let Some(old_tokens) = self.state.tokens_map.get(&request_id).copied() { + let delta = new_tokens as isize - old_tokens as isize; + self.state.running_sum = (self.state.running_sum as isize + delta) as usize; + self.state.tokens_map.insert(request_id, new_tokens); + } + } + + pub fn get(&self, request_id: &str) -> Option { + self.state.tokens_map.get(request_id).copied() + } + + pub fn running_sum(&self) -> usize { + self.state.running_sum() + } + + pub fn len(&self) -> usize { + self.state.tokens_map.len() + } + + pub fn is_empty(&self) -> bool { + self.state.tokens_map.is_empty() + } +} + +/// A collection of PrefillCounters for multiple workers with centralized event handling +pub struct PrefillCountersMultiWorker { + pub counters: Arc>, + pub request_to_workers: Arc>, + component: Component, + router_id: Uuid, +} + +impl PrefillCountersMultiWorker { + // Helper function to handle new prefill logic + fn handle_new_prefill( + counters: &Arc>, + request_to_workers: &Arc>, + request_id: &str, + worker_id: i64, + tokens: usize, + ) { + // Check if request already exists + if let Some(existing_worker_id) = request_to_workers.get(request_id) { + tracing::warn!( + "Request {} already exists for worker {}, but trying to add to worker {}", + request_id, + *existing_worker_id, + worker_id + ); + } + + // Update mapping + request_to_workers.insert(request_id.to_string(), worker_id); + + // Get or create counter and insert using get_mut + if let Some(mut counter) = counters.get_mut(&worker_id) { + counter.insert_direct(request_id.to_string(), tokens); + } else { + tracing::warn!( + "Worker {} does not exist, creating new PrefillCounter", + worker_id + ); + let mut new_counter = PrefillCounter::default(); + new_counter.insert_direct(request_id.to_string(), tokens); + counters.insert(worker_id, new_counter); + }; + } + + // Helper function to handle complete prefill logic + fn handle_complete_prefill( + counters: &Arc>, + request_to_workers: &Arc>, + request_id: &str, + ) -> Option { + // Remove from request_to_workers and get the worker_id + let Some((_, worker_id)) = request_to_workers.remove(request_id) else { + tracing::warn!("Request {} not found in request_to_workers", request_id); + return None; + }; + + // Use the worker_id from request_to_workers with get_mut + let Some(mut counter) = counters.get_mut(&worker_id) else { + tracing::warn!( + "No counter found for worker {} for request {}", + worker_id, + request_id + ); + return None; + }; + + let removed_tokens = counter.remove_direct(request_id); + if removed_tokens.is_none() { + tracing::warn!("Attempted to remove non-existent request: {}", request_id); + } + + removed_tokens + } + + pub fn new(component: Component) -> Self { + let counters = Arc::new(DashMap::new()); + let request_to_workers = Arc::new(DashMap::new()); + let router_id = Uuid::new_v4(); + + let multi_worker = Self { + counters: counters.clone(), + request_to_workers: request_to_workers.clone(), + component: component.clone(), + router_id, + }; + + // Start the subscription loop + let counters_clone = counters.clone(); + let request_to_workers_clone = request_to_workers.clone(); + let component_clone = component.clone(); + let router_id_clone = router_id; + + tokio::spawn(async move { + if let Err(e) = Self::subscribe_to_events( + counters_clone, + request_to_workers_clone, + component_clone, + router_id_clone, + ) + .await + { + tracing::error!("Error in prefill events subscription: {}", e); + } + }); + + multi_worker + } + + /// Background task to subscribe to prefill events and update all counters + async fn subscribe_to_events( + counters: Arc>, + request_to_workers: Arc>, + component: Component, + router_id: Uuid, + ) -> Result<()> { + let mut subscriber = component + .subscribe_with_type::(PREFILL_SUBJECT) + .await?; + + while let Some(result) = subscriber.next().await { + let Ok(event) = result else { + tracing::error!("Error receiving prefill event: {}", result.unwrap_err()); + continue; + }; + + // Skip events emitted by itself + if event.router_id == router_id { + continue; + } + + match event.data { + PrefillEventData::NewPrefill(tokens) => { + Self::handle_new_prefill( + &counters, + &request_to_workers, + &event.request_id, + event.worker_id, + tokens, + ); + } + PrefillEventData::UpdatePrefill(_) => { + // Do nothing for now + continue; + } + PrefillEventData::CompletePrefill => { + Self::handle_complete_prefill( + &counters, + &request_to_workers, + &event.request_id, + ); + } + } + } + + Ok(()) + } + + pub async fn add_prefill( + &self, + worker_id: i64, + request_id: String, + new_tokens: usize, + ) -> Result<()> { + let event = PrefillEvent { + request_id: request_id.clone(), + worker_id, + data: PrefillEventData::NewPrefill(new_tokens), + router_id: self.router_id, + }; + self.component.publish(PREFILL_SUBJECT, &event).await?; + + // Use the helper function + Self::handle_new_prefill( + &self.counters, + &self.request_to_workers, + &request_id, + worker_id, + new_tokens, + ); + + Ok(()) + } + + pub async fn remove_prefill(&self, request_id: &str) -> Result> { + // Send the event first with dummy worker_id + let event = PrefillEvent { + request_id: request_id.to_string(), + worker_id: 0, // Dummy worker_id + data: PrefillEventData::CompletePrefill, + router_id: self.router_id, + }; + self.component.publish(PREFILL_SUBJECT, &event).await?; + + // Use the helper function + Ok(Self::handle_complete_prefill( + &self.counters, + &self.request_to_workers, + request_id, + )) + } + + /// Get the running sums for all workers as a HashMap + pub async fn running_sums(&self) -> HashMap { + self.counters + .iter() + .map(|entry| (*entry.key(), entry.value().running_sum())) + .collect() + } + + /// Get a specific counter's running sum + pub async fn get_worker_sum(&self, worker_id: i64) -> Option { + self.counters.get(&worker_id).map(|c| c.running_sum()) + } +} + +#[cfg(test)] +mod integration_tests { + use super::*; + use dynamo_runtime::{DistributedRuntime, Runtime}; + use std::sync::{Arc, Mutex}; + use std::thread; + use tokio::time::Duration; + + #[test] + #[ignore] + fn test_prefill_counter_multiworker_synchronization() -> Result<()> { + // Initialize logging once + dynamo_runtime::logging::init(); + + let worker_id_1 = 1; + let worker_id_2 = 2; + let tokens_per_request = 100; + let requests_per_worker = 10; + + // Shared state for collecting results from both threads + let results1 = Arc::new(Mutex::new(None)); + let results2 = Arc::new(Mutex::new(None)); + let final_results1 = Arc::new(Mutex::new(None)); + let final_results2 = Arc::new(Mutex::new(None)); + + let results1_clone = results1.clone(); + let results2_clone = results2.clone(); + let final_results1_clone = final_results1.clone(); + let final_results2_clone = final_results2.clone(); + + // Thread 1: First distributed runtime with multi_worker1 + let handle1 = thread::spawn(move || { + let rt = tokio::runtime::Runtime::new().unwrap(); + + rt.block_on(async { + // Create runtime and distributed runtime + let runtime = Runtime::from_current()?; + let distributed = DistributedRuntime::from_settings(runtime.clone()).await?; + + // Create namespace and components with same names + let namespace = distributed.namespace("test_prefill_multiworker")?; + let component = namespace + .component("counters")? + .service_builder() + .create() + .await?; + + // Create first PrefillCountersMultiWorker instance + let multi_worker1 = PrefillCountersMultiWorker::new(component); + + // Give some time for subscribers to initialize + tokio::time::sleep(Duration::from_millis(3000)).await; + + // Send requests to multi_worker1's worker + for i in 0..requests_per_worker { + let request_id = format!("mw1_request_{}", i); + multi_worker1 + .add_prefill(worker_id_1, request_id, tokens_per_request) + .await?; + } + + // Wait for synchronization + tokio::time::sleep(Duration::from_millis(1000)).await; + + // Get running sums after additions + let sums1 = multi_worker1.running_sums().await; + *results1_clone.lock().unwrap() = Some(sums1); + + // Wait for other thread to add its requests + tokio::time::sleep(Duration::from_millis(2000)).await; + + // Remove all requests from multi_worker1 + for i in 0..requests_per_worker { + let request_id = format!("mw1_request_{}", i); + multi_worker1.remove_prefill(&request_id).await?; + } + + // Wait for removal synchronization + tokio::time::sleep(Duration::from_millis(1000)).await; + + // Get final running sums + let final_sums1 = multi_worker1.running_sums().await; + *final_results1_clone.lock().unwrap() = Some(final_sums1); + + // Keep runtime alive a bit longer for synchronization + tokio::time::sleep(Duration::from_millis(1000)).await; + + // Shutdown runtime + runtime.shutdown(); + + Ok::<(), anyhow::Error>(()) + }) + }); + + // Thread 2: Second distributed runtime with multi_worker2 + let handle2 = thread::spawn(move || { + let rt = tokio::runtime::Runtime::new().unwrap(); + + rt.block_on(async { + // Create runtime and distributed runtime + let runtime = Runtime::from_current()?; + let distributed = DistributedRuntime::from_settings(runtime.clone()).await?; + + // Create namespace and components with same names + let namespace = distributed.namespace("test_prefill_multiworker")?; + let component = namespace + .component("counters")? + .service_builder() + .create() + .await?; + + // Create second PrefillCountersMultiWorker instance + let multi_worker2 = PrefillCountersMultiWorker::new(component); + + // Give some time for subscribers to initialize + tokio::time::sleep(Duration::from_millis(3000)).await; + + // Wait a bit to ensure multi_worker1 has started + tokio::time::sleep(Duration::from_millis(500)).await; + + // Send requests to multi_worker2's worker + for i in 0..requests_per_worker { + let request_id = format!("mw2_request_{}", i); + multi_worker2 + .add_prefill(worker_id_2, request_id, tokens_per_request) + .await?; + } + + // Wait for synchronization + tokio::time::sleep(Duration::from_millis(1000)).await; + + // Get running sums after additions + let sums2 = multi_worker2.running_sums().await; + *results2_clone.lock().unwrap() = Some(sums2); + + // Wait for other thread to remove its requests + tokio::time::sleep(Duration::from_millis(2000)).await; + + // Remove all requests from multi_worker2 + for i in 0..requests_per_worker { + let request_id = format!("mw2_request_{}", i); + multi_worker2.remove_prefill(&request_id).await?; + } + + // Wait for removal synchronization + tokio::time::sleep(Duration::from_millis(1000)).await; + + // Get final running sums + let final_sums2 = multi_worker2.running_sums().await; + *final_results2_clone.lock().unwrap() = Some(final_sums2); + + // Keep runtime alive a bit longer for synchronization + tokio::time::sleep(Duration::from_millis(1000)).await; + + // Shutdown runtime + runtime.shutdown(); + + Ok::<(), anyhow::Error>(()) + }) + }); + + // Wait for both threads to complete + handle1.join().unwrap()?; + handle2.join().unwrap()?; + + // Extract results + let sums1 = results1.lock().unwrap().take().unwrap(); + let sums2 = results2.lock().unwrap().take().unwrap(); + let final_sums1 = final_results1.lock().unwrap().take().unwrap(); + let final_sums2 = final_results2.lock().unwrap().take().unwrap(); + + // Verify both multi-workers see all requests + assert_eq!( + sums1.get(&worker_id_1), + Some(&(requests_per_worker * tokens_per_request)), + "MultiWorker1 should see worker 1's requests" + ); + assert_eq!( + sums1.get(&worker_id_2), + Some(&(requests_per_worker * tokens_per_request)), + "MultiWorker1 should see worker 2's requests" + ); + assert_eq!( + sums2.get(&worker_id_1), + Some(&(requests_per_worker * tokens_per_request)), + "MultiWorker2 should see worker 1's requests" + ); + assert_eq!( + sums2.get(&worker_id_2), + Some(&(requests_per_worker * tokens_per_request)), + "MultiWorker2 should see worker 2's requests" + ); + + // Verify both multi-workers show zero sums after removal + assert_eq!( + final_sums1.get(&worker_id_1).copied().unwrap_or(0), + 0, + "MultiWorker1 should show zero for worker 1" + ); + assert_eq!( + final_sums1.get(&worker_id_2).copied().unwrap_or(0), + 0, + "MultiWorker1 should show zero for worker 2" + ); + assert_eq!( + final_sums2.get(&worker_id_1).copied().unwrap_or(0), + 0, + "MultiWorker2 should show zero for worker 1" + ); + assert_eq!( + final_sums2.get(&worker_id_2).copied().unwrap_or(0), + 0, + "MultiWorker2 should show zero for worker 2" + ); + + Ok(()) + } +} diff --git a/lib/llm/src/kv_router/protocols.rs b/lib/llm/src/kv_router/protocols.rs index e6429f3909..24f33bd0d1 100644 --- a/lib/llm/src/kv_router/protocols.rs +++ b/lib/llm/src/kv_router/protocols.rs @@ -1,20 +1,9 @@ // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use crate::tokens::Token; + +use crate::tokens::{SequenceHash, Token}; use serde::{Deserialize, Serialize}; +use uuid::Uuid; #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct RouterRequest { @@ -128,6 +117,56 @@ impl From for ExternalSequenceBlockHash { } } +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct PrefillEvent { + pub request_id: String, + pub worker_id: i64, + pub data: PrefillEventData, + pub router_id: Uuid, +} + +/// Represents the different stages of prefilling tokens for a request. +/// +/// Each variant contains a `usize` representing the number of tokens +/// that are pending prefill in the request. +#[derive(Serialize, Deserialize, Debug, Clone)] +pub enum PrefillEventData { + NewPrefill(usize), + UpdatePrefill(usize), + CompletePrefill, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ActiveSequenceEvent { + pub request_id: String, + pub worker_id: i64, + pub data: ActiveSequenceEventData, + pub router_id: Uuid, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub enum ActiveSequenceEventData { + AddRequest { + token_sequence: Vec, + isl: usize, + overlap: u32, + }, + Free, + MarkPrefillCompleted, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ActiveBlockEvent { + pub request_id: String, + pub data: ActiveBlockEventData, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub enum ActiveBlockEventData { + NewBlock(Vec), + FreeBlock, +} + /// Represents a collection of cache events and a shutdown flag. #[derive(Serialize, Deserialize, Debug, Clone)] pub struct KvCacheEvents { diff --git a/lib/llm/src/kv_router/publisher.rs b/lib/llm/src/kv_router/publisher.rs index 2f75b59c06..22c88b08d0 100644 --- a/lib/llm/src/kv_router/publisher.rs +++ b/lib/llm/src/kv_router/publisher.rs @@ -16,12 +16,13 @@ use crate::kv_router::{ indexer::{compute_block_hash_for_seq, RouterEvent}, protocols::*, - KV_EVENT_SUBJECT, KV_METRICS_ENDPOINT, + scoring::LoadEvent, + KV_EVENT_SUBJECT, KV_METRICS_ENDPOINT, KV_METRICS_SUBJECT, }; use async_trait::async_trait; use dynamo_runtime::traits::{events::EventPublisher, DistributedRuntimeProvider}; use dynamo_runtime::{ - component::Component, + component::{Component, Namespace}, pipeline::{ network::Ingress, AsyncEngine, AsyncEngineContextProvider, ManyOut, ResponseStream, SingleIn, @@ -499,9 +500,18 @@ impl WorkerMetricsPublisher { pub async fn create_endpoint(&self, component: Component) -> Result<()> { let mut metrics_rx = self.rx.clone(); - let handler = Arc::new(KvLoadEndpoingHander::new(metrics_rx.clone())); + let handler = Arc::new(KvLoadEndpointHandler::new(metrics_rx.clone())); let handler = Ingress::for_engine(handler)?; + // let worker_id = component + // .drt() + // .primary_lease() + // .map(|lease| lease.id()) + // .unwrap_or_else(|| { + // tracing::warn!("Component is static, assuming worker_id of 0"); + // 0 + // }); + component .endpoint(KV_METRICS_ENDPOINT) .endpoint_builder() @@ -513,13 +523,90 @@ impl WorkerMetricsPublisher { .start() .await } + + /// Starts a background task to publish metrics over NATS + /// + /// This task monitors metric changes (specifically kv_active_blocks and num_requests_waiting) + /// and publishes stable metrics to NATS after they've been unchanged for 1ms. + #[allow(dead_code)] + fn start_nats_metrics_publishing(&self, namespace: Namespace, worker_id: i64) { + let nats_rx = self.rx.clone(); + + tokio::spawn(async move { + let mut rx = nats_rx; + let mut last_kv_active_blocks: Option = None; + let mut last_num_requests_waiting: Option = None; + let mut pending_publish: Option> = None; + let mut publish_timer = + Box::pin(tokio::time::sleep(tokio::time::Duration::from_secs(0))); + publish_timer.as_mut().reset(tokio::time::Instant::now()); // Complete immediately + + loop { + tokio::select! { + // Handle metrics changes + result = rx.changed() => { + if result.is_err() { + tracing::debug!( + "Metrics publisher sender dropped, stopping NATS background task" + ); + break; + } + + let metrics = rx.borrow_and_update().clone(); + + // Extract the values we care about + let current_kv_active_blocks = metrics.kv_stats.kv_active_blocks; + let current_num_requests_waiting = + metrics.worker_stats.num_requests_waiting; + + // Check if these specific metrics have changed + let has_changed = match (last_kv_active_blocks, last_num_requests_waiting) { + (Some(last_kv), Some(last_requests)) => { + last_kv != current_kv_active_blocks + || last_requests != current_num_requests_waiting + } + _ => true, // First time, consider it changed + }; + + // If load metrics changed, schedule a publish + if has_changed { + pending_publish = Some(metrics.clone()); + last_kv_active_blocks = Some(current_kv_active_blocks); + last_num_requests_waiting = Some(current_num_requests_waiting); + + // Start the 1ms timer + publish_timer.as_mut().reset( + tokio::time::Instant::now() + tokio::time::Duration::from_millis(1) + ); + } + } + // Timer expired - publish if we have pending metrics + _ = &mut publish_timer => { + if let Some(metrics) = pending_publish.take() { + // Create LoadEvent wrapping the metrics + let load_event = LoadEvent { + worker_id, + data: (*metrics).clone(), + }; + + if let Err(e) = + namespace.publish(KV_METRICS_SUBJECT, &load_event).await + { + tracing::warn!("Failed to publish metrics over NATS: {}", e); + } + } + } + } + } + }); + } } -struct KvLoadEndpoingHander { +struct KvLoadEndpointHandler { metrics_rx: tokio::sync::watch::Receiver>, } -impl KvLoadEndpoingHander { +impl KvLoadEndpointHandler { pub fn new(metrics_rx: tokio::sync::watch::Receiver>) -> Self { Self { metrics_rx } } @@ -527,7 +614,7 @@ impl KvLoadEndpoingHander { #[async_trait] impl AsyncEngine, ManyOut>, Error> - for KvLoadEndpoingHander + for KvLoadEndpointHandler { async fn generate( &self, @@ -880,3 +967,116 @@ mod test_exponential_backoff { assert!(max_calculated <= MAX_BACKOFF_MS); } } + +#[cfg(test)] +mod test_worker_metrics_publisher { + use super::*; + use crate::kv_router::protocols::{ForwardPassMetrics, KvStats, WorkerStats}; + use dynamo_runtime::traits::events::EventSubscriber; // Add this import + use dynamo_runtime::{DistributedRuntime, Runtime}; + use futures::StreamExt; + + #[tokio::test] + #[ignore] // Mark as ignored as requested + async fn test_metrics_publishing_behavior() -> Result<()> { + // Set up runtime and namespace + let rt = Runtime::from_current().unwrap(); + let drt = DistributedRuntime::from_settings(rt.clone()).await?; + let namespace = drt.namespace("test".to_string())?; + + // Create a subscriber for the metrics events using subscribe_with_type + let mut subscriber = namespace + .subscribe_with_type::(KV_METRICS_SUBJECT) + .await + .unwrap(); + + // Create WorkerMetricsPublisher + let publisher = WorkerMetricsPublisher::new().unwrap(); + let worker_id = 1234; + + // Start NATS metrics publishing + publisher.start_nats_metrics_publishing(namespace.clone(), worker_id); + + // Allow some time for the background task to start + tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + + // Test 1: Publish 10 different metrics with 0.5ms intervals + // Only the last one should be published after 1ms of stability + for i in 0..10 { + let metrics = Arc::new(ForwardPassMetrics { + kv_stats: KvStats { + kv_active_blocks: (i * 100) as u64, // Changing load metric + kv_total_blocks: 1000, + gpu_cache_usage_perc: 0.5, + gpu_prefix_cache_hit_rate: 0.8, + }, + worker_stats: WorkerStats { + num_requests_waiting: (i * 10) as u64, // Changing load metric + data_parallel_rank: None, + request_active_slots: 50, + request_total_slots: 100, + }, + spec_decode_stats: None, + }); + + publisher.publish(metrics).unwrap(); + tokio::time::sleep(tokio::time::Duration::from_micros(100)).await; + } + + // Wait a bit more than 1ms to ensure the last metric is published + tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + + // Verify we receive exactly one event with the last metric values + let result = + tokio::time::timeout(tokio::time::Duration::from_millis(500), subscriber.next()) + .await + .unwrap(); + + let event = result.unwrap().unwrap(); // Unwrap the Option and the Result + assert_eq!(event.worker_id, worker_id); + assert_eq!(event.data.kv_stats.kv_active_blocks, 900); // Last value: 9 * 100 + assert_eq!(event.data.worker_stats.num_requests_waiting, 90); // Last value: 9 * 10 + + // Ensure no more events are waiting + let no_msg = + tokio::time::timeout(tokio::time::Duration::from_millis(50), subscriber.next()).await; + assert!(no_msg.is_err(), "Expected no more messages, but found one"); + + // Test 2: Publish 10 more metrics where everything changes EXCEPT the load metrics + for i in 0..10 { + let metrics = Arc::new(ForwardPassMetrics { + kv_stats: KvStats { + kv_active_blocks: 900, // Keep same as last published + kv_total_blocks: 1000 + (i * 100) as u64, // Change other metrics + gpu_cache_usage_perc: 0.3 + (i as f32 * 0.05), // Change other metrics + gpu_prefix_cache_hit_rate: 0.7 + (i as f32 * 0.01), // Change other metrics + }, + worker_stats: WorkerStats { + num_requests_waiting: 90, // Keep same as last published + data_parallel_rank: None, + request_active_slots: 40 + (i * 5) as u64, // Change other metrics + request_total_slots: 100 + (i * 10) as u64, // Change other metrics + }, + spec_decode_stats: None, + }); + + publisher.publish(metrics).unwrap(); + tokio::time::sleep(tokio::time::Duration::from_micros(100)).await; + } + + // Wait to ensure no events are published + tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + + // Verify no events are received + let no_msg = + tokio::time::timeout(tokio::time::Duration::from_millis(50), subscriber.next()).await; + assert!( + no_msg.is_err(), + "Expected no messages when load metrics don't change" + ); + + rt.shutdown(); + + Ok(()) + } +} diff --git a/lib/llm/src/kv_router/scheduler.rs b/lib/llm/src/kv_router/scheduler.rs index 797f738187..6603b0e906 100644 --- a/lib/llm/src/kv_router/scheduler.rs +++ b/lib/llm/src/kv_router/scheduler.rs @@ -1,36 +1,22 @@ // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use dynamo_runtime::component::Namespace; + +use dynamo_runtime::component::{Component, Instance}; use dynamo_runtime::traits::events::EventPublisher; use rand::Rng; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; -use tokio::sync::Mutex; +use super::indexer::OverlapScores; use super::protocols::WorkerSelectionResult; +use super::sequence::ActiveSequencesMultiWorker; +use super::KvRouterConfig; use super::WorkerSelector; -use crate::kv_router::indexer::OverlapScores; -use crate::kv_router::protocols::LoadMetrics; -use crate::kv_router::sequence::ActiveSequencesMultiWorker; -use crate::kv_router::KvRouterConfig; -use crate::kv_router::KV_HIT_RATE_SUBJECT; +use super::KV_HIT_RATE_SUBJECT; + use crate::tokens::SequenceHash; -use dynamo_runtime::component::Instance; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct KVHitRateEvent { @@ -51,155 +37,161 @@ pub enum KvSchedulerError { SubscriberShutdown, } -/// [gluo FIXME] exactly the same as EndpointInfo except that 'data' -/// is cleaned (not optional) -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub struct Endpoint { - pub name: String, - pub subject: String, - pub data: LoadMetrics, -} - -impl Endpoint { - pub fn worker_id(&self) -> i64 { - i64::from_str_radix( - self.subject - .split("-") - .last() - .expect("invalid subject") - .to_string() - .as_str(), - 16, - ) - .expect("invalid worker id") - } -} - #[derive(Debug)] pub struct SchedulingResponse { pub best_worker_id: i64, - pub overlap_blocks: u32, // Add this field - pub endpoints_changed: Option>, + pub overlap_blocks: u32, } pub struct SchedulingRequest { + pub request_id: String, + pub token_seq: Vec, pub isl_tokens: usize, pub overlaps: OverlapScores, - pub potential_blocks: HashMap, - pub potential_tokens: HashMap, - resp_tx: tokio::sync::oneshot::Sender, + pub decode_blocks: HashMap, + pub prefill_tokens: HashMap, + // Option to take it out to send the response without moving the struct + resp_tx: Option>, } impl SchedulingRequest { - pub fn respond(self, response: SchedulingResponse) { - if self.resp_tx.send(response).is_err() { - tracing::error!("failed to send response to requestor"); + pub fn respond(&mut self, response: SchedulingResponse) { + // Changed to &mut self + if let Some(tx) = self.resp_tx.take() { + // Use take() to extract the sender + if tx.send(response).is_err() { + tracing::error!("failed to send response to requestor"); + } + } else { + tracing::error!("respond called multiple times on same request"); } } } pub struct KvScheduler { request_tx: tokio::sync::mpsc::Sender, - sequences: Arc>, + slots: Arc, } impl KvScheduler { pub async fn start( - ns: Namespace, + component: Component, block_size: u32, mut instances_rx: tokio::sync::watch::Receiver>, // Changed from ProcessedEndpoints selector: Option>, + replica_sync: bool, ) -> Result { let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default())); let mut instances: Vec = instances_rx.borrow_and_update().clone(); - // Get worker IDs from instances - let worker_ids: Vec = instances.iter().map(|i| i.instance_id).collect(); - let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel::(); + let ns_clone = component.namespace().clone(); tokio::spawn(async move { let mut event_rx = event_rx; while let Some(event) = event_rx.recv().await { - if let Err(e) = ns.publish(KV_HIT_RATE_SUBJECT, &event).await { + if let Err(e) = ns_clone.publish(KV_HIT_RATE_SUBJECT, &event).await { tracing::warn!("Failed to publish KV hit rate event: {:?}", e); } } }); - let sequences = Arc::new(Mutex::new(ActiveSequencesMultiWorker::new( + let worker_ids: Vec = instances + .iter() + .map(|instance| instance.instance_id) + .collect(); + let slots = Arc::new(ActiveSequencesMultiWorker::new( + component, block_size as usize, worker_ids, - ))); + replica_sync, + )); - // Channel to accept new scheduling requests + let slots_clone = slots.clone(); let (request_tx, request_rx) = tokio::sync::mpsc::channel::(1024); // Background task to handle scheduling requests tokio::spawn(async move { - let mut request: SchedulingRequest; let mut request_rx = request_rx; - let mut pending_endpoint_update: Option> = None; tracing::trace!("scheduler background task started"); - 'outer: loop { - request = tokio::select! { - biased; - - _ = instances_rx.changed() => { + loop { + // First, check for instance updates (non-blocking) + match instances_rx.has_changed() { + Ok(true) => { instances = instances_rx.borrow_and_update().clone(); - let worker_ids: Vec = instances.iter().map(|i| i.instance_id).collect(); - pending_endpoint_update = Some(worker_ids); - continue 'outer; + let worker_ids: Vec = instances + .iter() + .map(|instance| instance.instance_id) + .collect(); + slots_clone.update_workers(worker_ids); } - - maybe_new_request = request_rx.recv() => { - let Some(new_request) = maybe_new_request else { - tracing::warn!("scheduler shutdown"); - break 'outer; - }; - tracing::trace!("received request to be scheduled"); - new_request + Ok(false) => { + // No changes, continue. This is the happy path. } - }; + Err(_) => { + tracing::warn!("endpoint watch sender shutdown"); + break; + } + } - loop { - // When calling selector.select_worker, we need to adapt - match selector.select_worker(&instances, &request, block_size) { - Ok(selection) => { - if let Err(e) = event_tx.send(KVHitRateEvent { - worker_id: selection.worker_id, - isl_blocks: selection.required_blocks as usize, - overlap_blocks: selection.overlap_blocks, - }) { - tracing::warn!("Failed to send KV hit rate event: {:?}", e); - } - - let response = SchedulingResponse { - best_worker_id: selection.worker_id, - overlap_blocks: selection.overlap_blocks, - endpoints_changed: pending_endpoint_update.take(), - }; - request.respond(response); - continue 'outer; - } - Err(KvSchedulerError::NoEndpoints) => { - tracing::trace!("no endpoints available; waiting for endpoints update"); - instances_rx.changed().await.ok(); - instances = instances_rx.borrow_and_update().clone(); - let worker_ids: Vec = - instances.iter().map(|i| i.instance_id).collect(); - pending_endpoint_update = Some(worker_ids); - continue; - } - // TODO: this is not actually hooked up - Err(KvSchedulerError::AllWorkersBusy) => { - tracing::trace!("all workers busy; waiting for more capacity"); - tokio::time::sleep(Duration::from_millis(5)).await; - continue; - } - Err(e) => { - tracing::error!("error scheduling request: {:?}", e); - break 'outer; + // Then, wait for a new request + let Some(mut request) = request_rx.recv().await else { + tracing::warn!("scheduler shutdown"); + break; + }; + tracing::trace!("received request to be scheduled"); + + let (decode_blocks, prefill_tokens) = slots_clone + .potential_blocks_and_tokens( + request.token_seq.clone(), + request.isl_tokens, + request.overlaps.clone(), + ) + .await; + request.decode_blocks = decode_blocks; + request.prefill_tokens = prefill_tokens; + + match selector.select_worker(&instances, &request, block_size) { + Ok(selection) => { + if let Err(e) = event_tx.send(KVHitRateEvent { + worker_id: selection.worker_id, + isl_blocks: selection.required_blocks as usize, + overlap_blocks: selection.overlap_blocks, + }) { + tracing::warn!("Failed to send KV hit rate event: {:?}", e); } + + let response = SchedulingResponse { + best_worker_id: selection.worker_id, + overlap_blocks: selection.overlap_blocks, + }; + request.respond(response); + + let _ = slots_clone + .add_request( + request.request_id, + request.token_seq, + request.isl_tokens, + selection.overlap_blocks, + selection.worker_id, + ) + .await; + + continue; + } + Err(KvSchedulerError::NoEndpoints) => { + tracing::trace!("no endpoints available; waiting for endpoints update"); + tokio::time::sleep(Duration::from_millis(5)).await; + continue; + } + // TODO: this is not actually hooked up + Err(KvSchedulerError::AllWorkersBusy) => { + tracing::trace!("all workers busy; waiting for more capacity"); + tokio::time::sleep(Duration::from_millis(5)).await; + continue; + } + Err(e) => { + tracing::error!("error scheduling request: {:?}", e); + break; } } } @@ -207,10 +199,7 @@ impl KvScheduler { tracing::trace!("background endpoint subscriber shutting down"); }); - Ok(KvScheduler { - request_tx, - sequences, - }) + Ok(KvScheduler { request_tx, slots }) } pub async fn schedule( @@ -220,19 +209,17 @@ impl KvScheduler { token_seq: Vec, overlaps: OverlapScores, ) -> Result { - let mut sequences = self.sequences.lock().await; - - let (potential_blocks, potential_tokens) = - sequences.potential_blocks_and_tokens(token_seq.clone(), isl_tokens, overlaps.clone()); - let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); let request = SchedulingRequest { + request_id, + token_seq, isl_tokens, overlaps, - potential_blocks, - potential_tokens, - resp_tx, + decode_blocks: HashMap::new(), + prefill_tokens: HashMap::new(), + resp_tx: Some(resp_tx), // Wrap in Some() }; + self.request_tx .send(request) .await @@ -241,30 +228,19 @@ impl KvScheduler { .await .map_err(|_| KvSchedulerError::SubscriberShutdown)?; - if let Some(new_worker_ids) = response.endpoints_changed { - sequences.update_workers(new_worker_ids); - } - - sequences.add_request( - request_id, - token_seq, - isl_tokens, - response.overlap_blocks, - response.best_worker_id, - ); - - Ok(response.best_worker_id) + let best_worker_id = response.best_worker_id; + Ok(best_worker_id) } - pub async fn mark_prefill_completed(&self, request_id: &String) { - let mut sequences = self.sequences.lock().await; - sequences.mark_prefill_completed(request_id) + pub async fn mark_prefill_completed(&self, request_id: &str) { + let _ = self + .slots + .mark_prefill_completed(&request_id.to_string()) + .await; } - /// Free all blocks associated with a request - pub async fn free(&self, request_id: &String) { - let mut sequences = self.sequences.lock().await; - sequences.free(request_id) + pub async fn free(&self, request_id: &str) { + let _ = self.slots.free(&request_id.to_string()).await; } } @@ -307,8 +283,9 @@ fn softmax_sample(logits: &HashMap, temperature: f64) -> i64 { let normalized: Vec<_> = values .iter() .map(|&v| { - let norm = v / (max_val - min_val); // Lower is better, so negate + // Note we don't need to do actual min-max norm here, just off by an offset + let norm = v / (max_val - min_val); -norm }) .collect(); @@ -370,10 +347,8 @@ impl WorkerSelector for DefaultWorkerSelector { let request_blocks = isl.div_ceil(block_size as usize); let overlaps = &request.overlaps.scores; - // active blocks for decoding - let potential_active_blocks = &request.potential_blocks; - // active tokens in the batch (processed by the linear layers), mostly prefill tokens - let potential_active_tokens = &request.potential_tokens; + let decode_blocks = &request.decode_blocks; + let prefill_tokens = &request.prefill_tokens; let mut worker_logits = HashMap::new(); let mut max_logit = f64::NEG_INFINITY; @@ -381,52 +356,40 @@ impl WorkerSelector for DefaultWorkerSelector { // Calculate logits for each worker for instance in workers.iter() { let worker_id = instance.instance_id; - // this is the number of tokens each worker would have if the request were scheduled there - let potential_tokens = *potential_active_tokens.get(&worker_id).unwrap_or_else(|| { - tracing::warn!( - "assuming {isl} tokens for {worker_id}, as the endpoint does not exist yet" - ); - &isl - }) as f64; - - // this is the number of blocks each worker would have if the request were scheduled there - let potential_blocks = *potential_active_blocks.get(&worker_id).unwrap_or_else(|| - {tracing::warn!("assuming {request_blocks} decoding blocks for {worker_id}, as the endpoint does not exist yet"); - &request_blocks - }) as f64; - - let potential_prefill_blocks = potential_tokens / (block_size as f64); + let overlap = *overlaps.get(&worker_id).unwrap_or(&0); + + // this is the number of prefill tokens the worker would have if the request were scheduled there + let prefill_token = *prefill_tokens.get(&worker_id).unwrap_or(&isl); + let potential_prefill_block = (prefill_token as f64) / (block_size as f64); + + // this is the number of decode blocks the worker would have if the request were scheduled there + let decode_block = *decode_blocks + .get(&worker_id) + .unwrap_or(&(potential_prefill_block.floor() as usize)) + as f64; // Calculate logit (lower is better) - let logit = self.kv_router_config.overlap_score_weight * potential_prefill_blocks - + potential_blocks; + let logit = + self.kv_router_config.overlap_score_weight * potential_prefill_block + decode_block; max_logit = max_logit.max(logit); worker_logits.insert(worker_id, logit); + let overlap_weight = self.kv_router_config.overlap_score_weight; tracing::info!( - "Formula for {worker_id}: {logit:.3} = {:.1} * {potential_prefill_blocks:.3} + {potential_blocks:.3} (cached_blocks: {})", - self.kv_router_config.overlap_score_weight, - overlaps.get(&worker_id).unwrap_or(&0), + "Formula for {worker_id} with {overlap} cached blocks: {logit:.3} \ + = {overlap_weight:.1} * prefill_blocks + decode_blocks \ + = {overlap_weight:.1} * {potential_prefill_block:.3} + {decode_block:.3}" ); } - // Normalize by dividing by max value - if max_logit > 0.0 { - for logit in worker_logits.values_mut() { - *logit /= max_logit; - } - } - // Use softmax sampling to select worker let temperature = self.kv_router_config.router_temperature; let best_worker_id = softmax_sample(&worker_logits, temperature); - - let overlap_blocks = overlaps.get(&best_worker_id).copied().unwrap_or(0); let best_logit = worker_logits[&best_worker_id]; tracing::info!( - "Selected worker: {}, normalized logit: {:.3}", + "Selected worker: {}, logit: {:.3}", best_worker_id, best_logit ); @@ -434,7 +397,7 @@ impl WorkerSelector for DefaultWorkerSelector { Ok(WorkerSelectionResult { worker_id: best_worker_id, required_blocks: request_blocks as u64, - overlap_blocks, + overlap_blocks: overlaps.get(&best_worker_id).copied().unwrap_or(0), }) } } diff --git a/lib/llm/src/kv_router/scoring.rs b/lib/llm/src/kv_router/scoring.rs index c065a613ba..5934716c7c 100644 --- a/lib/llm/src/kv_router/scoring.rs +++ b/lib/llm/src/kv_router/scoring.rs @@ -15,10 +15,39 @@ //! Scoring functions for the KV router. +use super::protocols::{ForwardPassMetrics, LoadMetrics}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; -use crate::kv_router::scheduler::Endpoint; +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct LoadEvent { + pub worker_id: i64, + pub data: ForwardPassMetrics, +} + +/// [gluo FIXME] exactly the same as EndpointInfo except that 'data' +/// is cleaned (not optional) +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct Endpoint { + pub name: String, + pub subject: String, + pub data: LoadMetrics, +} + +impl Endpoint { + pub fn worker_id(&self) -> i64 { + i64::from_str_radix( + self.subject + .split("-") + .last() + .expect("invalid subject") + .to_string() + .as_str(), + 16, + ) + .expect("invalid worker id") + } +} #[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq)] pub struct ProcessedEndpoints { diff --git a/lib/llm/src/kv_router/sequence.rs b/lib/llm/src/kv_router/sequence.rs index 060eaf1ed3..86d3ca66a5 100644 --- a/lib/llm/src/kv_router/sequence.rs +++ b/lib/llm/src/kv_router/sequence.rs @@ -1,17 +1,5 @@ // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. //! KV Cache Sequence Management for LLM Inference //! @@ -37,11 +25,20 @@ use crate::kv_router::indexer::OverlapScores; use crate::kv_router::indexer::WorkerId; use crate::tokens::SequenceHash; +use anyhow::Result; +use dashmap::DashMap; use derive_getters::Getters; +use dynamo_runtime::component::Component; +use dynamo_runtime::traits::events::{EventPublisher, EventSubscriber}; +use dynamo_runtime::traits::DistributedRuntimeProvider; +use futures::StreamExt; use std::collections::{HashMap, HashSet}; -use std::sync::{mpsc, Arc}; -use std::thread; -use std::time::Duration; +use std::sync::Arc; +use uuid::Uuid; + +use super::protocols::{ActiveSequenceEvent, ActiveSequenceEventData}; +use crate::kv_router::ACTIVE_SEQUENCES_SUBJECT; +use dynamo_runtime::CancellationToken; // TODO: use the common request_id if it exists in the repo pub type RequestId = String; @@ -134,7 +131,7 @@ impl ActiveSequences { if let Some(tokens) = self.prefill_tokens.remove(request_id) { self.active_tokens = self .active_tokens - .checked_sub(tokens.saturating_sub(1)) // Keep 1 token for decoding + .checked_sub(tokens) .expect("active_tokens underflow"); } } @@ -171,11 +168,7 @@ impl ActiveSequences { /// Free all blocks associated with a request pub fn free(&mut self, request_id: &RequestId) -> usize { - // decoding has one active token - self.active_tokens = self - .active_tokens - .checked_sub(self.prefill_tokens.remove(request_id).unwrap_or(1)) - .expect("active_tokens < 0"); + self.mark_prefill_completed(request_id); let Some(token_seq) = self.active_seqs.get(request_id) else { tracing::warn!("Trying to free free non-existent request {request_id}"); @@ -207,127 +200,261 @@ enum UpdateSequences { }, NewBlocks { token_sequence: Arc>, - resp_tx: mpsc::SyncSender, + resp_tx: tokio::sync::oneshot::Sender, }, PotentialBlocks { token_sequence: Arc>, - resp_tx: mpsc::SyncSender, + resp_tx: tokio::sync::oneshot::Sender, }, PotentialBlocksAndTokens { token_sequence: Arc>, isl: usize, overlap: u32, - resp_tx: mpsc::SyncSender<(usize, usize)>, + resp_tx: tokio::sync::oneshot::Sender<(usize, usize)>, }, ActiveBlocks { - resp_tx: mpsc::SyncSender, + resp_tx: tokio::sync::oneshot::Sender, }, ActiveTokens { - resp_tx: mpsc::SyncSender, + resp_tx: tokio::sync::oneshot::Sender, }, Shutdown, } /// Multi-worker extension of ActiveSequences that distributes requests across multiple threads pub struct ActiveSequencesMultiWorker { - senders: HashMap>, - request_to_worker: HashMap, - handles: HashMap>, + senders: Arc>>, + request_to_worker: Arc>, + handles: Arc>>, block_size: usize, + component: Component, + router_id: Uuid, + replica_sync: bool, } impl ActiveSequencesMultiWorker { - pub fn new(block_size: usize, worker_ids: Vec) -> Self { + pub fn new( + component: Component, + block_size: usize, + worker_ids: Vec, + replica_sync: bool, + ) -> Self { assert!(block_size > 1, "block_size must be greater than 1"); - let mut senders = HashMap::new(); - let mut handles = HashMap::new(); + let senders = Arc::new(DashMap::new()); + let handles = Arc::new(DashMap::new()); + let request_to_worker = Arc::new(DashMap::new()); + let router_id = Uuid::new_v4(); for worker_id in worker_ids { - let (sender, handle) = Self::start_worker(block_size); + // Create a child cancellation token from the component's runtime + let cancel_token = component.drt().runtime().child_token(); + let (sender, handle) = Self::start_worker(block_size, cancel_token); senders.insert(worker_id, sender); handles.insert(worker_id, handle); } - Self { - senders, - request_to_worker: HashMap::new(), + let multi_worker = Self { + senders: senders.clone(), + request_to_worker: request_to_worker.clone(), handles, block_size, + component: component.clone(), + router_id, + replica_sync, + }; + + // Start the subscription loop only if replica_sync is enabled + if replica_sync { + let senders_clone = senders.clone(); + let request_to_worker_clone = request_to_worker.clone(); + let component_clone = component.clone(); + let router_id_clone = router_id; + + tokio::spawn(async move { + if let Err(e) = Self::subscribe_to_events( + senders_clone, + request_to_worker_clone, + component_clone, + router_id_clone, + ) + .await + { + tracing::error!("Error in active sequences events subscription: {}", e); + } + }); } + + multi_worker } - /// Helper method to start a worker thread - fn start_worker(block_size: usize) -> (mpsc::Sender, thread::JoinHandle<()>) { - let (request_tx, request_rx) = mpsc::channel::(); + /// Helper method to start a worker task + fn start_worker( + block_size: usize, + cancel_token: CancellationToken, // Add cancellation token parameter + ) -> ( + tokio::sync::mpsc::UnboundedSender, + tokio::task::JoinHandle<()>, + ) { + let (request_tx, mut request_rx) = tokio::sync::mpsc::unbounded_channel(); - let handle = thread::spawn(move || { + let handle = tokio::spawn(async move { let mut active_sequences = ActiveSequences::new(block_size); - while let Ok(command) = request_rx.recv() { - match command { - UpdateSequences::AddRequest { - request_id, - token_sequence, - isl, - overlap, - } => { - active_sequences.add_request(request_id, token_sequence, isl, overlap); - } - UpdateSequences::Free { request_id } => { - active_sequences.free(&request_id); - } - UpdateSequences::MarkPrefillCompleted { request_id } => { - active_sequences.mark_prefill_completed(&request_id); + loop { + tokio::select! { + // Handle incoming commands + command = request_rx.recv() => { + match command { + Some(command) => { + match command { + UpdateSequences::AddRequest { + request_id, + token_sequence, + isl, + overlap, + } => { + active_sequences.add_request(request_id, token_sequence, isl, overlap); + } + UpdateSequences::Free { request_id } => { + active_sequences.free(&request_id); + } + UpdateSequences::MarkPrefillCompleted { request_id } => { + active_sequences.mark_prefill_completed(&request_id); + } + UpdateSequences::NewBlocks { + token_sequence, + resp_tx, + } => { + let new_blocks = active_sequences.new_blocks(&token_sequence); + let _ = resp_tx.send(new_blocks); + } + UpdateSequences::PotentialBlocks { + token_sequence, + resp_tx, + } => { + let potential_blocks = active_sequences.potential_blocks(&token_sequence); + let _ = resp_tx.send(potential_blocks); + } + UpdateSequences::PotentialBlocksAndTokens { + token_sequence, + isl, + overlap, + resp_tx, + } => { + let potential_tokens = active_sequences.potential_blocks_and_tokens( + &token_sequence, + isl, + overlap, + ); + let _ = resp_tx.send(potential_tokens); + } + UpdateSequences::ActiveBlocks { resp_tx } => { + let active_blocks = active_sequences.active_blocks(); + let _ = resp_tx.send(active_blocks); + } + UpdateSequences::ActiveTokens { resp_tx } => { + let active_tokens = active_sequences.active_tokens(); + let _ = resp_tx.send(active_tokens); + } + UpdateSequences::Shutdown => { + break; + } + } + } + None => { + // Channel closed, exit + break; + } + } } - UpdateSequences::NewBlocks { - token_sequence, - resp_tx, - } => { - let new_blocks = active_sequences.new_blocks(&token_sequence); - let _ = resp_tx.send(new_blocks); - } - UpdateSequences::PotentialBlocks { - token_sequence, - resp_tx, - } => { - let potential_blocks = active_sequences.potential_blocks(&token_sequence); - let _ = resp_tx.send(potential_blocks); + // Handle cancellation + _ = cancel_token.cancelled() => { + tracing::debug!("Worker task cancelled"); + break; } - UpdateSequences::PotentialBlocksAndTokens { - token_sequence, - isl, - overlap, - resp_tx, - } => { - let potential_tokens = active_sequences.potential_blocks_and_tokens( - &token_sequence, - isl, - overlap, + } + } + }); + + (request_tx, handle) + } + + /// Background task to subscribe to active sequence events and update all workers + async fn subscribe_to_events( + senders: Arc>>, + request_to_worker: Arc>, + component: Component, + router_id: Uuid, + ) -> Result<()> { + let mut subscriber = component + .subscribe_with_type::(ACTIVE_SEQUENCES_SUBJECT) + .await?; + + while let Some(result) = subscriber.next().await { + let Ok(event) = result else { + tracing::error!( + "Error receiving active sequence event: {}", + result.unwrap_err() + ); + continue; + }; + + // Skip events emitted by itself + if event.router_id == router_id { + continue; + } + + match &event.data { + ActiveSequenceEventData::AddRequest { + token_sequence, + isl, + overlap, + } => { + request_to_worker.insert(event.request_id.clone(), event.worker_id); + + if let Some(sender) = senders.get(&event.worker_id) { + let _ = sender.send(UpdateSequences::AddRequest { + request_id: event.request_id.clone(), + token_sequence: token_sequence.clone(), + isl: *isl, + overlap: *overlap, + }); + } else { + tracing::warn!( + "Worker {} not found, cannot process AddRequest", + event.worker_id ); - let _ = resp_tx.send(potential_tokens); } - UpdateSequences::ActiveBlocks { resp_tx } => { - let active_blocks = active_sequences.active_blocks(); - let _ = resp_tx.send(active_blocks); - } - UpdateSequences::ActiveTokens { resp_tx } => { - let active_tokens = active_sequences.active_tokens(); - let _ = resp_tx.send(active_tokens); + } + ActiveSequenceEventData::Free => { + if let Some((_, worker_id)) = request_to_worker.remove(&event.request_id) { + if let Some(sender) = senders.get(&worker_id) { + let _ = sender.send(UpdateSequences::Free { + request_id: event.request_id.clone(), + }); + } } - UpdateSequences::Shutdown => { - break; + } + ActiveSequenceEventData::MarkPrefillCompleted => { + if let Some(worker_id) = request_to_worker.get(&event.request_id) { + if let Some(sender) = senders.get(&*worker_id) { + let _ = sender.send(UpdateSequences::MarkPrefillCompleted { + request_id: event.request_id.clone(), + }); + } } } } - }); + } - (request_tx, handle) + Ok(()) } /// Update the set of workers, adding and removing as needed - pub fn update_workers(&mut self, new_worker_ids: Vec) -> HashMap { - let current_workers: HashSet = self.senders.keys().copied().collect(); + pub fn update_workers(&self, new_worker_ids: Vec) { + let current_workers: HashSet = + self.senders.iter().map(|entry| *entry.key()).collect(); let new_workers: HashSet = new_worker_ids.into_iter().collect(); let workers_to_remove: Vec = @@ -340,11 +467,11 @@ impl ActiveSequencesMultiWorker { tracing::warn!("Removing worker {}", worker_id); // Send shutdown command to the worker - if let Some(sender) = self.senders.remove(worker_id) { + if let Some((_, sender)) = self.senders.remove(worker_id) { let _ = sender.send(UpdateSequences::Shutdown); } - if let Some(handle) = self.handles.remove(worker_id) { - let _ = handle.join(); + if let Some((_, handle)) = self.handles.remove(worker_id) { + handle.abort(); } } @@ -352,68 +479,128 @@ impl ActiveSequencesMultiWorker { for worker_id in &workers_to_add { tracing::warn!("Adding worker {}", worker_id); - let (sender, handle) = Self::start_worker(self.block_size); + let (sender, handle) = Self::start_worker( + self.block_size, + self.component.drt().runtime().child_token(), + ); self.senders.insert(*worker_id, sender); self.handles.insert(*worker_id, handle); } - - // Return active blocks for all workers - self.active_blocks() } - pub fn add_request( - &mut self, + pub async fn add_request( + &self, request_id: RequestId, token_sequence: Vec, isl: usize, overlap: u32, worker_id: WorkerId, - ) { + ) -> Result<()> { if !self.senders.contains_key(&worker_id) { - panic!("Worker ID {worker_id} not found"); + return Err(anyhow::anyhow!("Worker ID {worker_id} not found")); + } + + // Publish event only if replica_sync is enabled + if self.replica_sync { + let event = ActiveSequenceEvent { + request_id: request_id.clone(), + worker_id, + data: ActiveSequenceEventData::AddRequest { + token_sequence: token_sequence.clone(), + isl, + overlap, + }, + router_id: self.router_id, + }; + self.component + .publish(ACTIVE_SEQUENCES_SUBJECT, &event) + .await?; } + // Update local state self.request_to_worker.insert(request_id.clone(), worker_id); - self.senders[&worker_id] + self.senders + .get(&worker_id) + .unwrap() .send(UpdateSequences::AddRequest { request_id, token_sequence, isl, overlap, }) - .expect("Failed to send add_request command to worker"); + .map_err(|_| anyhow::anyhow!("Failed to send add_request command to worker"))?; + + Ok(()) } - pub fn free(&mut self, request_id: &RequestId) { + pub async fn free(&self, request_id: &RequestId) -> Result<()> { let worker_id = self .request_to_worker .get(request_id) - .copied() - .expect("Request ID not found in request_to_worker mapping"); + .map(|entry| *entry) + .ok_or_else(|| anyhow::anyhow!("Request ID not found in request_to_worker mapping"))?; - self.senders[&worker_id] + // Publish event only if replica_sync is enabled + if self.replica_sync { + let event = ActiveSequenceEvent { + request_id: request_id.clone(), + worker_id, + data: ActiveSequenceEventData::Free, + router_id: self.router_id, + }; + self.component + .publish(ACTIVE_SEQUENCES_SUBJECT, &event) + .await?; + } + + // Update local state + self.senders + .get(&worker_id) + .unwrap() .send(UpdateSequences::Free { request_id: request_id.clone(), }) - .expect("Failed to send free command to worker"); + .map_err(|_| anyhow::anyhow!("Failed to send free command to worker"))?; self.request_to_worker.remove(request_id); + + Ok(()) } /// Mark prefill as completed for a request - pub fn mark_prefill_completed(&mut self, request_id: &RequestId) { + pub async fn mark_prefill_completed(&self, request_id: &RequestId) -> Result<()> { let worker_id = self .request_to_worker .get(request_id) - .copied() - .expect("Request ID not found in request_to_worker mapping"); + .map(|entry| *entry) + .ok_or_else(|| anyhow::anyhow!("Request ID not found in request_to_worker mapping"))?; + + // Publish event only if replica_sync is enabled + if self.replica_sync { + let event = ActiveSequenceEvent { + request_id: request_id.clone(), + worker_id, + data: ActiveSequenceEventData::MarkPrefillCompleted, + router_id: self.router_id, + }; + self.component + .publish(ACTIVE_SEQUENCES_SUBJECT, &event) + .await?; + } - self.senders[&worker_id] + // Update local state + self.senders + .get(&worker_id) + .unwrap() .send(UpdateSequences::MarkPrefillCompleted { request_id: request_id.clone(), }) - .expect("Failed to send mark_prefill_completed command to worker"); + .map_err(|_| { + anyhow::anyhow!("Failed to send mark_prefill_completed command to worker") + })?; + + Ok(()) } /// Get the number of workers @@ -422,37 +609,49 @@ impl ActiveSequencesMultiWorker { } /// Generic method to query all workers with a given command - fn query_workers( + async fn query_workers( &self, token_sequence: Option>, - command_fn: impl Fn(Option>>, mpsc::SyncSender) -> UpdateSequences, - ) -> HashMap { + command_fn: impl Fn( + Option>>, + tokio::sync::oneshot::Sender, + ) -> UpdateSequences, + ) -> HashMap { let mut results = HashMap::new(); let token_sequence_shared = token_sequence.map(Arc::new); let mut receivers = Vec::new(); // Send queries to all workers in parallel - for (worker_id, sender) in &self.senders { - let (resp_tx, resp_rx) = mpsc::sync_channel(0); + for entry in self.senders.iter() { + let worker_id = *entry.key(); + let sender = entry.value(); + let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); receivers.push((worker_id, resp_rx)); - sender - .send(command_fn(token_sequence_shared.clone(), resp_tx)) - .expect("Failed to send command to worker"); + if let Err(e) = sender.send(command_fn(token_sequence_shared.clone(), resp_tx)) { + tracing::error!("Failed to send command to worker {}: {}", worker_id, e); + } } // Collect results from all workers for (worker_id, receiver) in receivers { - let result = receiver - .recv_timeout(Duration::from_secs(1)) - .expect("Failed to receive response from worker"); - results.insert(*worker_id, result); + match tokio::time::timeout(tokio::time::Duration::from_secs(1), receiver).await { + Ok(Ok(result)) => { + results.insert(worker_id, result); + } + Ok(Err(_)) => { + tracing::error!("Worker {} dropped response channel", worker_id); + } + Err(_) => { + tracing::error!("Timeout waiting for response from worker {}", worker_id); + } + } } results } /// Query all workers for the number of new blocks that would be added by a token sequence - pub fn new_blocks(&self, token_sequence: Vec) -> HashMap { + pub async fn new_blocks(&self, token_sequence: Vec) -> HashMap { self.query_workers(Some(token_sequence), |ts, resp_tx| match ts { Some(ts) => UpdateSequences::NewBlocks { token_sequence: ts, @@ -460,10 +659,14 @@ impl ActiveSequencesMultiWorker { }, None => unreachable!("token_sequence should always be Some for new_blocks"), }) + .await } /// Query all workers for the total number of blocks (new + active) that would be used by a token sequence - pub fn potential_blocks(&self, token_sequence: Vec) -> HashMap { + pub async fn potential_blocks( + &self, + token_sequence: Vec, + ) -> HashMap { self.query_workers(Some(token_sequence), |ts, resp_tx| match ts { Some(ts) => UpdateSequences::PotentialBlocks { token_sequence: ts, @@ -471,10 +674,11 @@ impl ActiveSequencesMultiWorker { }, None => unreachable!("token_sequence should always be Some for potential_blocks"), }) + .await } /// Query all workers for the potential tokens (new + active) that would be used by a token sequence with overlap - pub fn potential_blocks_and_tokens( + pub async fn potential_blocks_and_tokens( &self, token_sequence: Vec, isl: usize, @@ -486,53 +690,68 @@ impl ActiveSequencesMultiWorker { let mut receivers = Vec::new(); // Send queries to all workers in parallel - for (worker_id, sender) in &self.senders { - let (resp_tx, resp_rx) = mpsc::sync_channel(0); + for entry in self.senders.iter() { + let worker_id = *entry.key(); + let sender = entry.value(); + let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); receivers.push((worker_id, resp_rx)); - sender - .send(UpdateSequences::PotentialBlocksAndTokens { - token_sequence: token_sequence_shared.clone(), - isl, - overlap: overlaps.scores.get(worker_id).copied().unwrap_or(0), - resp_tx, - }) - .expect("Failed to send potential_tokens command to worker"); + if let Err(e) = sender.send(UpdateSequences::PotentialBlocksAndTokens { + token_sequence: token_sequence_shared.clone(), + isl, + overlap: overlaps.scores.get(&worker_id).copied().unwrap_or(0), + resp_tx, + }) { + tracing::error!( + "Failed to send potential_tokens command to worker {}: {}", + worker_id, + e + ); + } } // Collect results from all workers for (worker_id, receiver) in receivers { - let (blocks, tokens) = receiver - .recv_timeout(Duration::from_secs(1)) - .expect("Failed to receive response from worker"); - potential_blocks.insert(*worker_id, blocks); - potential_tokens.insert(*worker_id, tokens); + match tokio::time::timeout(tokio::time::Duration::from_secs(1), receiver).await { + Ok(Ok((blocks, tokens))) => { + potential_blocks.insert(worker_id, blocks); + potential_tokens.insert(worker_id, tokens); + } + Ok(Err(_)) => { + tracing::error!("Worker {} dropped response channel", worker_id); + } + Err(_) => { + tracing::error!("Timeout waiting for response from worker {}", worker_id); + } + } } (potential_blocks, potential_tokens) } /// Query all workers for their current number of active blocks - pub fn active_blocks(&self) -> HashMap { + pub async fn active_blocks(&self) -> HashMap { self.query_workers(None, |_, resp_tx| UpdateSequences::ActiveBlocks { resp_tx }) + .await } /// Query all workers for their current number of active tokens - pub fn active_tokens(&self) -> HashMap { + pub async fn active_tokens(&self) -> HashMap { self.query_workers(None, |_, resp_tx| UpdateSequences::ActiveTokens { resp_tx }) + .await } } impl Drop for ActiveSequencesMultiWorker { fn drop(&mut self) { - // Send shutdown command to all workers - for sender in self.senders.values() { - let _ = sender.send(UpdateSequences::Shutdown); + // Send shutdown to all workers + for entry in self.senders.iter() { + let _ = entry.value().send(UpdateSequences::Shutdown); } - // Wait for all threads to finish - for (_, handle) in self.handles.drain() { - let _ = handle.join(); + // Abort all tasks + for entry in self.handles.iter() { + entry.value().abort(); } } } @@ -540,44 +759,248 @@ impl Drop for ActiveSequencesMultiWorker { #[cfg(test)] mod tests { use super::*; + use dynamo_runtime::{DistributedRuntime, Runtime}; + use std::sync::{Arc, Mutex}; + use std::thread; #[test] - fn test_multi_worker_block_sharing() { - // Create multi-worker sequence manager with 3 workers + #[ignore] + fn test_multi_worker_block_sharing() -> Result<()> { + // Initialize logging once + dynamo_runtime::logging::init(); + let block_size = 4; // arbitrary block size - let worker_ids = vec![0, 1, 2]; - let mut seq_manager = ActiveSequencesMultiWorker::new(block_size, worker_ids); - - // Add requests to each worker - // Worker 0: sequence [0, 1, 2] - seq_manager.add_request( - "request_0".to_string(), - vec![0, 1, 2], - 12, // ISL (3 blocks * 4 block_size) - 0, // no overlap - 0, // worker_id - ); - // Worker 1: sequence [3, 4] - seq_manager.add_request( - "request_1".to_string(), - vec![3, 4], - 8, // ISL (2 blocks * 4 block_size) - 0, // no overlap - 1, // worker_id - ); + // Shared state for collecting results from both threads + let active_tokens_after_add = Arc::new(Mutex::new(HashMap::new())); + let potential_blocks_result = Arc::new(Mutex::new(HashMap::new())); + let active_blocks_after_free = Arc::new(Mutex::new(HashMap::new())); + let active_tokens_after_free = Arc::new(Mutex::new(HashMap::new())); + + let active_tokens_after_add_clone = active_tokens_after_add.clone(); + let potential_blocks_result_clone = potential_blocks_result.clone(); + let active_blocks_after_free_clone = active_blocks_after_free.clone(); + let active_tokens_after_free_clone = active_tokens_after_free.clone(); + + // Clone again for the second thread + let active_tokens_after_add_clone2 = active_tokens_after_add.clone(); + let potential_blocks_result_clone2 = potential_blocks_result.clone(); + let active_blocks_after_free_clone2 = active_blocks_after_free.clone(); + let active_tokens_after_free_clone2 = active_tokens_after_free.clone(); + + // Thread 1: First runtime with workers 0 and 1 + let handle1 = thread::spawn(move || { + let rt = tokio::runtime::Runtime::new().unwrap(); + + rt.block_on(async { + // Create runtime and distributed runtime + let runtime = Runtime::from_current()?; + let distributed = DistributedRuntime::from_settings(runtime.clone()).await?; + + // Create namespace and component with same names as thread 2 + let namespace = distributed.namespace("test_multiworker_sequences")?; + let component = namespace + .component("sequences")? + .service_builder() + .create() + .await?; + + // Create multi-worker sequence manager with workers 0 and 1 + let worker_ids = vec![0, 1]; + let seq_manager = + ActiveSequencesMultiWorker::new(component, block_size, worker_ids, true); + + // Give some time for the subscription loop to start + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Add requests to workers + // Worker 0: sequence [0, 1, 2] + seq_manager + .add_request( + "request_0".to_string(), + vec![0, 1, 2], + 12, // ISL (3 blocks * 4 block_size) + 0, // no overlap + 0, // worker_id + ) + .await?; + + // Worker 1: sequence [3, 4] + seq_manager + .add_request( + "request_1".to_string(), + vec![3, 4], + 8, // ISL (2 blocks * 4 block_size) + 0, // no overlap + 1, // worker_id + ) + .await?; + + // Give some time for the commands to be processed and synchronization + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Get active tokens from workers 0 and 1 + let tokens = seq_manager.active_tokens().await; + active_tokens_after_add_clone + .lock() + .unwrap() + .insert(0, tokens.get(&0).copied().unwrap_or(0)); + active_tokens_after_add_clone + .lock() + .unwrap() + .insert(1, tokens.get(&1).copied().unwrap_or(0)); + + // Test potential blocks for sequence [0, 1] + let potential = seq_manager.potential_blocks(vec![0, 1]).await; + potential_blocks_result_clone + .lock() + .unwrap() + .insert(0, potential.get(&0).copied().unwrap_or(0)); + potential_blocks_result_clone + .lock() + .unwrap() + .insert(1, potential.get(&1).copied().unwrap_or(0)); + + // Wait for second thread to process its requests + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Free requests from workers 0 and 1 + seq_manager.free(&"request_0".to_string()).await?; + seq_manager.free(&"request_1".to_string()).await?; + + // Give some time for the commands to be processed + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Get final active blocks and tokens + let blocks = seq_manager.active_blocks().await; + let tokens = seq_manager.active_tokens().await; + + active_blocks_after_free_clone + .lock() + .unwrap() + .insert(0, blocks.get(&0).copied().unwrap_or(0)); + active_blocks_after_free_clone + .lock() + .unwrap() + .insert(1, blocks.get(&1).copied().unwrap_or(0)); + active_tokens_after_free_clone + .lock() + .unwrap() + .insert(0, tokens.get(&0).copied().unwrap_or(0)); + active_tokens_after_free_clone + .lock() + .unwrap() + .insert(1, tokens.get(&1).copied().unwrap_or(0)); + + // Keep runtime alive a bit longer for synchronization + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Shutdown runtime + runtime.shutdown(); + + Ok::<(), anyhow::Error>(()) + }) + }); - // Worker 2: sequence [0, 1, 2, 3] - seq_manager.add_request( - "request_2".to_string(), - vec![0, 1, 2, 3], - 16, // ISL (4 blocks * 4 block_size) - 0, // no overlap - 2, // worker_id - ); + // Thread 2: Second runtime with worker 2 + let handle2 = thread::spawn(move || { + let rt = tokio::runtime::Runtime::new().unwrap(); + + rt.block_on(async { + // Create runtime and distributed runtime + let runtime = Runtime::from_current()?; + let distributed = DistributedRuntime::from_settings(runtime.clone()).await?; + + // Create namespace and component with same names as thread 1 + let namespace = distributed.namespace("test_multiworker_sequences")?; + let component = namespace + .component("sequences")? + .service_builder() + .create() + .await?; + + // Create multi-worker sequence manager with worker 2 + let worker_ids = vec![2]; + let seq_manager = + ActiveSequencesMultiWorker::new(component, block_size, worker_ids, true); + + // Give some time for the subscription loop to start + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Wait a bit to ensure thread 1 has started + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Worker 2: sequence [0, 1, 2, 3] + seq_manager + .add_request( + "request_2".to_string(), + vec![0, 1, 2, 3], + 16, // ISL (4 blocks * 4 block_size) + 0, // no overlap + 2, // worker_id + ) + .await?; + + // Give some time for the commands to be processed and synchronization + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Get active tokens from worker 2 + let tokens = seq_manager.active_tokens().await; + active_tokens_after_add_clone2 + .lock() + .unwrap() + .insert(2, tokens.get(&2).copied().unwrap_or(0)); + + // Test potential blocks for sequence [0, 1] + let potential = seq_manager.potential_blocks(vec![0, 1]).await; + potential_blocks_result_clone2 + .lock() + .unwrap() + .insert(2, potential.get(&2).copied().unwrap_or(0)); + + // Wait for first thread to free its requests + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Free request from worker 2 + seq_manager.free(&"request_2".to_string()).await?; + + // Give some time for the commands to be processed + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Get final active blocks and tokens + let blocks = seq_manager.active_blocks().await; + let tokens = seq_manager.active_tokens().await; + + active_blocks_after_free_clone2 + .lock() + .unwrap() + .insert(2, blocks.get(&2).copied().unwrap_or(0)); + active_tokens_after_free_clone2 + .lock() + .unwrap() + .insert(2, tokens.get(&2).copied().unwrap_or(0)); + + // Keep runtime alive a bit longer for synchronization + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Shutdown runtime + runtime.shutdown(); + + Ok::<(), anyhow::Error>(()) + }) + }); + + // Wait for both threads to complete + handle1.join().unwrap()?; + handle2.join().unwrap()?; + + // Extract results + let tokens_after_add = active_tokens_after_add.lock().unwrap(); + let potential_blocks = potential_blocks_result.lock().unwrap(); + let blocks_after_free = active_blocks_after_free.lock().unwrap(); + let tokens_after_free = active_tokens_after_free.lock().unwrap(); // Verify active tokens after adding requests - let tokens_after_add = seq_manager.active_tokens(); assert_eq!( tokens_after_add[&0], 12, "Worker 0 should have 12 active tokens" @@ -592,8 +1015,6 @@ mod tests { ); // Test potential blocks for sequence [0, 1] - let potential_blocks = seq_manager.potential_blocks(vec![0, 1]); - // Worker 0 should return 3 (already has blocks 0, 1, 2, so no new blocks needed for [0, 1]) assert_eq!( potential_blocks[&0], 3, @@ -612,30 +1033,34 @@ mod tests { "Worker 2 should have 4 potential blocks" ); - // Free all original requests - seq_manager.free(&"request_0".to_string()); - seq_manager.free(&"request_1".to_string()); - seq_manager.free(&"request_2".to_string()); - // Verify active blocks are zero for all workers - let active_blocks = seq_manager.active_blocks(); - assert_eq!(active_blocks[&0], 0, "Worker 0 should have 0 active blocks"); - assert_eq!(active_blocks[&1], 0, "Worker 1 should have 0 active blocks"); - assert_eq!(active_blocks[&2], 0, "Worker 2 should have 0 active blocks"); + assert_eq!( + blocks_after_free[&0], 0, + "Worker 0 should have 0 active blocks" + ); + assert_eq!( + blocks_after_free[&1], 0, + "Worker 1 should have 0 active blocks" + ); + assert_eq!( + blocks_after_free[&2], 0, + "Worker 2 should have 0 active blocks" + ); // Verify active tokens are zero for all workers - let final_tokens = seq_manager.active_tokens(); assert_eq!( - final_tokens[&0], 0, + tokens_after_free[&0], 0, "Worker 0 should have 0 active tokens after freeing all" ); assert_eq!( - final_tokens[&1], 0, + tokens_after_free[&1], 0, "Worker 1 should have 0 active tokens after freeing all" ); assert_eq!( - final_tokens[&2], 0, + tokens_after_free[&2], 0, "Worker 2 should have 0 active tokens after freeing all" ); + + Ok(()) } }