From 78823a8c69312b2d9eb204c9e5a58b846f706200 Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Sun, 3 Aug 2025 13:33:46 -0700 Subject: [PATCH 01/27] background process sending load metrics over nats --- lib/llm/src/kv_router.rs | 1 + lib/llm/src/kv_router/publisher.rs | 81 ++++++++++++++++++++++++++++-- 2 files changed, 77 insertions(+), 5 deletions(-) diff --git a/lib/llm/src/kv_router.rs b/lib/llm/src/kv_router.rs index 2e5d02fe82..09fad92371 100644 --- a/lib/llm/src/kv_router.rs +++ b/lib/llm/src/kv_router.rs @@ -51,6 +51,7 @@ use dynamo_runtime::traits::events::EventSubscriber; pub const KV_EVENT_SUBJECT: &str = "kv_events"; pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate"; pub const KV_METRICS_ENDPOINT: &str = "load_metrics"; +pub const KV_METRICS_SUBJECT: &str = "kv_metrics"; /// A trait that users can implement to define custom selection logic pub trait WorkerSelector { diff --git a/lib/llm/src/kv_router/publisher.rs b/lib/llm/src/kv_router/publisher.rs index 2f75b59c06..b0bb532e4f 100644 --- a/lib/llm/src/kv_router/publisher.rs +++ b/lib/llm/src/kv_router/publisher.rs @@ -16,7 +16,7 @@ use crate::kv_router::{ indexer::{compute_block_hash_for_seq, RouterEvent}, protocols::*, - KV_EVENT_SUBJECT, KV_METRICS_ENDPOINT, + KV_EVENT_SUBJECT, KV_METRICS_ENDPOINT, KV_METRICS_SUBJECT, }; use async_trait::async_trait; use dynamo_runtime::traits::{events::EventPublisher, DistributedRuntimeProvider}; @@ -499,9 +499,11 @@ impl WorkerMetricsPublisher { pub async fn create_endpoint(&self, component: Component) -> Result<()> { let mut metrics_rx = self.rx.clone(); - let handler = Arc::new(KvLoadEndpoingHander::new(metrics_rx.clone())); + let handler = Arc::new(KvLoadEndpointHandler::new(metrics_rx.clone())); let handler = Ingress::for_engine(handler)?; + self.start_nats_metrics_publishing(component.clone()); + component .endpoint(KV_METRICS_ENDPOINT) .endpoint_builder() @@ -513,13 +515,82 @@ impl WorkerMetricsPublisher { .start() .await } + + /// Starts a background task to publish metrics over NATS + /// + /// This task monitors metric changes (specifically kv_active_blocks and num_requests_waiting) + /// and publishes stable metrics to NATS after they've been unchanged for 1ms. + fn start_nats_metrics_publishing(&self, component: Component) { + let nats_rx = self.rx.clone(); + + tokio::spawn(async move { + let mut rx = nats_rx; + let mut last_kv_active_blocks: Option = None; + let mut last_num_requests_waiting: Option = None; + let mut stable_metrics: Option> = None; + let mut stable_since: Option = None; + + loop { + match rx.changed().await { + Ok(_) => { + let metrics = rx.borrow_and_update().clone(); + + // Extract the values we care about + let current_kv_active_blocks = metrics.kv_stats.kv_active_blocks; + let current_num_requests_waiting = + metrics.worker_stats.num_requests_waiting; + + // Check if these specific metrics have changed + let has_changed = match (last_kv_active_blocks, last_num_requests_waiting) { + (Some(last_kv), Some(last_requests)) => { + last_kv != current_kv_active_blocks + || last_requests != current_num_requests_waiting + } + _ => true, // First time, consider it changed + }; + + // Metrics changed, reset stability tracking + if has_changed { + stable_metrics = Some(metrics.clone()); + stable_since = Some(tokio::time::Instant::now()); + last_kv_active_blocks = Some(current_kv_active_blocks); + last_num_requests_waiting = Some(current_num_requests_waiting); + } + + // Check if metrics have been stable for 1ms (independent of whether they just changed) + if let (Some(stable_start), Some(ref stable_m)) = + (stable_since, &stable_metrics) + { + // Stable for 1ms, publish + if stable_start.elapsed() >= tokio::time::Duration::from_millis(1) { + if let Err(e) = + component.publish(KV_METRICS_SUBJECT, &**stable_m).await + { + tracing::warn!("Failed to publish metrics over NATS: {}", e); + } + // Reset stability tracking after publishing + stable_metrics = None; + stable_since = None; + } + } + } + Err(_) => { + tracing::debug!( + "Metrics publisher sender dropped, stopping NATS background task" + ); + break; + } + } + } + }); + } } -struct KvLoadEndpoingHander { +struct KvLoadEndpointHandler { metrics_rx: tokio::sync::watch::Receiver>, } -impl KvLoadEndpoingHander { +impl KvLoadEndpointHandler { pub fn new(metrics_rx: tokio::sync::watch::Receiver>) -> Self { Self { metrics_rx } } @@ -527,7 +598,7 @@ impl KvLoadEndpoingHander { #[async_trait] impl AsyncEngine, ManyOut>, Error> - for KvLoadEndpoingHander + for KvLoadEndpointHandler { async fn generate( &self, From 5e4ba553a5929f315888ceaff561144b9f737bc5 Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Sun, 3 Aug 2025 15:53:38 -0700 Subject: [PATCH 02/27] distributed prefill counter (tests passing) --- lib/llm/src/kv_router.rs | 2 + lib/llm/src/kv_router/prefill_counter.rs | 357 +++++++++++++++++++++++ lib/llm/src/kv_router/protocols.rs | 29 +- 3 files changed, 376 insertions(+), 12 deletions(-) create mode 100644 lib/llm/src/kv_router/prefill_counter.rs diff --git a/lib/llm/src/kv_router.rs b/lib/llm/src/kv_router.rs index 09fad92371..3400affac4 100644 --- a/lib/llm/src/kv_router.rs +++ b/lib/llm/src/kv_router.rs @@ -20,6 +20,7 @@ use tokio::sync::Mutex; pub mod approx; pub mod indexer; pub mod metrics_aggregator; +pub mod prefill_counter; pub mod protocols; pub mod publisher; pub mod recorder; @@ -52,6 +53,7 @@ pub const KV_EVENT_SUBJECT: &str = "kv_events"; pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate"; pub const KV_METRICS_ENDPOINT: &str = "load_metrics"; pub const KV_METRICS_SUBJECT: &str = "kv_metrics"; +pub const PREFILL_SUBJECT: &str = "prefill_events"; /// A trait that users can implement to define custom selection logic pub trait WorkerSelector { diff --git a/lib/llm/src/kv_router/prefill_counter.rs b/lib/llm/src/kv_router/prefill_counter.rs new file mode 100644 index 0000000000..4c1b83f573 --- /dev/null +++ b/lib/llm/src/kv_router/prefill_counter.rs @@ -0,0 +1,357 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use anyhow::Result; +use dynamo_runtime::component::{Component, Namespace}; +use dynamo_runtime::traits::events::{EventPublisher, EventSubscriber}; +use futures::StreamExt; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use tokio::sync::RwLock; + +use super::protocols::{PrefillEvent, PrefillEventData}; +use crate::kv_router::PREFILL_SUBJECT; +use dashmap::DashMap; + +/// A counter that tracks pending prefill tokens for each request. +/// +/// This struct maintains a local hashmap of request_id to token count, +/// a running sum of all tokens, and subscribes to prefill events over NATS +/// to keep the counts synchronized across components. +pub struct PrefillCounter { + state: Arc>, + namespace: Namespace, +} + +struct PrefillCounterState { + tokens_map: DashMap, + running_sum: AtomicUsize, +} + +impl PrefillCounterState { + fn new() -> Self { + Self { + tokens_map: DashMap::new(), + running_sum: AtomicUsize::new(0), + } + } + + fn contains_key(&self, key: &str) -> bool { + self.tokens_map.contains_key(key) + } + + fn insert(&self, key: String, value: usize) -> Option { + let old_value = self.tokens_map.insert(key, value); + + if let Some(old) = old_value { + self.running_sum.fetch_sub(old, Ordering::SeqCst); + self.running_sum.fetch_add(value, Ordering::SeqCst); + } else { + self.running_sum.fetch_add(value, Ordering::SeqCst); + } + + old_value + } + + fn remove(&self, key: &str) -> Option { + let removed = self.tokens_map.remove(key).map(|(_, v)| v); + + if let Some(value) = removed { + self.running_sum.fetch_sub(value, Ordering::SeqCst); + } + + removed + } + + fn running_sum(&self) -> usize { + self.running_sum.load(Ordering::SeqCst) + } +} + +impl PrefillCounter { + /// Create a new PrefillCounter with the given component. + /// + /// This will start a background task that subscribes to PREFILL_SUBJECT + /// and updates the internal state based on received events. + pub fn new(component: Component) -> Self { + let state = Arc::new(RwLock::new(PrefillCounterState::new())); + + let counter = Self { + state: state.clone(), + namespace: component.namespace().clone(), + }; + + let state_clone = state.clone(); + let namespace_clone = counter.namespace.clone(); + + tokio::spawn(async move { + if let Err(e) = Self::subscribe_to_events(state_clone, namespace_clone).await { + tracing::error!("Error in prefill events subscription: {}", e); + } + }); + + counter + } + + /// Background task to subscribe to prefill events and update internal state + /// TODO: somehow try to block events that are sent by itself + async fn subscribe_to_events( + state: Arc>, + namespace: Namespace, + ) -> Result<()> { + let mut subscriber = namespace + .subscribe_with_type::(PREFILL_SUBJECT) + .await?; + + while let Some(result) = subscriber.next().await { + let Ok(event) = result else { + tracing::error!("Error receiving prefill event: {}", result.unwrap_err()); + continue; + }; + + match event.data { + PrefillEventData::NewPrefill(tokens) => { + let state_read = state.read().await; + if state_read.contains_key(&event.request_id) { + continue; + } + drop(state_read); + + let state_write = state.write().await; + state_write.insert(event.request_id.clone(), tokens); + } + PrefillEventData::UpdatePrefill(new_tokens) => { + let state_write = state.write().await; + let Some(old_tokens_ref) = state_write.tokens_map.get(&event.request_id) else { + continue; + }; + let old_tokens = *old_tokens_ref; + + let delta = new_tokens as isize - old_tokens as isize; + state_write + .running_sum + .fetch_add(delta as usize, Ordering::SeqCst); + state_write + .tokens_map + .insert(event.request_id.clone(), new_tokens); + } + PrefillEventData::CompletePrefill(_) => { + let state_read = state.read().await; + if !state_read.contains_key(&event.request_id) { + continue; + } + drop(state_read); + + let state_write = state.write().await; + state_write.remove(&event.request_id); + } + } + } + + Ok(()) + } + + pub async fn insert(&self, request_id: String, tokens: usize) -> Result> { + let state = self.state.write().await; + let old_value = state.insert(request_id.clone(), tokens); + + // Send appropriate event based on whether this is a new prefill or an update + let event = PrefillEvent { + request_id, + data: if old_value.is_some() { + PrefillEventData::UpdatePrefill(tokens) + } else { + PrefillEventData::NewPrefill(tokens) + }, + }; + self.namespace.publish(PREFILL_SUBJECT, &event).await?; + + Ok(old_value) + } + + pub async fn remove(&self, request_id: &str) -> Result> { + let state = self.state.write().await; + let removed_tokens = state.remove(request_id); + + if let Some(tokens) = removed_tokens { + let event = PrefillEvent { + request_id: request_id.to_string(), + data: PrefillEventData::CompletePrefill(tokens), + }; + self.namespace.publish(PREFILL_SUBJECT, &event).await?; + } + + Ok(removed_tokens) + } + + pub async fn get(&self, request_id: &str) -> Option { + let state = self.state.read().await; + state.tokens_map.get(request_id).map(|entry| *entry) + } + + pub async fn running_sum(&self) -> usize { + let state = self.state.read().await; + state.running_sum() + } + + pub async fn len(&self) -> usize { + let state = self.state.read().await; + state.tokens_map.len() + } + + pub async fn is_empty(&self) -> bool { + let state = self.state.read().await; + state.tokens_map.is_empty() + } +} + +#[cfg(test)] +mod integration_tests { + use super::*; + use dynamo_runtime::{DistributedRuntime, Runtime}; + use std::collections::HashMap; + use tokio::time::Duration; + + #[tokio::test] + #[ignore] + async fn test_prefill_counter_synchronization() -> Result<()> { + // Initialize logging + dynamo_runtime::logging::init(); + + // Create runtime and distributed runtime + let runtime = Runtime::from_current()?; + let distributed = DistributedRuntime::from_settings(runtime.clone()).await?; + + // Create namespace and components for two counters + let namespace = distributed.namespace("test_prefill_counter")?; + let component1 = namespace + .component("counter1")? + .service_builder() + .create() + .await?; + let component2 = namespace + .component("counter2")? + .service_builder() + .create() + .await?; + + // Create two PrefillCounter instances + let counter1 = PrefillCounter::new(component1); + let counter2 = PrefillCounter::new(component2); + + // Give some time for subscribers to initialize + tokio::time::sleep(Duration::from_millis(2000)).await; + + // Track all request_ids and their token counts for verification + let mut expected_tokens = HashMap::new(); + let tokens_per_request = 100; + let requests_per_counter = 50; + + // Send 50 requests to counter1 + for i in 0..requests_per_counter { + let request_id = format!("counter1_request_{}", i); + counter1 + .insert(request_id.clone(), tokens_per_request) + .await?; + expected_tokens.insert(request_id, tokens_per_request); + } + + // Send 50 requests to counter2 + for i in 0..requests_per_counter { + let request_id = format!("counter2_request_{}", i); + counter2 + .insert(request_id.clone(), tokens_per_request) + .await?; + expected_tokens.insert(request_id, tokens_per_request); + } + + // Wait for synchronization + tokio::time::sleep(Duration::from_millis(10)).await; + + // Verify both counters have the same running sum + let expected_sum = (requests_per_counter * 2) * tokens_per_request; + let sum1 = counter1.running_sum().await; + let sum2 = counter2.running_sum().await; + + assert_eq!( + sum1, expected_sum, + "Counter1 running sum mismatch. Expected: {}, Got: {}", + expected_sum, sum1 + ); + assert_eq!( + sum2, expected_sum, + "Counter2 running sum mismatch. Expected: {}, Got: {}", + expected_sum, sum2 + ); + + // Verify both counters have all 100 requests + let len1 = counter1.len().await; + let len2 = counter2.len().await; + assert_eq!( + len1, + requests_per_counter * 2, + "Counter1 should have {} requests", + requests_per_counter * 2 + ); + assert_eq!( + len2, + requests_per_counter * 2, + "Counter2 should have {} requests", + requests_per_counter * 2 + ); + + // Spot check some individual requests on both counters + for i in 0..5 { + let request_id = format!("counter1_request_{}", i); + let tokens1 = counter1.get(&request_id).await; + let tokens2 = counter2.get(&request_id).await; + assert_eq!( + tokens1, + Some(tokens_per_request), + "Counter1 missing request {}", + request_id + ); + assert_eq!( + tokens2, + Some(tokens_per_request), + "Counter2 missing request {}", + request_id + ); + } + + // Now remove all requests from both counters + for i in 0..requests_per_counter { + let request_id = format!("counter1_request_{}", i); + counter1.remove(&request_id).await?; + } + + for i in 0..requests_per_counter { + let request_id = format!("counter2_request_{}", i); + counter2.remove(&request_id).await?; + } + + // Wait for removal synchronization + tokio::time::sleep(Duration::from_millis(10)).await; + + // Verify both counters have zero running sum + let final_sum1 = counter1.running_sum().await; + let final_sum2 = counter2.running_sum().await; + assert_eq!( + final_sum1, 0, + "Counter1 should have zero running sum after removal" + ); + assert_eq!( + final_sum2, 0, + "Counter2 should have zero running sum after removal" + ); + + // Verify both counters are empty + assert!(counter1.is_empty().await, "Counter1 should be empty"); + assert!(counter2.is_empty().await, "Counter2 should be empty"); + + // Shutdown runtime + runtime.shutdown(); + + Ok(()) + } +} diff --git a/lib/llm/src/kv_router/protocols.rs b/lib/llm/src/kv_router/protocols.rs index e6429f3909..923ede9105 100644 --- a/lib/llm/src/kv_router/protocols.rs +++ b/lib/llm/src/kv_router/protocols.rs @@ -1,17 +1,5 @@ // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. use crate::tokens::Token; use serde::{Deserialize, Serialize}; @@ -128,6 +116,23 @@ impl From for ExternalSequenceBlockHash { } } +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct PrefillEvent { + pub request_id: String, + pub data: PrefillEventData, +} + +/// Represents the different stages of prefilling tokens for a request. +/// +/// Each variant contains a `usize` representing the number of tokens +/// that are pending prefill in the request. +#[derive(Serialize, Deserialize, Debug, Clone)] +pub enum PrefillEventData { + NewPrefill(usize), + UpdatePrefill(usize), + CompletePrefill(usize), +} + /// Represents a collection of cache events and a shutdown flag. #[derive(Serialize, Deserialize, Debug, Clone)] pub struct KvCacheEvents { From e6b4c3b3f66235aa343a6794c8720a375f628c19 Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Sun, 3 Aug 2025 15:54:14 -0700 Subject: [PATCH 03/27] add dashmap to Cargo.toml --- Cargo.lock | 17 ++++++++++++++++- lib/llm/Cargo.toml | 1 + 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index 5c4b0afd0c..655112a2dd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1556,6 +1556,20 @@ dependencies = [ "parking_lot_core", ] +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if 1.0.0", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "data-encoding" version = "2.9.0" @@ -1899,6 +1913,7 @@ dependencies = [ "chrono", "criterion", "cudarc 0.16.2", + "dashmap 6.1.0", "derive-getters", "derive_builder", "dialoguer", @@ -8723,7 +8738,7 @@ dependencies = [ "asynchronous-codec", "bytes", "crossbeam-queue", - "dashmap", + "dashmap 5.5.3", "futures-channel", "futures-io", "futures-task", diff --git a/lib/llm/Cargo.toml b/lib/llm/Cargo.toml index f36994144d..485455e39b 100644 --- a/lib/llm/Cargo.toml +++ b/lib/llm/Cargo.toml @@ -85,6 +85,7 @@ derive-getters = "0.5" offset-allocator = "0.2" regex = "1" rayon = "1" +dashmap = "6" # input/text dialoguer = { version = "0.11", default-features = false, features = ["editor", "history"] } From 107ca095aec1db2a4d766a92d68a050bef239d61 Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Sun, 3 Aug 2025 21:02:28 -0700 Subject: [PATCH 04/27] should be functional --- components/metrics/src/lib.rs | 2 +- lib/bindings/python/Cargo.lock | 17 +- lib/llm/src/kv_router.rs | 37 +-- lib/llm/src/kv_router/metrics_aggregator.rs | 2 +- lib/llm/src/kv_router/prefill_counter.rs | 143 ++++++++-- lib/llm/src/kv_router/protocols.rs | 16 +- lib/llm/src/kv_router/publisher.rs | 192 ++++++++++--- lib/llm/src/kv_router/scheduler.rs | 290 ++++++++------------ lib/llm/src/kv_router/scoring.rs | 31 ++- 9 files changed, 469 insertions(+), 261 deletions(-) diff --git a/components/metrics/src/lib.rs b/components/metrics/src/lib.rs index 7fb632b881..1ef523b9ee 100644 --- a/components/metrics/src/lib.rs +++ b/components/metrics/src/lib.rs @@ -84,7 +84,7 @@ use std::net::SocketAddr; use std::time::Duration as StdDuration; use dynamo_llm::kv_router::protocols::{ForwardPassMetrics, LoadMetrics}; -use dynamo_llm::kv_router::scheduler::Endpoint; +use dynamo_llm::kv_router::scoring::Endpoint; use dynamo_llm::kv_router::scoring::ProcessedEndpoints; use dynamo_runtime::{ diff --git a/lib/bindings/python/Cargo.lock b/lib/bindings/python/Cargo.lock index d677224444..514bf90361 100644 --- a/lib/bindings/python/Cargo.lock +++ b/lib/bindings/python/Cargo.lock @@ -978,6 +978,20 @@ dependencies = [ "parking_lot_core", ] +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if 1.0.0", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "data-encoding" version = "2.9.0" @@ -1187,6 +1201,7 @@ dependencies = [ "candle-core", "chrono", "cudarc", + "dashmap 6.1.0", "derive-getters", "derive_builder", "dialoguer", @@ -5841,7 +5856,7 @@ dependencies = [ "asynchronous-codec", "bytes", "crossbeam-queue", - "dashmap", + "dashmap 5.5.3", "futures-channel", "futures-io", "futures-task", diff --git a/lib/llm/src/kv_router.rs b/lib/llm/src/kv_router.rs index 3400affac4..2e58f83cb6 100644 --- a/lib/llm/src/kv_router.rs +++ b/lib/llm/src/kv_router.rs @@ -15,7 +15,6 @@ use dynamo_runtime::{ protocols::annotated::Annotated, }; use futures::stream::{self, StreamExt}; -use tokio::sync::Mutex; pub mod approx; pub mod indexer; @@ -138,10 +137,6 @@ pub struct KvRouter { scheduler: KvScheduler, block_size: u32, - - // To ensure blocking reads / writes - // TODO: benchmark tradeoffs - find_best_match_mutex: Mutex<()>, } impl KvRouter { @@ -178,13 +173,8 @@ impl KvRouter { )) }; - let scheduler = KvScheduler::start( - component.namespace().clone(), - block_size, - instances_rx, - selector, - ) - .await?; + let scheduler = + KvScheduler::start(component.clone(), block_size, instances_rx, selector).await?; // [gluo TODO] try subscribe_with_type::, // error checking below will be different. @@ -218,7 +208,6 @@ impl KvRouter { indexer, scheduler, block_size, - find_best_match_mutex: Mutex::new(()), // Add this }) } @@ -230,28 +219,19 @@ impl KvRouter { context_id: &str, tokens: &[u32], ) -> anyhow::Result<(i64, u32)> { - // Acquire mutex to serialize access - // TODO: may as well make all the subroutines synchronous if benchmarking favors this - let _guard = self.find_best_match_mutex.lock().await; - let isl_tokens = tokens.len(); let block_hashes = compute_block_hash_for_seq(tokens, self.block_size); - let seq_hashes = compute_seq_hash_for_block(&block_hashes); let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?; let best_worker_id = self .scheduler - .schedule( - context_id.to_string(), - isl_tokens, - seq_hashes.clone(), - overlap_scores.clone(), - ) + .schedule(context_id.to_string(), isl_tokens, overlap_scores.clone()) .await?; if let Indexer::ApproxKvIndexer(ref indexer) = self.indexer { + let seq_hashes = compute_seq_hash_for_block(&block_hashes); indexer .process_routing_decision(best_worker_id, block_hashes, seq_hashes) .await @@ -267,15 +247,10 @@ impl KvRouter { } /// Free all blocks associated with a request - pub async fn mark_prefill_completed(&self, request_id: &String) { + pub async fn mark_prefill_completed(&self, request_id: &str) { self.scheduler.mark_prefill_completed(request_id).await } - /// Free all blocks associated with a request - pub async fn free(&self, request_id: &String) { - self.scheduler.free(request_id).await - } - /// Get the block size this router was configured with pub fn block_size(&self) -> u32 { self.block_size @@ -362,8 +337,6 @@ impl AsyncEngine, ManyOut(state: &DashMap) -> HashMap +where + K: Clone + Hash + Eq, + V: Copy, +{ + state + .iter() + .map(|entry| (entry.key().clone(), *entry.value())) + .collect() +} /// A counter that tracks pending prefill tokens for each request. /// /// This struct maintains a local hashmap of request_id to token count, /// a running sum of all tokens, and subscribes to prefill events over NATS /// to keep the counts synchronized across components. +#[derive(Clone)] pub struct PrefillCounter { state: Arc>, - namespace: Namespace, + component: Component, } struct PrefillCounterState { @@ -78,14 +94,14 @@ impl PrefillCounter { let counter = Self { state: state.clone(), - namespace: component.namespace().clone(), + component: component.clone(), }; let state_clone = state.clone(); - let namespace_clone = counter.namespace.clone(); + let component_clone = component.clone(); tokio::spawn(async move { - if let Err(e) = Self::subscribe_to_events(state_clone, namespace_clone).await { + if let Err(e) = Self::subscribe_to_events(state_clone, component_clone).await { tracing::error!("Error in prefill events subscription: {}", e); } }); @@ -97,9 +113,9 @@ impl PrefillCounter { /// TODO: somehow try to block events that are sent by itself async fn subscribe_to_events( state: Arc>, - namespace: Namespace, + component: Component, ) -> Result<()> { - let mut subscriber = namespace + let mut subscriber = component .subscribe_with_type::(PREFILL_SUBJECT) .await?; @@ -135,7 +151,7 @@ impl PrefillCounter { .tokens_map .insert(event.request_id.clone(), new_tokens); } - PrefillEventData::CompletePrefill(_) => { + PrefillEventData::CompletePrefill => { let state_read = state.read().await; if !state_read.contains_key(&event.request_id) { continue; @@ -164,7 +180,7 @@ impl PrefillCounter { PrefillEventData::NewPrefill(tokens) }, }; - self.namespace.publish(PREFILL_SUBJECT, &event).await?; + self.component.publish(PREFILL_SUBJECT, &event).await?; Ok(old_value) } @@ -173,12 +189,12 @@ impl PrefillCounter { let state = self.state.write().await; let removed_tokens = state.remove(request_id); - if let Some(tokens) = removed_tokens { + if removed_tokens.is_some() { let event = PrefillEvent { request_id: request_id.to_string(), - data: PrefillEventData::CompletePrefill(tokens), + data: PrefillEventData::CompletePrefill, }; - self.namespace.publish(PREFILL_SUBJECT, &event).await?; + self.component.publish(PREFILL_SUBJECT, &event).await?; } Ok(removed_tokens) @@ -203,6 +219,92 @@ impl PrefillCounter { let state = self.state.read().await; state.tokens_map.is_empty() } + + /// Returns a snapshot of the current state as a HashMap + pub async fn snapshot(&self) -> HashMap { + let state = self.state.read().await; + get_snapshot(&state.tokens_map) + } +} + +/// A collection of PrefillCounters for multiple workers +pub struct PrefillCountersMultiWorker { + pub counters: DashMap, + pub request_to_workers: DashMap, + component: Component, +} + +impl PrefillCountersMultiWorker { + pub fn new(component: Component) -> Self { + Self { + counters: DashMap::new(), + request_to_workers: DashMap::new(), + component, + } + } + + pub async fn add_prefill( + &self, + worker_id: i64, + request_id: String, + new_tokens: usize, + ) -> Result<()> { + if let Some(existing_worker_id) = self.request_to_workers.get(&request_id) { + tracing::warn!( + "Request {} already exists for worker {}, but trying to add to worker {}", + request_id, + *existing_worker_id, + worker_id + ); + } + self.request_to_workers + .insert(request_id.clone(), worker_id); + + if let Some(counter) = self.counters.get(&worker_id) { + counter.insert(request_id, new_tokens).await?; + } else { + tracing::warn!( + "Worker {} does not exist, creating new PrefillCounter", + worker_id + ); + let new_counter = PrefillCounter::new(self.component.clone()); + new_counter.insert(request_id, new_tokens).await?; + self.counters.insert(worker_id, new_counter); + } + + Ok(()) + } + + pub async fn remove_prefill(&self, request_id: &str) -> Result> { + let Some((_request_id, worker_id)) = self.request_to_workers.remove(request_id) else { + tracing::warn!("Request {} not found", request_id); + return Ok(None); + }; + + if let Some(counter) = self.counters.get(&worker_id) { + counter.remove(request_id).await + } else { + tracing::warn!( + "Worker {} not found in counters for request {}", + worker_id, + request_id + ); + Ok(None) + } + } + + /// Get the running sums for all workers as a HashMap + pub async fn running_sums(&self) -> HashMap { + let futures = FuturesUnordered::new(); + + for entry in self.counters.iter() { + let worker_id = *entry.key(); + let counter = entry.value().clone(); + futures.push(async move { (worker_id, counter.running_sum().await) }); + } + + futures.collect::>().await + } } #[cfg(test)] @@ -222,22 +324,17 @@ mod integration_tests { let runtime = Runtime::from_current()?; let distributed = DistributedRuntime::from_settings(runtime.clone()).await?; - // Create namespace and components for two counters + // Create namespace and a single component let namespace = distributed.namespace("test_prefill_counter")?; - let component1 = namespace - .component("counter1")? - .service_builder() - .create() - .await?; - let component2 = namespace - .component("counter2")? + let component = namespace + .component("shared_counter")? .service_builder() .create() .await?; - // Create two PrefillCounter instances - let counter1 = PrefillCounter::new(component1); - let counter2 = PrefillCounter::new(component2); + // Create two PrefillCounter instances using the same component (cloned) + let counter1 = PrefillCounter::new(component.clone()); + let counter2 = PrefillCounter::new(component.clone()); // Give some time for subscribers to initialize tokio::time::sleep(Duration::from_millis(2000)).await; diff --git a/lib/llm/src/kv_router/protocols.rs b/lib/llm/src/kv_router/protocols.rs index 923ede9105..a81341330d 100644 --- a/lib/llm/src/kv_router/protocols.rs +++ b/lib/llm/src/kv_router/protocols.rs @@ -1,7 +1,7 @@ // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -use crate::tokens::Token; +use crate::tokens::{SequenceHash, Token}; use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize, Default)] @@ -130,7 +130,19 @@ pub struct PrefillEvent { pub enum PrefillEventData { NewPrefill(usize), UpdatePrefill(usize), - CompletePrefill(usize), + CompletePrefill, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ActiveBlockEvent { + pub request_id: String, + pub data: ActiveBlockEventData, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub enum ActiveBlockEventData { + NewBlock(Vec), + FreeBlock, } /// Represents a collection of cache events and a shutdown flag. diff --git a/lib/llm/src/kv_router/publisher.rs b/lib/llm/src/kv_router/publisher.rs index b0bb532e4f..4415738b6e 100644 --- a/lib/llm/src/kv_router/publisher.rs +++ b/lib/llm/src/kv_router/publisher.rs @@ -16,12 +16,13 @@ use crate::kv_router::{ indexer::{compute_block_hash_for_seq, RouterEvent}, protocols::*, + scoring::LoadEvent, KV_EVENT_SUBJECT, KV_METRICS_ENDPOINT, KV_METRICS_SUBJECT, }; use async_trait::async_trait; use dynamo_runtime::traits::{events::EventPublisher, DistributedRuntimeProvider}; use dynamo_runtime::{ - component::Component, + component::{Component, Namespace}, pipeline::{ network::Ingress, AsyncEngine, AsyncEngineContextProvider, ManyOut, ResponseStream, SingleIn, @@ -502,7 +503,16 @@ impl WorkerMetricsPublisher { let handler = Arc::new(KvLoadEndpointHandler::new(metrics_rx.clone())); let handler = Ingress::for_engine(handler)?; - self.start_nats_metrics_publishing(component.clone()); + let worker_id = component + .drt() + .primary_lease() + .map(|lease| lease.id()) + .unwrap_or_else(|| { + tracing::warn!("Component is static, assuming worker_id of 0"); + 0 + }); + + self.start_nats_metrics_publishing(component.namespace().clone(), worker_id); component .endpoint(KV_METRICS_ENDPOINT) @@ -520,19 +530,29 @@ impl WorkerMetricsPublisher { /// /// This task monitors metric changes (specifically kv_active_blocks and num_requests_waiting) /// and publishes stable metrics to NATS after they've been unchanged for 1ms. - fn start_nats_metrics_publishing(&self, component: Component) { + fn start_nats_metrics_publishing(&self, namespace: Namespace, worker_id: i64) { let nats_rx = self.rx.clone(); tokio::spawn(async move { let mut rx = nats_rx; let mut last_kv_active_blocks: Option = None; let mut last_num_requests_waiting: Option = None; - let mut stable_metrics: Option> = None; - let mut stable_since: Option = None; + let mut pending_publish: Option> = None; + let mut publish_timer = + Box::pin(tokio::time::sleep(tokio::time::Duration::from_secs(0))); + publish_timer.as_mut().reset(tokio::time::Instant::now()); // Complete immediately loop { - match rx.changed().await { - Ok(_) => { + tokio::select! { + // Handle metrics changes + result = rx.changed() => { + if result.is_err() { + tracing::debug!( + "Metrics publisher sender dropped, stopping NATS background task" + ); + break; + } + let metrics = rx.borrow_and_update().clone(); // Extract the values we care about @@ -549,36 +569,33 @@ impl WorkerMetricsPublisher { _ => true, // First time, consider it changed }; - // Metrics changed, reset stability tracking + // If load metrics changed, schedule a publish if has_changed { - stable_metrics = Some(metrics.clone()); - stable_since = Some(tokio::time::Instant::now()); + pending_publish = Some(metrics.clone()); last_kv_active_blocks = Some(current_kv_active_blocks); last_num_requests_waiting = Some(current_num_requests_waiting); - } - // Check if metrics have been stable for 1ms (independent of whether they just changed) - if let (Some(stable_start), Some(ref stable_m)) = - (stable_since, &stable_metrics) - { - // Stable for 1ms, publish - if stable_start.elapsed() >= tokio::time::Duration::from_millis(1) { - if let Err(e) = - component.publish(KV_METRICS_SUBJECT, &**stable_m).await - { - tracing::warn!("Failed to publish metrics over NATS: {}", e); - } - // Reset stability tracking after publishing - stable_metrics = None; - stable_since = None; - } + // Start the 1ms timer + publish_timer.as_mut().reset( + tokio::time::Instant::now() + tokio::time::Duration::from_millis(1) + ); } } - Err(_) => { - tracing::debug!( - "Metrics publisher sender dropped, stopping NATS background task" - ); - break; + // Timer expired - publish if we have pending metrics + _ = &mut publish_timer => { + if let Some(metrics) = pending_publish.take() { + // Create LoadEvent wrapping the metrics + let load_event = LoadEvent { + worker_id, + data: (*metrics).clone(), + }; + + if let Err(e) = + namespace.publish(KV_METRICS_SUBJECT, &load_event).await + { + tracing::warn!("Failed to publish metrics over NATS: {}", e); + } + } } } } @@ -951,3 +968,116 @@ mod test_exponential_backoff { assert!(max_calculated <= MAX_BACKOFF_MS); } } + +#[cfg(test)] +mod test_worker_metrics_publisher { + use super::*; + use crate::kv_router::protocols::{ForwardPassMetrics, KvStats, WorkerStats}; + use dynamo_runtime::traits::events::EventSubscriber; // Add this import + use dynamo_runtime::{DistributedRuntime, Runtime}; + use futures::StreamExt; + + #[tokio::test] + #[ignore] // Mark as ignored as requested + async fn test_metrics_publishing_behavior() -> Result<()> { + // Set up runtime and namespace + let rt = Runtime::from_current().unwrap(); + let drt = DistributedRuntime::from_settings(rt.clone()).await?; + let namespace = drt.namespace("test".to_string())?; + + // Create a subscriber for the metrics events using subscribe_with_type + let mut subscriber = namespace + .subscribe_with_type::(KV_METRICS_SUBJECT) + .await + .unwrap(); + + // Create WorkerMetricsPublisher + let publisher = WorkerMetricsPublisher::new().unwrap(); + let worker_id = 1234; + + // Start NATS metrics publishing + publisher.start_nats_metrics_publishing(namespace.clone(), worker_id); + + // Allow some time for the background task to start + tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + + // Test 1: Publish 10 different metrics with 0.5ms intervals + // Only the last one should be published after 1ms of stability + for i in 0..10 { + let metrics = Arc::new(ForwardPassMetrics { + kv_stats: KvStats { + kv_active_blocks: (i * 100) as u64, // Changing load metric + kv_total_blocks: 1000, + gpu_cache_usage_perc: 0.5, + gpu_prefix_cache_hit_rate: 0.8, + }, + worker_stats: WorkerStats { + num_requests_waiting: (i * 10) as u64, // Changing load metric + data_parallel_rank: None, + request_active_slots: 50, + request_total_slots: 100, + }, + spec_decode_stats: None, + }); + + publisher.publish(metrics).unwrap(); + tokio::time::sleep(tokio::time::Duration::from_micros(100)).await; + } + + // Wait a bit more than 1ms to ensure the last metric is published + tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + + // Verify we receive exactly one event with the last metric values + let result = + tokio::time::timeout(tokio::time::Duration::from_millis(500), subscriber.next()) + .await + .unwrap(); + + let event = result.unwrap().unwrap(); // Unwrap the Option and the Result + assert_eq!(event.worker_id, worker_id); + assert_eq!(event.data.kv_stats.kv_active_blocks, 900); // Last value: 9 * 100 + assert_eq!(event.data.worker_stats.num_requests_waiting, 90); // Last value: 9 * 10 + + // Ensure no more events are waiting + let no_msg = + tokio::time::timeout(tokio::time::Duration::from_millis(50), subscriber.next()).await; + assert!(no_msg.is_err(), "Expected no more messages, but found one"); + + // Test 2: Publish 10 more metrics where everything changes EXCEPT the load metrics + for i in 0..10 { + let metrics = Arc::new(ForwardPassMetrics { + kv_stats: KvStats { + kv_active_blocks: 900, // Keep same as last published + kv_total_blocks: 1000 + (i * 100) as u64, // Change other metrics + gpu_cache_usage_perc: 0.3 + (i as f32 * 0.05), // Change other metrics + gpu_prefix_cache_hit_rate: 0.7 + (i as f32 * 0.01), // Change other metrics + }, + worker_stats: WorkerStats { + num_requests_waiting: 90, // Keep same as last published + data_parallel_rank: None, + request_active_slots: 40 + (i * 5) as u64, // Change other metrics + request_total_slots: 100 + (i * 10) as u64, // Change other metrics + }, + spec_decode_stats: None, + }); + + publisher.publish(metrics).unwrap(); + tokio::time::sleep(tokio::time::Duration::from_micros(100)).await; + } + + // Wait to ensure no events are published + tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + + // Verify no events are received + let no_msg = + tokio::time::timeout(tokio::time::Duration::from_millis(50), subscriber.next()).await; + assert!( + no_msg.is_err(), + "Expected no messages when load metrics don't change" + ); + + rt.shutdown(); + + Ok(()) + } +} diff --git a/lib/llm/src/kv_router/scheduler.rs b/lib/llm/src/kv_router/scheduler.rs index 797f738187..6b5a44b642 100644 --- a/lib/llm/src/kv_router/scheduler.rs +++ b/lib/llm/src/kv_router/scheduler.rs @@ -13,8 +13,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -use dynamo_runtime::component::Namespace; -use dynamo_runtime::traits::events::EventPublisher; +use dashmap::DashMap; +use dynamo_runtime::component::{Component, Instance}; +use dynamo_runtime::traits::events::{EventPublisher, EventSubscriber}; +use futures::StreamExt; use rand::Rng; use serde::{Deserialize, Serialize}; use std::collections::HashMap; @@ -22,15 +24,13 @@ use std::sync::Arc; use std::time::Duration; use tokio::sync::Mutex; +use super::indexer::OverlapScores; +use super::prefill_counter::{get_snapshot, PrefillCountersMultiWorker}; use super::protocols::WorkerSelectionResult; +use super::scoring::LoadEvent; +use super::KvRouterConfig; use super::WorkerSelector; -use crate::kv_router::indexer::OverlapScores; -use crate::kv_router::protocols::LoadMetrics; -use crate::kv_router::sequence::ActiveSequencesMultiWorker; -use crate::kv_router::KvRouterConfig; -use crate::kv_router::KV_HIT_RATE_SUBJECT; -use crate::tokens::SequenceHash; -use dynamo_runtime::component::Instance; +use super::{KV_HIT_RATE_SUBJECT, KV_METRICS_SUBJECT}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct KVHitRateEvent { @@ -51,42 +51,17 @@ pub enum KvSchedulerError { SubscriberShutdown, } -/// [gluo FIXME] exactly the same as EndpointInfo except that 'data' -/// is cleaned (not optional) -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub struct Endpoint { - pub name: String, - pub subject: String, - pub data: LoadMetrics, -} - -impl Endpoint { - pub fn worker_id(&self) -> i64 { - i64::from_str_radix( - self.subject - .split("-") - .last() - .expect("invalid subject") - .to_string() - .as_str(), - 16, - ) - .expect("invalid worker id") - } -} - #[derive(Debug)] pub struct SchedulingResponse { pub best_worker_id: i64, - pub overlap_blocks: u32, // Add this field - pub endpoints_changed: Option>, + pub new_tokens: usize, } pub struct SchedulingRequest { pub isl_tokens: usize, pub overlaps: OverlapScores, - pub potential_blocks: HashMap, - pub potential_tokens: HashMap, + pub decode_blocks: HashMap, + pub prefill_tokens: HashMap, resp_tx: tokio::sync::oneshot::Sender, } @@ -100,12 +75,13 @@ impl SchedulingRequest { pub struct KvScheduler { request_tx: tokio::sync::mpsc::Sender, - sequences: Arc>, + prefill_tokens: Arc, + barrier: Mutex<()>, } impl KvScheduler { pub async fn start( - ns: Namespace, + component: Component, block_size: u32, mut instances_rx: tokio::sync::watch::Receiver>, // Changed from ProcessedEndpoints selector: Option>, @@ -113,93 +89,98 @@ impl KvScheduler { let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default())); let mut instances: Vec = instances_rx.borrow_and_update().clone(); - // Get worker IDs from instances - let worker_ids: Vec = instances.iter().map(|i| i.instance_id).collect(); - let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel::(); + let ns_clone = component.namespace().clone(); tokio::spawn(async move { let mut event_rx = event_rx; while let Some(event) = event_rx.recv().await { - if let Err(e) = ns.publish(KV_HIT_RATE_SUBJECT, &event).await { + if let Err(e) = ns_clone.publish(KV_HIT_RATE_SUBJECT, &event).await { tracing::warn!("Failed to publish KV hit rate event: {:?}", e); } } }); - let sequences = Arc::new(Mutex::new(ActiveSequencesMultiWorker::new( - block_size as usize, - worker_ids, - ))); + // token and block states + let active_blocks = Arc::new(DashMap::::new()); // Create dashmap for worker_id -> active_blocks + let prefill_tokens = Arc::new(PrefillCountersMultiWorker::new(component.clone())); - // Channel to accept new scheduling requests + // listen on load metrics from backends to update active block states + let active_blocks_clone = active_blocks.clone(); + let ns_clone = component.namespace().clone(); + tokio::spawn(async move { + let mut load_subscriber = ns_clone + .subscribe_with_type::(KV_METRICS_SUBJECT) + .await + .expect("Cannot launch load subscriber"); + + while let Some(event_result) = load_subscriber.next().await { + let Ok(load_event) = event_result else { + tracing::warn!("Error receiving load event: {}", event_result.unwrap_err()); + continue; + }; + + active_blocks_clone.insert( + load_event.worker_id, + load_event.data.kv_stats.kv_active_blocks as usize, + ); + } + }); + + let prefill_tokens_clone = prefill_tokens.clone(); let (request_tx, request_rx) = tokio::sync::mpsc::channel::(1024); // Background task to handle scheduling requests tokio::spawn(async move { - let mut request: SchedulingRequest; let mut request_rx = request_rx; - let mut pending_endpoint_update: Option> = None; tracing::trace!("scheduler background task started"); - 'outer: loop { - request = tokio::select! { - biased; + loop { + // First, check for instance updates (non-blocking) + if instances_rx.has_changed().unwrap_or(false) { + instances = instances_rx.borrow_and_update().clone(); + } - _ = instances_rx.changed() => { - instances = instances_rx.borrow_and_update().clone(); - let worker_ids: Vec = instances.iter().map(|i| i.instance_id).collect(); - pending_endpoint_update = Some(worker_ids); - continue 'outer; - } + // Then, wait for a new request + let Some(mut request) = request_rx.recv().await else { + tracing::warn!("scheduler shutdown"); + break; + }; + tracing::trace!("received request to be scheduled"); + + request.prefill_tokens = prefill_tokens_clone.running_sums().await; + request.decode_blocks = get_snapshot(active_blocks.as_ref()); + + match selector.select_worker(&instances, &request, block_size) { + Ok(selection) => { + if let Err(e) = event_tx.send(KVHitRateEvent { + worker_id: selection.worker_id, + isl_blocks: selection.required_blocks as usize, + overlap_blocks: selection.overlap_blocks, + }) { + tracing::warn!("Failed to send KV hit rate event: {:?}", e); + } - maybe_new_request = request_rx.recv() => { - let Some(new_request) = maybe_new_request else { - tracing::warn!("scheduler shutdown"); - break 'outer; + let response = SchedulingResponse { + best_worker_id: selection.worker_id, + new_tokens: request.isl_tokens + - (selection.overlap_blocks * block_size) as usize, }; - tracing::trace!("received request to be scheduled"); - new_request + request.respond(response); + continue; } - }; - - loop { - // When calling selector.select_worker, we need to adapt - match selector.select_worker(&instances, &request, block_size) { - Ok(selection) => { - if let Err(e) = event_tx.send(KVHitRateEvent { - worker_id: selection.worker_id, - isl_blocks: selection.required_blocks as usize, - overlap_blocks: selection.overlap_blocks, - }) { - tracing::warn!("Failed to send KV hit rate event: {:?}", e); - } - - let response = SchedulingResponse { - best_worker_id: selection.worker_id, - overlap_blocks: selection.overlap_blocks, - endpoints_changed: pending_endpoint_update.take(), - }; - request.respond(response); - continue 'outer; - } - Err(KvSchedulerError::NoEndpoints) => { - tracing::trace!("no endpoints available; waiting for endpoints update"); - instances_rx.changed().await.ok(); - instances = instances_rx.borrow_and_update().clone(); - let worker_ids: Vec = - instances.iter().map(|i| i.instance_id).collect(); - pending_endpoint_update = Some(worker_ids); - continue; - } - // TODO: this is not actually hooked up - Err(KvSchedulerError::AllWorkersBusy) => { - tracing::trace!("all workers busy; waiting for more capacity"); - tokio::time::sleep(Duration::from_millis(5)).await; - continue; - } - Err(e) => { - tracing::error!("error scheduling request: {:?}", e); - break 'outer; - } + Err(KvSchedulerError::NoEndpoints) => { + tracing::trace!("no endpoints available; waiting for endpoints update"); + tokio::time::sleep(Duration::from_millis(5)).await; + continue; + } + // TODO: this is not actually hooked up + Err(KvSchedulerError::AllWorkersBusy) => { + tracing::trace!("all workers busy; waiting for more capacity"); + tokio::time::sleep(Duration::from_millis(5)).await; + continue; + } + Err(e) => { + tracing::error!("error scheduling request: {:?}", e); + break; } } } @@ -209,7 +190,8 @@ impl KvScheduler { Ok(KvScheduler { request_tx, - sequences, + prefill_tokens, + barrier: Mutex::new(()), }) } @@ -217,22 +199,21 @@ impl KvScheduler { &self, request_id: String, isl_tokens: usize, - token_seq: Vec, overlaps: OverlapScores, ) -> Result { - let mut sequences = self.sequences.lock().await; - - let (potential_blocks, potential_tokens) = - sequences.potential_blocks_and_tokens(token_seq.clone(), isl_tokens, overlaps.clone()); + // TODO: this is temporary needed for now to ensure blocking of scheduling and updating + // Need to remove in future if we have better algo to truly enable async processing + let _guard = self.barrier.lock().await; let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); let request = SchedulingRequest { isl_tokens, overlaps, - potential_blocks, - potential_tokens, + decode_blocks: HashMap::new(), + prefill_tokens: HashMap::new(), resp_tx, }; + self.request_tx .send(request) .await @@ -241,30 +222,16 @@ impl KvScheduler { .await .map_err(|_| KvSchedulerError::SubscriberShutdown)?; - if let Some(new_worker_ids) = response.endpoints_changed { - sequences.update_workers(new_worker_ids); - } - - sequences.add_request( - request_id, - token_seq, - isl_tokens, - response.overlap_blocks, - response.best_worker_id, - ); - - Ok(response.best_worker_id) - } - - pub async fn mark_prefill_completed(&self, request_id: &String) { - let mut sequences = self.sequences.lock().await; - sequences.mark_prefill_completed(request_id) + let best_worker_id = response.best_worker_id; + let _ = self + .prefill_tokens + .add_prefill(best_worker_id, request_id, response.new_tokens) + .await; + Ok(best_worker_id) } - /// Free all blocks associated with a request - pub async fn free(&self, request_id: &String) { - let mut sequences = self.sequences.lock().await; - sequences.free(request_id) + pub async fn mark_prefill_completed(&self, request_id: &str) { + let _ = self.prefill_tokens.remove_prefill(request_id).await; } } @@ -307,8 +274,9 @@ fn softmax_sample(logits: &HashMap, temperature: f64) -> i64 { let normalized: Vec<_> = values .iter() .map(|&v| { - let norm = v / (max_val - min_val); // Lower is better, so negate + // Note we don't need to do actual min-max norm here, just off by an offset + let norm = v / (max_val - min_val); -norm }) .collect(); @@ -370,10 +338,8 @@ impl WorkerSelector for DefaultWorkerSelector { let request_blocks = isl.div_ceil(block_size as usize); let overlaps = &request.overlaps.scores; - // active blocks for decoding - let potential_active_blocks = &request.potential_blocks; - // active tokens in the batch (processed by the linear layers), mostly prefill tokens - let potential_active_tokens = &request.potential_tokens; + let decode_blocks = &request.decode_blocks; + let prefill_tokens = &request.prefill_tokens; let mut worker_logits = HashMap::new(); let mut max_logit = f64::NEG_INFINITY; @@ -381,52 +347,38 @@ impl WorkerSelector for DefaultWorkerSelector { // Calculate logits for each worker for instance in workers.iter() { let worker_id = instance.instance_id; - // this is the number of tokens each worker would have if the request were scheduled there - let potential_tokens = *potential_active_tokens.get(&worker_id).unwrap_or_else(|| { - tracing::warn!( - "assuming {isl} tokens for {worker_id}, as the endpoint does not exist yet" - ); - &isl - }) as f64; - // this is the number of blocks each worker would have if the request were scheduled there - let potential_blocks = *potential_active_blocks.get(&worker_id).unwrap_or_else(|| - {tracing::warn!("assuming {request_blocks} decoding blocks for {worker_id}, as the endpoint does not exist yet"); - &request_blocks - }) as f64; + // this is the number of prefill tokens the worker would have if the request were scheduled there + let prefill_token = *prefill_tokens.get(&worker_id).unwrap_or(&0); + let overlap = *overlaps.get(&worker_id).unwrap_or(&0); + let new_token = isl - (overlap * block_size) as usize; + let potential_prefill_token = prefill_token + new_token; + let potential_prefill_block = (potential_prefill_token as f64) / (block_size as f64); - let potential_prefill_blocks = potential_tokens / (block_size as f64); + // this is the number of decode blocks currently on the worker + let decode_block = *decode_blocks.get(&worker_id).unwrap_or(&0) as f64; // Calculate logit (lower is better) - let logit = self.kv_router_config.overlap_score_weight * potential_prefill_blocks - + potential_blocks; + let logit = + self.kv_router_config.overlap_score_weight * potential_prefill_block + decode_block; max_logit = max_logit.max(logit); worker_logits.insert(worker_id, logit); tracing::info!( - "Formula for {worker_id}: {logit:.3} = {:.1} * {potential_prefill_blocks:.3} + {potential_blocks:.3} (cached_blocks: {})", + "Formula for {worker_id}: {logit:.3} = {:.1} * {potential_prefill_block:.3} + {decode_block:.3} (cached_blocks: {})", self.kv_router_config.overlap_score_weight, - overlaps.get(&worker_id).unwrap_or(&0), + overlap, ); } - // Normalize by dividing by max value - if max_logit > 0.0 { - for logit in worker_logits.values_mut() { - *logit /= max_logit; - } - } - // Use softmax sampling to select worker let temperature = self.kv_router_config.router_temperature; let best_worker_id = softmax_sample(&worker_logits, temperature); - - let overlap_blocks = overlaps.get(&best_worker_id).copied().unwrap_or(0); let best_logit = worker_logits[&best_worker_id]; tracing::info!( - "Selected worker: {}, normalized logit: {:.3}", + "Selected worker: {}, logit: {:.3}", best_worker_id, best_logit ); @@ -434,7 +386,7 @@ impl WorkerSelector for DefaultWorkerSelector { Ok(WorkerSelectionResult { worker_id: best_worker_id, required_blocks: request_blocks as u64, - overlap_blocks, + overlap_blocks: overlaps.get(&best_worker_id).copied().unwrap_or(0), }) } } diff --git a/lib/llm/src/kv_router/scoring.rs b/lib/llm/src/kv_router/scoring.rs index c065a613ba..5934716c7c 100644 --- a/lib/llm/src/kv_router/scoring.rs +++ b/lib/llm/src/kv_router/scoring.rs @@ -15,10 +15,39 @@ //! Scoring functions for the KV router. +use super::protocols::{ForwardPassMetrics, LoadMetrics}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; -use crate::kv_router::scheduler::Endpoint; +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct LoadEvent { + pub worker_id: i64, + pub data: ForwardPassMetrics, +} + +/// [gluo FIXME] exactly the same as EndpointInfo except that 'data' +/// is cleaned (not optional) +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct Endpoint { + pub name: String, + pub subject: String, + pub data: LoadMetrics, +} + +impl Endpoint { + pub fn worker_id(&self) -> i64 { + i64::from_str_radix( + self.subject + .split("-") + .last() + .expect("invalid subject") + .to_string() + .as_str(), + 16, + ) + .expect("invalid worker id") + } +} #[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq)] pub struct ProcessedEndpoints { From cfba11d3c422c944fdae6557cccda31b8efd0776 Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Sun, 3 Aug 2025 22:40:36 -0700 Subject: [PATCH 05/27] ignore events emitted by counter itself --- lib/llm/src/kv_router/prefill_counter.rs | 41 +++++++++++++----------- lib/llm/src/kv_router/protocols.rs | 2 ++ 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/lib/llm/src/kv_router/prefill_counter.rs b/lib/llm/src/kv_router/prefill_counter.rs index f78d6d0d1a..09e2179287 100644 --- a/lib/llm/src/kv_router/prefill_counter.rs +++ b/lib/llm/src/kv_router/prefill_counter.rs @@ -9,6 +9,7 @@ use futures::StreamExt; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use tokio::sync::RwLock; +use uuid::Uuid; // Remove the Mutex import since we're using DashMap use super::protocols::{PrefillEvent, PrefillEventData}; @@ -37,6 +38,7 @@ where pub struct PrefillCounter { state: Arc>, component: Component, + router_id: Uuid, } struct PrefillCounterState { @@ -52,10 +54,6 @@ impl PrefillCounterState { } } - fn contains_key(&self, key: &str) -> bool { - self.tokens_map.contains_key(key) - } - fn insert(&self, key: String, value: usize) -> Option { let old_value = self.tokens_map.insert(key, value); @@ -91,17 +89,22 @@ impl PrefillCounter { /// and updates the internal state based on received events. pub fn new(component: Component) -> Self { let state = Arc::new(RwLock::new(PrefillCounterState::new())); + let router_id = Uuid::new_v4(); let counter = Self { state: state.clone(), component: component.clone(), + router_id, }; let state_clone = state.clone(); let component_clone = component.clone(); + let router_id_clone = router_id; tokio::spawn(async move { - if let Err(e) = Self::subscribe_to_events(state_clone, component_clone).await { + if let Err(e) = + Self::subscribe_to_events(state_clone, component_clone, router_id_clone).await + { tracing::error!("Error in prefill events subscription: {}", e); } }); @@ -110,10 +113,10 @@ impl PrefillCounter { } /// Background task to subscribe to prefill events and update internal state - /// TODO: somehow try to block events that are sent by itself async fn subscribe_to_events( state: Arc>, component: Component, + router_id: Uuid, ) -> Result<()> { let mut subscriber = component .subscribe_with_type::(PREFILL_SUBJECT) @@ -125,14 +128,13 @@ impl PrefillCounter { continue; }; + // Skip events emitted by itself + if event.router_id == router_id { + continue; + } + match event.data { PrefillEventData::NewPrefill(tokens) => { - let state_read = state.read().await; - if state_read.contains_key(&event.request_id) { - continue; - } - drop(state_read); - let state_write = state.write().await; state_write.insert(event.request_id.clone(), tokens); } @@ -152,14 +154,13 @@ impl PrefillCounter { .insert(event.request_id.clone(), new_tokens); } PrefillEventData::CompletePrefill => { - let state_read = state.read().await; - if !state_read.contains_key(&event.request_id) { - continue; - } - drop(state_read); - let state_write = state.write().await; - state_write.remove(&event.request_id); + if state_write.remove(&event.request_id).is_none() { + tracing::warn!( + "Attempted to remove non-existent request: {}", + event.request_id + ); + } } } } @@ -179,6 +180,7 @@ impl PrefillCounter { } else { PrefillEventData::NewPrefill(tokens) }, + router_id: self.router_id, }; self.component.publish(PREFILL_SUBJECT, &event).await?; @@ -193,6 +195,7 @@ impl PrefillCounter { let event = PrefillEvent { request_id: request_id.to_string(), data: PrefillEventData::CompletePrefill, + router_id: self.router_id, }; self.component.publish(PREFILL_SUBJECT, &event).await?; } diff --git a/lib/llm/src/kv_router/protocols.rs b/lib/llm/src/kv_router/protocols.rs index a81341330d..09a593a9e9 100644 --- a/lib/llm/src/kv_router/protocols.rs +++ b/lib/llm/src/kv_router/protocols.rs @@ -3,6 +3,7 @@ use crate::tokens::{SequenceHash, Token}; use serde::{Deserialize, Serialize}; +use uuid::Uuid; #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct RouterRequest { @@ -120,6 +121,7 @@ impl From for ExternalSequenceBlockHash { pub struct PrefillEvent { pub request_id: String, pub data: PrefillEventData, + pub router_id: Uuid, } /// Represents the different stages of prefilling tokens for a request. From 7e60b8364ff78338aad871cd0e1b86bf0d4bd109 Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Sun, 3 Aug 2025 22:41:33 -0700 Subject: [PATCH 06/27] rm license block --- lib/llm/src/kv_router/scheduler.rs | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/lib/llm/src/kv_router/scheduler.rs b/lib/llm/src/kv_router/scheduler.rs index 6b5a44b642..a1ad2a3053 100644 --- a/lib/llm/src/kv_router/scheduler.rs +++ b/lib/llm/src/kv_router/scheduler.rs @@ -1,17 +1,5 @@ // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. use dashmap::DashMap; use dynamo_runtime::component::{Component, Instance}; From f6d07100afaf85455096db6e64409ec1c3c01fc3 Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Sun, 3 Aug 2025 23:09:20 -0700 Subject: [PATCH 07/27] move event loop to multi-worker counter --- lib/llm/src/kv_router/prefill_counter.rs | 454 ++++++++++++----------- 1 file changed, 228 insertions(+), 226 deletions(-) diff --git a/lib/llm/src/kv_router/prefill_counter.rs b/lib/llm/src/kv_router/prefill_counter.rs index 09e2179287..10e3f6b9b1 100644 --- a/lib/llm/src/kv_router/prefill_counter.rs +++ b/lib/llm/src/kv_router/prefill_counter.rs @@ -4,11 +4,9 @@ use anyhow::Result; use dynamo_runtime::component::Component; use dynamo_runtime::traits::events::{EventPublisher, EventSubscriber}; -use futures::stream::FuturesUnordered; use futures::StreamExt; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; -use tokio::sync::RwLock; use uuid::Uuid; // Remove the Mutex import since we're using DashMap @@ -29,31 +27,13 @@ where .collect() } -/// A counter that tracks pending prefill tokens for each request. -/// -/// This struct maintains a local hashmap of request_id to token count, -/// a running sum of all tokens, and subscribes to prefill events over NATS -/// to keep the counts synchronized across components. -#[derive(Clone)] -pub struct PrefillCounter { - state: Arc>, - component: Component, - router_id: Uuid, -} - +#[derive(Default)] struct PrefillCounterState { tokens_map: DashMap, running_sum: AtomicUsize, } impl PrefillCounterState { - fn new() -> Self { - Self { - tokens_map: DashMap::new(), - running_sum: AtomicUsize::new(0), - } - } - fn insert(&self, key: String, value: usize) -> Option { let old_value = self.tokens_map.insert(key, value); @@ -82,39 +62,105 @@ impl PrefillCounterState { } } +/// A counter that tracks pending prefill tokens for each request. +/// +/// This struct maintains a local hashmap of request_id to token count, +/// and a running sum of all tokens. It no longer handles its own subscriptions. +#[derive(Clone, Default)] +pub struct PrefillCounter { + state: Arc, +} + impl PrefillCounter { - /// Create a new PrefillCounter with the given component. - /// - /// This will start a background task that subscribes to PREFILL_SUBJECT - /// and updates the internal state based on received events. + // Internal methods for direct state manipulation (no publishing) + fn insert_direct(&self, request_id: String, tokens: usize) -> Option { + self.state.insert(request_id, tokens) + } + + fn remove_direct(&self, request_id: &str) -> Option { + self.state.remove(request_id) + } + + fn update_direct(&self, request_id: String, new_tokens: usize) { + if let Some(old_tokens_ref) = self.state.tokens_map.get(&request_id) { + let old_tokens = *old_tokens_ref; + let delta = new_tokens as isize - old_tokens as isize; + self.state + .running_sum + .fetch_add(delta as usize, Ordering::SeqCst); + self.state.tokens_map.insert(request_id, new_tokens); + } + } + + pub fn get(&self, request_id: &str) -> Option { + self.state.tokens_map.get(request_id).map(|entry| *entry) + } + + pub fn running_sum(&self) -> usize { + self.state.running_sum() + } + + pub fn len(&self) -> usize { + self.state.tokens_map.len() + } + + pub fn is_empty(&self) -> bool { + self.state.tokens_map.is_empty() + } + + /// Returns a snapshot of the current state as a HashMap + pub fn snapshot(&self) -> HashMap { + get_snapshot(&self.state.tokens_map) + } +} + +/// A collection of PrefillCounters for multiple workers with centralized event handling +pub struct PrefillCountersMultiWorker { + pub counters: Arc>, + pub request_to_workers: Arc>, + component: Component, + router_id: Uuid, +} + +impl PrefillCountersMultiWorker { pub fn new(component: Component) -> Self { - let state = Arc::new(RwLock::new(PrefillCounterState::new())); + let counters = Arc::new(DashMap::new()); + let request_to_workers = Arc::new(DashMap::new()); let router_id = Uuid::new_v4(); - let counter = Self { - state: state.clone(), + let multi_worker = Self { + counters: counters.clone(), + request_to_workers: request_to_workers.clone(), component: component.clone(), router_id, }; - let state_clone = state.clone(); + // Start the subscription loop + let counters_clone = counters.clone(); + let request_to_workers_clone = request_to_workers.clone(); let component_clone = component.clone(); let router_id_clone = router_id; tokio::spawn(async move { - if let Err(e) = - Self::subscribe_to_events(state_clone, component_clone, router_id_clone).await + if let Err(e) = Self::subscribe_to_events( + counters_clone, + request_to_workers_clone, + component_clone, + router_id_clone, + ) + .await { tracing::error!("Error in prefill events subscription: {}", e); } }); - counter + multi_worker } - /// Background task to subscribe to prefill events and update internal state + /// Background task to subscribe to prefill events and update all counters async fn subscribe_to_events( - state: Arc>, + counters: Arc>, + request_to_workers: Arc>, component: Component, router_id: Uuid, ) -> Result<()> { @@ -135,27 +181,54 @@ impl PrefillCounter { match event.data { PrefillEventData::NewPrefill(tokens) => { - let state_write = state.write().await; - state_write.insert(event.request_id.clone(), tokens); + let Some(worker_id_ref) = request_to_workers.get(&event.request_id) else { + continue; + }; + + let worker_id = *worker_id_ref; + let Some(counter) = counters.get(&worker_id) else { + tracing::warn!( + "No counter found for worker {} when handling NewPrefill for request {}", + worker_id, + event.request_id + ); + continue; + }; + + counter.insert_direct(event.request_id.clone(), tokens); } PrefillEventData::UpdatePrefill(new_tokens) => { - let state_write = state.write().await; - let Some(old_tokens_ref) = state_write.tokens_map.get(&event.request_id) else { + let Some(worker_id_ref) = request_to_workers.get(&event.request_id) else { + continue; + }; + + let worker_id = *worker_id_ref; + let Some(counter) = counters.get(&worker_id) else { + tracing::warn!( + "No counter found for worker {} when handling UpdatePrefill for request {}", + worker_id, + event.request_id + ); continue; }; - let old_tokens = *old_tokens_ref; - - let delta = new_tokens as isize - old_tokens as isize; - state_write - .running_sum - .fetch_add(delta as usize, Ordering::SeqCst); - state_write - .tokens_map - .insert(event.request_id.clone(), new_tokens); + + counter.update_direct(event.request_id.clone(), new_tokens); } PrefillEventData::CompletePrefill => { - let state_write = state.write().await; - if state_write.remove(&event.request_id).is_none() { + let Some((_, worker_id)) = request_to_workers.remove(&event.request_id) else { + continue; + }; + + let Some(counter) = counters.get(&worker_id) else { + tracing::warn!( + "No counter found for worker {} when handling CompletePrefill for request {}", + worker_id, + event.request_id + ); + continue; + }; + + if counter.remove_direct(&event.request_id).is_none() { tracing::warn!( "Attempted to remove non-existent request: {}", event.request_id @@ -168,84 +241,6 @@ impl PrefillCounter { Ok(()) } - pub async fn insert(&self, request_id: String, tokens: usize) -> Result> { - let state = self.state.write().await; - let old_value = state.insert(request_id.clone(), tokens); - - // Send appropriate event based on whether this is a new prefill or an update - let event = PrefillEvent { - request_id, - data: if old_value.is_some() { - PrefillEventData::UpdatePrefill(tokens) - } else { - PrefillEventData::NewPrefill(tokens) - }, - router_id: self.router_id, - }; - self.component.publish(PREFILL_SUBJECT, &event).await?; - - Ok(old_value) - } - - pub async fn remove(&self, request_id: &str) -> Result> { - let state = self.state.write().await; - let removed_tokens = state.remove(request_id); - - if removed_tokens.is_some() { - let event = PrefillEvent { - request_id: request_id.to_string(), - data: PrefillEventData::CompletePrefill, - router_id: self.router_id, - }; - self.component.publish(PREFILL_SUBJECT, &event).await?; - } - - Ok(removed_tokens) - } - - pub async fn get(&self, request_id: &str) -> Option { - let state = self.state.read().await; - state.tokens_map.get(request_id).map(|entry| *entry) - } - - pub async fn running_sum(&self) -> usize { - let state = self.state.read().await; - state.running_sum() - } - - pub async fn len(&self) -> usize { - let state = self.state.read().await; - state.tokens_map.len() - } - - pub async fn is_empty(&self) -> bool { - let state = self.state.read().await; - state.tokens_map.is_empty() - } - - /// Returns a snapshot of the current state as a HashMap - pub async fn snapshot(&self) -> HashMap { - let state = self.state.read().await; - get_snapshot(&state.tokens_map) - } -} - -/// A collection of PrefillCounters for multiple workers -pub struct PrefillCountersMultiWorker { - pub counters: DashMap, - pub request_to_workers: DashMap, - component: Component, -} - -impl PrefillCountersMultiWorker { - pub fn new(component: Component) -> Self { - Self { - counters: DashMap::new(), - request_to_workers: DashMap::new(), - component, - } - } - pub async fn add_prefill( &self, worker_id: i64, @@ -263,17 +258,31 @@ impl PrefillCountersMultiWorker { self.request_to_workers .insert(request_id.clone(), worker_id); - if let Some(counter) = self.counters.get(&worker_id) { - counter.insert(request_id, new_tokens).await?; + let counter = if let Some(counter) = self.counters.get(&worker_id) { + counter.clone() } else { tracing::warn!( "Worker {} does not exist, creating new PrefillCounter", worker_id ); - let new_counter = PrefillCounter::new(self.component.clone()); - new_counter.insert(request_id, new_tokens).await?; - self.counters.insert(worker_id, new_counter); - } + let new_counter = PrefillCounter::default(); + self.counters.insert(worker_id, new_counter.clone()); + new_counter + }; + + let old_value = counter.insert_direct(request_id.clone(), new_tokens); + + // Publish the event + let event = PrefillEvent { + request_id, + data: if old_value.is_some() { + PrefillEventData::UpdatePrefill(new_tokens) + } else { + PrefillEventData::NewPrefill(new_tokens) + }, + router_id: self.router_id, + }; + self.component.publish(PREFILL_SUBJECT, &event).await?; Ok(()) } @@ -285,7 +294,18 @@ impl PrefillCountersMultiWorker { }; if let Some(counter) = self.counters.get(&worker_id) { - counter.remove(request_id).await + let removed_tokens = counter.remove_direct(request_id); + + if removed_tokens.is_some() { + let event = PrefillEvent { + request_id: request_id.to_string(), + data: PrefillEventData::CompletePrefill, + router_id: self.router_id, + }; + self.component.publish(PREFILL_SUBJECT, &event).await?; + } + + Ok(removed_tokens) } else { tracing::warn!( "Worker {} not found in counters for request {}", @@ -298,15 +318,15 @@ impl PrefillCountersMultiWorker { /// Get the running sums for all workers as a HashMap pub async fn running_sums(&self) -> HashMap { - let futures = FuturesUnordered::new(); - - for entry in self.counters.iter() { - let worker_id = *entry.key(); - let counter = entry.value().clone(); - futures.push(async move { (worker_id, counter.running_sum().await) }); - } + self.counters + .iter() + .map(|entry| (*entry.key(), entry.value().running_sum())) + .collect() + } - futures.collect::>().await + /// Get a specific counter's running sum + pub async fn get_worker_sum(&self, worker_id: i64) -> Option { + self.counters.get(&worker_id).map(|c| c.running_sum()) } } @@ -314,12 +334,11 @@ impl PrefillCountersMultiWorker { mod integration_tests { use super::*; use dynamo_runtime::{DistributedRuntime, Runtime}; - use std::collections::HashMap; use tokio::time::Duration; #[tokio::test] #[ignore] - async fn test_prefill_counter_synchronization() -> Result<()> { + async fn test_prefill_counter_multiworker_synchronization() -> Result<()> { // Initialize logging dynamo_runtime::logging::init(); @@ -327,127 +346,110 @@ mod integration_tests { let runtime = Runtime::from_current()?; let distributed = DistributedRuntime::from_settings(runtime.clone()).await?; - // Create namespace and a single component - let namespace = distributed.namespace("test_prefill_counter")?; + // Create namespace and components + let namespace = distributed.namespace("test_prefill_multiworker")?; let component = namespace - .component("shared_counter")? + .component("counters")? .service_builder() .create() .await?; - // Create two PrefillCounter instances using the same component (cloned) - let counter1 = PrefillCounter::new(component.clone()); - let counter2 = PrefillCounter::new(component.clone()); + // Create two PrefillCountersMultiWorker instances + let multi_worker1 = PrefillCountersMultiWorker::new(component.clone()); + let multi_worker2 = PrefillCountersMultiWorker::new(component.clone()); // Give some time for subscribers to initialize tokio::time::sleep(Duration::from_millis(2000)).await; - // Track all request_ids and their token counts for verification - let mut expected_tokens = HashMap::new(); + let worker_id_1 = 1; + let worker_id_2 = 2; let tokens_per_request = 100; - let requests_per_counter = 50; + let requests_per_worker = 10; - // Send 50 requests to counter1 - for i in 0..requests_per_counter { - let request_id = format!("counter1_request_{}", i); - counter1 - .insert(request_id.clone(), tokens_per_request) + // Send requests to multi_worker1's worker + for i in 0..requests_per_worker { + let request_id = format!("mw1_request_{}", i); + multi_worker1 + .add_prefill(worker_id_1, request_id, tokens_per_request) .await?; - expected_tokens.insert(request_id, tokens_per_request); } - // Send 50 requests to counter2 - for i in 0..requests_per_counter { - let request_id = format!("counter2_request_{}", i); - counter2 - .insert(request_id.clone(), tokens_per_request) + // Send requests to multi_worker2's worker + for i in 0..requests_per_worker { + let request_id = format!("mw2_request_{}", i); + multi_worker2 + .add_prefill(worker_id_2, request_id, tokens_per_request) .await?; - expected_tokens.insert(request_id, tokens_per_request); } // Wait for synchronization - tokio::time::sleep(Duration::from_millis(10)).await; + tokio::time::sleep(Duration::from_millis(100)).await; - // Verify both counters have the same running sum - let expected_sum = (requests_per_counter * 2) * tokens_per_request; - let sum1 = counter1.running_sum().await; - let sum2 = counter2.running_sum().await; + // Verify both multi-workers see all requests + let sums1 = multi_worker1.running_sums().await; + let sums2 = multi_worker2.running_sums().await; + // Each multi-worker should see both workers assert_eq!( - sum1, expected_sum, - "Counter1 running sum mismatch. Expected: {}, Got: {}", - expected_sum, sum1 + sums1.get(&worker_id_1), + Some(&(requests_per_worker * tokens_per_request)), + "MultiWorker1 should see worker 1's requests" ); assert_eq!( - sum2, expected_sum, - "Counter2 running sum mismatch. Expected: {}, Got: {}", - expected_sum, sum2 + sums1.get(&worker_id_2), + Some(&(requests_per_worker * tokens_per_request)), + "MultiWorker1 should see worker 2's requests" ); - - // Verify both counters have all 100 requests - let len1 = counter1.len().await; - let len2 = counter2.len().await; assert_eq!( - len1, - requests_per_counter * 2, - "Counter1 should have {} requests", - requests_per_counter * 2 + sums2.get(&worker_id_1), + Some(&(requests_per_worker * tokens_per_request)), + "MultiWorker2 should see worker 1's requests" ); assert_eq!( - len2, - requests_per_counter * 2, - "Counter2 should have {} requests", - requests_per_counter * 2 + sums2.get(&worker_id_2), + Some(&(requests_per_worker * tokens_per_request)), + "MultiWorker2 should see worker 2's requests" ); - // Spot check some individual requests on both counters - for i in 0..5 { - let request_id = format!("counter1_request_{}", i); - let tokens1 = counter1.get(&request_id).await; - let tokens2 = counter2.get(&request_id).await; - assert_eq!( - tokens1, - Some(tokens_per_request), - "Counter1 missing request {}", - request_id - ); - assert_eq!( - tokens2, - Some(tokens_per_request), - "Counter2 missing request {}", - request_id - ); + // Remove all requests from multi_worker1 + for i in 0..requests_per_worker { + let request_id = format!("mw1_request_{}", i); + multi_worker1.remove_prefill(&request_id).await?; } - // Now remove all requests from both counters - for i in 0..requests_per_counter { - let request_id = format!("counter1_request_{}", i); - counter1.remove(&request_id).await?; - } - - for i in 0..requests_per_counter { - let request_id = format!("counter2_request_{}", i); - counter2.remove(&request_id).await?; + // Remove all requests from multi_worker2 + for i in 0..requests_per_worker { + let request_id = format!("mw2_request_{}", i); + multi_worker2.remove_prefill(&request_id).await?; } // Wait for removal synchronization - tokio::time::sleep(Duration::from_millis(10)).await; + tokio::time::sleep(Duration::from_millis(100)).await; + + // Verify both multi-workers show zero sums + let final_sums1 = multi_worker1.running_sums().await; + let final_sums2 = multi_worker2.running_sums().await; - // Verify both counters have zero running sum - let final_sum1 = counter1.running_sum().await; - let final_sum2 = counter2.running_sum().await; assert_eq!( - final_sum1, 0, - "Counter1 should have zero running sum after removal" + final_sums1.get(&worker_id_1).copied().unwrap_or(0), + 0, + "MultiWorker1 should show zero for worker 1" ); assert_eq!( - final_sum2, 0, - "Counter2 should have zero running sum after removal" + final_sums1.get(&worker_id_2).copied().unwrap_or(0), + 0, + "MultiWorker1 should show zero for worker 2" + ); + assert_eq!( + final_sums2.get(&worker_id_1).copied().unwrap_or(0), + 0, + "MultiWorker2 should show zero for worker 1" + ); + assert_eq!( + final_sums2.get(&worker_id_2).copied().unwrap_or(0), + 0, + "MultiWorker2 should show zero for worker 2" ); - - // Verify both counters are empty - assert!(counter1.is_empty().await, "Counter1 should be empty"); - assert!(counter2.is_empty().await, "Counter2 should be empty"); // Shutdown runtime runtime.shutdown(); From a9eb5b32e3afd127809880568d8b8b025ae0454e Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Mon, 4 Aug 2025 00:08:57 -0700 Subject: [PATCH 08/27] refactor --- lib/llm/src/kv_router/prefill_counter.rs | 213 ++++++++++++----------- lib/llm/src/kv_router/protocols.rs | 1 + 2 files changed, 108 insertions(+), 106 deletions(-) diff --git a/lib/llm/src/kv_router/prefill_counter.rs b/lib/llm/src/kv_router/prefill_counter.rs index 10e3f6b9b1..ceca42d9d4 100644 --- a/lib/llm/src/kv_router/prefill_counter.rs +++ b/lib/llm/src/kv_router/prefill_counter.rs @@ -81,6 +81,7 @@ impl PrefillCounter { self.state.remove(request_id) } + #[allow(dead_code)] fn update_direct(&self, request_id: String, new_tokens: usize) { if let Some(old_tokens_ref) = self.state.tokens_map.get(&request_id) { let old_tokens = *old_tokens_ref; @@ -123,6 +124,72 @@ pub struct PrefillCountersMultiWorker { } impl PrefillCountersMultiWorker { + // Helper function to handle new prefill logic + fn handle_new_prefill( + counters: &Arc>, + request_to_workers: &Arc>, + request_id: &str, + worker_id: i64, + tokens: usize, + ) { + // Check if request already exists + if let Some(existing_worker_id) = request_to_workers.get(request_id) { + tracing::warn!( + "Request {} already exists for worker {}, but trying to add to worker {}", + request_id, + *existing_worker_id, + worker_id + ); + } + + // Update mapping + request_to_workers.insert(request_id.to_string(), worker_id); + + // Get or create counter and insert + let Some(counter) = counters.get(&worker_id) else { + tracing::warn!( + "Worker {} does not exist, creating new PrefillCounter", + worker_id + ); + let new_counter = PrefillCounter::default(); + new_counter.insert_direct(request_id.to_string(), tokens); + counters.insert(worker_id, new_counter); + return; + }; + + counter.insert_direct(request_id.to_string(), tokens); + } + + // Helper function to handle complete prefill logic + fn handle_complete_prefill( + counters: &Arc>, + request_to_workers: &Arc>, + request_id: &str, + ) -> Option { + // Remove from request_to_workers and get the worker_id + let Some((_, worker_id)) = request_to_workers.remove(request_id) else { + tracing::warn!("Request {} not found in request_to_workers", request_id); + return None; + }; + + // Use the worker_id from request_to_workers + let Some(counter) = counters.get(&worker_id) else { + tracing::warn!( + "No counter found for worker {} for request {}", + worker_id, + request_id + ); + return None; + }; + + let removed_tokens = counter.remove_direct(request_id); + if removed_tokens.is_none() { + tracing::warn!("Attempted to remove non-existent request: {}", request_id); + } + + removed_tokens + } + pub fn new(component: Component) -> Self { let counters = Arc::new(DashMap::new()); let request_to_workers = Arc::new(DashMap::new()); @@ -181,59 +248,24 @@ impl PrefillCountersMultiWorker { match event.data { PrefillEventData::NewPrefill(tokens) => { - let Some(worker_id_ref) = request_to_workers.get(&event.request_id) else { - continue; - }; - - let worker_id = *worker_id_ref; - let Some(counter) = counters.get(&worker_id) else { - tracing::warn!( - "No counter found for worker {} when handling NewPrefill for request {}", - worker_id, - event.request_id - ); - continue; - }; - - counter.insert_direct(event.request_id.clone(), tokens); + Self::handle_new_prefill( + &counters, + &request_to_workers, + &event.request_id, + event.worker_id, + tokens, + ); } - PrefillEventData::UpdatePrefill(new_tokens) => { - let Some(worker_id_ref) = request_to_workers.get(&event.request_id) else { - continue; - }; - - let worker_id = *worker_id_ref; - let Some(counter) = counters.get(&worker_id) else { - tracing::warn!( - "No counter found for worker {} when handling UpdatePrefill for request {}", - worker_id, - event.request_id - ); - continue; - }; - - counter.update_direct(event.request_id.clone(), new_tokens); + PrefillEventData::UpdatePrefill(_) => { + // Do nothing for now + continue; } PrefillEventData::CompletePrefill => { - let Some((_, worker_id)) = request_to_workers.remove(&event.request_id) else { - continue; - }; - - let Some(counter) = counters.get(&worker_id) else { - tracing::warn!( - "No counter found for worker {} when handling CompletePrefill for request {}", - worker_id, - event.request_id - ); - continue; - }; - - if counter.remove_direct(&event.request_id).is_none() { - tracing::warn!( - "Attempted to remove non-existent request: {}", - event.request_id - ); - } + Self::handle_complete_prefill( + &counters, + &request_to_workers, + &event.request_id, + ); } } } @@ -247,73 +279,42 @@ impl PrefillCountersMultiWorker { request_id: String, new_tokens: usize, ) -> Result<()> { - if let Some(existing_worker_id) = self.request_to_workers.get(&request_id) { - tracing::warn!( - "Request {} already exists for worker {}, but trying to add to worker {}", - request_id, - *existing_worker_id, - worker_id - ); - } - self.request_to_workers - .insert(request_id.clone(), worker_id); - - let counter = if let Some(counter) = self.counters.get(&worker_id) { - counter.clone() - } else { - tracing::warn!( - "Worker {} does not exist, creating new PrefillCounter", - worker_id - ); - let new_counter = PrefillCounter::default(); - self.counters.insert(worker_id, new_counter.clone()); - new_counter - }; - - let old_value = counter.insert_direct(request_id.clone(), new_tokens); - - // Publish the event let event = PrefillEvent { - request_id, - data: if old_value.is_some() { - PrefillEventData::UpdatePrefill(new_tokens) - } else { - PrefillEventData::NewPrefill(new_tokens) - }, + request_id: request_id.clone(), + worker_id, + data: PrefillEventData::NewPrefill(new_tokens), router_id: self.router_id, }; self.component.publish(PREFILL_SUBJECT, &event).await?; + // Use the helper function + Self::handle_new_prefill( + &self.counters, + &self.request_to_workers, + &request_id, + worker_id, + new_tokens, + ); + Ok(()) } pub async fn remove_prefill(&self, request_id: &str) -> Result> { - let Some((_request_id, worker_id)) = self.request_to_workers.remove(request_id) else { - tracing::warn!("Request {} not found", request_id); - return Ok(None); + // Send the event first with dummy worker_id + let event = PrefillEvent { + request_id: request_id.to_string(), + worker_id: 0, // Dummy worker_id + data: PrefillEventData::CompletePrefill, + router_id: self.router_id, }; + self.component.publish(PREFILL_SUBJECT, &event).await?; - if let Some(counter) = self.counters.get(&worker_id) { - let removed_tokens = counter.remove_direct(request_id); - - if removed_tokens.is_some() { - let event = PrefillEvent { - request_id: request_id.to_string(), - data: PrefillEventData::CompletePrefill, - router_id: self.router_id, - }; - self.component.publish(PREFILL_SUBJECT, &event).await?; - } - - Ok(removed_tokens) - } else { - tracing::warn!( - "Worker {} not found in counters for request {}", - worker_id, - request_id - ); - Ok(None) - } + // Use the helper function + Ok(Self::handle_complete_prefill( + &self.counters, + &self.request_to_workers, + request_id, + )) } /// Get the running sums for all workers as a HashMap diff --git a/lib/llm/src/kv_router/protocols.rs b/lib/llm/src/kv_router/protocols.rs index 09a593a9e9..28170474ee 100644 --- a/lib/llm/src/kv_router/protocols.rs +++ b/lib/llm/src/kv_router/protocols.rs @@ -120,6 +120,7 @@ impl From for ExternalSequenceBlockHash { #[derive(Serialize, Deserialize, Debug, Clone)] pub struct PrefillEvent { pub request_id: String, + pub worker_id: i64, pub data: PrefillEventData, pub router_id: Uuid, } From 32d2739919e43fed653e5a67a3f92288b55e8a58 Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Mon, 4 Aug 2025 01:03:20 -0700 Subject: [PATCH 09/27] don't bother keeping the 1 as decode token --- lib/llm/src/kv_router/sequence.rs | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/lib/llm/src/kv_router/sequence.rs b/lib/llm/src/kv_router/sequence.rs index 060eaf1ed3..0fb822659a 100644 --- a/lib/llm/src/kv_router/sequence.rs +++ b/lib/llm/src/kv_router/sequence.rs @@ -1,17 +1,5 @@ // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. //! KV Cache Sequence Management for LLM Inference //! @@ -134,7 +122,7 @@ impl ActiveSequences { if let Some(tokens) = self.prefill_tokens.remove(request_id) { self.active_tokens = self .active_tokens - .checked_sub(tokens.saturating_sub(1)) // Keep 1 token for decoding + .checked_sub(tokens) .expect("active_tokens underflow"); } } @@ -171,11 +159,7 @@ impl ActiveSequences { /// Free all blocks associated with a request pub fn free(&mut self, request_id: &RequestId) -> usize { - // decoding has one active token - self.active_tokens = self - .active_tokens - .checked_sub(self.prefill_tokens.remove(request_id).unwrap_or(1)) - .expect("active_tokens < 0"); + self.mark_prefill_completed(request_id); let Some(token_seq) = self.active_seqs.get(request_id) else { tracing::warn!("Trying to free free non-existent request {request_id}"); From f0f5ec3fffe69e2fce40250b9239b7cf88d0c3d4 Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Mon, 4 Aug 2025 08:45:20 -0700 Subject: [PATCH 10/27] lose the dashmap in the inner counter --- lib/llm/src/kv_router/prefill_counter.rs | 285 +++++++++++++++-------- 1 file changed, 185 insertions(+), 100 deletions(-) diff --git a/lib/llm/src/kv_router/prefill_counter.rs b/lib/llm/src/kv_router/prefill_counter.rs index ceca42d9d4..5052faf752 100644 --- a/lib/llm/src/kv_router/prefill_counter.rs +++ b/lib/llm/src/kv_router/prefill_counter.rs @@ -5,10 +5,8 @@ use anyhow::Result; use dynamo_runtime::component::Component; use dynamo_runtime::traits::events::{EventPublisher, EventSubscriber}; use futures::StreamExt; -use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use uuid::Uuid; -// Remove the Mutex import since we're using DashMap use super::protocols::{PrefillEvent, PrefillEventData}; use crate::kv_router::PREFILL_SUBJECT; @@ -29,36 +27,38 @@ where #[derive(Default)] struct PrefillCounterState { - tokens_map: DashMap, - running_sum: AtomicUsize, + tokens_map: HashMap, // Plain HashMap + running_sum: usize, // Plain usize } impl PrefillCounterState { - fn insert(&self, key: String, value: usize) -> Option { + fn insert(&mut self, key: String, value: usize) -> Option { + // Takes &mut self let old_value = self.tokens_map.insert(key, value); if let Some(old) = old_value { - self.running_sum.fetch_sub(old, Ordering::SeqCst); - self.running_sum.fetch_add(value, Ordering::SeqCst); + self.running_sum -= old; + self.running_sum += value; } else { - self.running_sum.fetch_add(value, Ordering::SeqCst); + self.running_sum += value; } old_value } - fn remove(&self, key: &str) -> Option { - let removed = self.tokens_map.remove(key).map(|(_, v)| v); + fn remove(&mut self, key: &str) -> Option { + // Takes &mut self + let removed = self.tokens_map.remove(key); if let Some(value) = removed { - self.running_sum.fetch_sub(value, Ordering::SeqCst); + self.running_sum -= value; } removed } fn running_sum(&self) -> usize { - self.running_sum.load(Ordering::SeqCst) + self.running_sum } } @@ -66,35 +66,35 @@ impl PrefillCounterState { /// /// This struct maintains a local hashmap of request_id to token count, /// and a running sum of all tokens. It no longer handles its own subscriptions. -#[derive(Clone, Default)] +#[derive(Default)] // Removed Clone pub struct PrefillCounter { - state: Arc, + state: PrefillCounterState, // No Arc, direct ownership } impl PrefillCounter { // Internal methods for direct state manipulation (no publishing) - fn insert_direct(&self, request_id: String, tokens: usize) -> Option { + fn insert_direct(&mut self, request_id: String, tokens: usize) -> Option { + // Takes &mut self self.state.insert(request_id, tokens) } - fn remove_direct(&self, request_id: &str) -> Option { + fn remove_direct(&mut self, request_id: &str) -> Option { + // Takes &mut self self.state.remove(request_id) } #[allow(dead_code)] - fn update_direct(&self, request_id: String, new_tokens: usize) { - if let Some(old_tokens_ref) = self.state.tokens_map.get(&request_id) { - let old_tokens = *old_tokens_ref; + fn update_direct(&mut self, request_id: String, new_tokens: usize) { + // Takes &mut self + if let Some(old_tokens) = self.state.tokens_map.get(&request_id).copied() { let delta = new_tokens as isize - old_tokens as isize; - self.state - .running_sum - .fetch_add(delta as usize, Ordering::SeqCst); + self.state.running_sum = (self.state.running_sum as isize + delta) as usize; self.state.tokens_map.insert(request_id, new_tokens); } } pub fn get(&self, request_id: &str) -> Option { - self.state.tokens_map.get(request_id).map(|entry| *entry) + self.state.tokens_map.get(request_id).copied() } pub fn running_sum(&self) -> usize { @@ -108,11 +108,6 @@ impl PrefillCounter { pub fn is_empty(&self) -> bool { self.state.tokens_map.is_empty() } - - /// Returns a snapshot of the current state as a HashMap - pub fn snapshot(&self) -> HashMap { - get_snapshot(&self.state.tokens_map) - } } /// A collection of PrefillCounters for multiple workers with centralized event handling @@ -145,19 +140,18 @@ impl PrefillCountersMultiWorker { // Update mapping request_to_workers.insert(request_id.to_string(), worker_id); - // Get or create counter and insert - let Some(counter) = counters.get(&worker_id) else { + // Get or create counter and insert using get_mut + if let Some(mut counter) = counters.get_mut(&worker_id) { + counter.insert_direct(request_id.to_string(), tokens); + } else { tracing::warn!( "Worker {} does not exist, creating new PrefillCounter", worker_id ); - let new_counter = PrefillCounter::default(); + let mut new_counter = PrefillCounter::default(); new_counter.insert_direct(request_id.to_string(), tokens); counters.insert(worker_id, new_counter); - return; }; - - counter.insert_direct(request_id.to_string(), tokens); } // Helper function to handle complete prefill logic @@ -172,8 +166,8 @@ impl PrefillCountersMultiWorker { return None; }; - // Use the worker_id from request_to_workers - let Some(counter) = counters.get(&worker_id) else { + // Use the worker_id from request_to_workers with get_mut + let Some(mut counter) = counters.get_mut(&worker_id) else { tracing::warn!( "No counter found for worker {} for request {}", worker_id, @@ -335,62 +329,174 @@ impl PrefillCountersMultiWorker { mod integration_tests { use super::*; use dynamo_runtime::{DistributedRuntime, Runtime}; + use std::sync::{Arc, Mutex}; + use std::thread; use tokio::time::Duration; - #[tokio::test] + #[test] #[ignore] - async fn test_prefill_counter_multiworker_synchronization() -> Result<()> { - // Initialize logging + fn test_prefill_counter_multiworker_synchronization() -> Result<()> { + // Initialize logging once dynamo_runtime::logging::init(); - // Create runtime and distributed runtime - let runtime = Runtime::from_current()?; - let distributed = DistributedRuntime::from_settings(runtime.clone()).await?; - - // Create namespace and components - let namespace = distributed.namespace("test_prefill_multiworker")?; - let component = namespace - .component("counters")? - .service_builder() - .create() - .await?; - - // Create two PrefillCountersMultiWorker instances - let multi_worker1 = PrefillCountersMultiWorker::new(component.clone()); - let multi_worker2 = PrefillCountersMultiWorker::new(component.clone()); - - // Give some time for subscribers to initialize - tokio::time::sleep(Duration::from_millis(2000)).await; - let worker_id_1 = 1; let worker_id_2 = 2; let tokens_per_request = 100; let requests_per_worker = 10; - // Send requests to multi_worker1's worker - for i in 0..requests_per_worker { - let request_id = format!("mw1_request_{}", i); - multi_worker1 - .add_prefill(worker_id_1, request_id, tokens_per_request) - .await?; - } + // Shared state for collecting results from both threads + let results1 = Arc::new(Mutex::new(None)); + let results2 = Arc::new(Mutex::new(None)); + let final_results1 = Arc::new(Mutex::new(None)); + let final_results2 = Arc::new(Mutex::new(None)); + + let results1_clone = results1.clone(); + let results2_clone = results2.clone(); + let final_results1_clone = final_results1.clone(); + let final_results2_clone = final_results2.clone(); + + // Thread 1: First distributed runtime with multi_worker1 + let handle1 = thread::spawn(move || { + let rt = tokio::runtime::Runtime::new().unwrap(); + + rt.block_on(async { + // Create runtime and distributed runtime + let runtime = Runtime::from_current()?; + let distributed = DistributedRuntime::from_settings(runtime.clone()).await?; + + // Create namespace and components with same names + let namespace = distributed.namespace("test_prefill_multiworker")?; + let component = namespace + .component("counters")? + .service_builder() + .create() + .await?; + + // Create first PrefillCountersMultiWorker instance + let multi_worker1 = PrefillCountersMultiWorker::new(component); + + // Give some time for subscribers to initialize + tokio::time::sleep(Duration::from_millis(3000)).await; + + // Send requests to multi_worker1's worker + for i in 0..requests_per_worker { + let request_id = format!("mw1_request_{}", i); + multi_worker1 + .add_prefill(worker_id_1, request_id, tokens_per_request) + .await?; + } - // Send requests to multi_worker2's worker - for i in 0..requests_per_worker { - let request_id = format!("mw2_request_{}", i); - multi_worker2 - .add_prefill(worker_id_2, request_id, tokens_per_request) - .await?; - } + // Wait for synchronization + tokio::time::sleep(Duration::from_millis(1000)).await; - // Wait for synchronization - tokio::time::sleep(Duration::from_millis(100)).await; + // Get running sums after additions + let sums1 = multi_worker1.running_sums().await; + *results1_clone.lock().unwrap() = Some(sums1); - // Verify both multi-workers see all requests - let sums1 = multi_worker1.running_sums().await; - let sums2 = multi_worker2.running_sums().await; + // Wait for other thread to add its requests + tokio::time::sleep(Duration::from_millis(2000)).await; + + // Remove all requests from multi_worker1 + for i in 0..requests_per_worker { + let request_id = format!("mw1_request_{}", i); + multi_worker1.remove_prefill(&request_id).await?; + } + + // Wait for removal synchronization + tokio::time::sleep(Duration::from_millis(1000)).await; + + // Get final running sums + let final_sums1 = multi_worker1.running_sums().await; + *final_results1_clone.lock().unwrap() = Some(final_sums1); + + // Keep runtime alive a bit longer for synchronization + tokio::time::sleep(Duration::from_millis(1000)).await; + + // Shutdown runtime + runtime.shutdown(); + + Ok::<(), anyhow::Error>(()) + }) + }); + + // Thread 2: Second distributed runtime with multi_worker2 + let handle2 = thread::spawn(move || { + let rt = tokio::runtime::Runtime::new().unwrap(); + + rt.block_on(async { + // Create runtime and distributed runtime + let runtime = Runtime::from_current()?; + let distributed = DistributedRuntime::from_settings(runtime.clone()).await?; + + // Create namespace and components with same names + let namespace = distributed.namespace("test_prefill_multiworker")?; + let component = namespace + .component("counters")? + .service_builder() + .create() + .await?; + + // Create second PrefillCountersMultiWorker instance + let multi_worker2 = PrefillCountersMultiWorker::new(component); + + // Give some time for subscribers to initialize + tokio::time::sleep(Duration::from_millis(3000)).await; + + // Wait a bit to ensure multi_worker1 has started + tokio::time::sleep(Duration::from_millis(500)).await; + + // Send requests to multi_worker2's worker + for i in 0..requests_per_worker { + let request_id = format!("mw2_request_{}", i); + multi_worker2 + .add_prefill(worker_id_2, request_id, tokens_per_request) + .await?; + } + + // Wait for synchronization + tokio::time::sleep(Duration::from_millis(1000)).await; + + // Get running sums after additions + let sums2 = multi_worker2.running_sums().await; + *results2_clone.lock().unwrap() = Some(sums2); + + // Wait for other thread to remove its requests + tokio::time::sleep(Duration::from_millis(2000)).await; + + // Remove all requests from multi_worker2 + for i in 0..requests_per_worker { + let request_id = format!("mw2_request_{}", i); + multi_worker2.remove_prefill(&request_id).await?; + } + + // Wait for removal synchronization + tokio::time::sleep(Duration::from_millis(1000)).await; + + // Get final running sums + let final_sums2 = multi_worker2.running_sums().await; + *final_results2_clone.lock().unwrap() = Some(final_sums2); + + // Keep runtime alive a bit longer for synchronization + tokio::time::sleep(Duration::from_millis(1000)).await; + + // Shutdown runtime + runtime.shutdown(); + + Ok::<(), anyhow::Error>(()) + }) + }); + + // Wait for both threads to complete + handle1.join().unwrap()?; + handle2.join().unwrap()?; + + // Extract results + let sums1 = results1.lock().unwrap().take().unwrap(); + let sums2 = results2.lock().unwrap().take().unwrap(); + let final_sums1 = final_results1.lock().unwrap().take().unwrap(); + let final_sums2 = final_results2.lock().unwrap().take().unwrap(); - // Each multi-worker should see both workers + // Verify both multi-workers see all requests assert_eq!( sums1.get(&worker_id_1), Some(&(requests_per_worker * tokens_per_request)), @@ -412,25 +518,7 @@ mod integration_tests { "MultiWorker2 should see worker 2's requests" ); - // Remove all requests from multi_worker1 - for i in 0..requests_per_worker { - let request_id = format!("mw1_request_{}", i); - multi_worker1.remove_prefill(&request_id).await?; - } - - // Remove all requests from multi_worker2 - for i in 0..requests_per_worker { - let request_id = format!("mw2_request_{}", i); - multi_worker2.remove_prefill(&request_id).await?; - } - - // Wait for removal synchronization - tokio::time::sleep(Duration::from_millis(100)).await; - - // Verify both multi-workers show zero sums - let final_sums1 = multi_worker1.running_sums().await; - let final_sums2 = multi_worker2.running_sums().await; - + // Verify both multi-workers show zero sums after removal assert_eq!( final_sums1.get(&worker_id_1).copied().unwrap_or(0), 0, @@ -452,9 +540,6 @@ mod integration_tests { "MultiWorker2 should show zero for worker 2" ); - // Shutdown runtime - runtime.shutdown(); - Ok(()) } } From fad078d5bc5b1d50f000b0448536e23eea08f6f5 Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Mon, 4 Aug 2025 09:49:59 -0700 Subject: [PATCH 11/27] distributed active sequence --- lib/llm/src/kv_router.rs | 9 +- lib/llm/src/kv_router/protocols.rs | 19 + lib/llm/src/kv_router/sequence.rs | 666 +++++++++++++++++++++++------ 3 files changed, 560 insertions(+), 134 deletions(-) diff --git a/lib/llm/src/kv_router.rs b/lib/llm/src/kv_router.rs index 2e58f83cb6..71bb8397d4 100644 --- a/lib/llm/src/kv_router.rs +++ b/lib/llm/src/kv_router.rs @@ -48,11 +48,18 @@ use dynamo_runtime::traits::events::EventSubscriber; // [gluo TODO] shouldn't need to be public // this should be discovered from the component + +// for metric scraping (pull-based) +pub const KV_METRICS_ENDPOINT: &str = "load_metrics"; + +// for metric publishing (push-based) pub const KV_EVENT_SUBJECT: &str = "kv_events"; pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate"; -pub const KV_METRICS_ENDPOINT: &str = "load_metrics"; pub const KV_METRICS_SUBJECT: &str = "kv_metrics"; + +// for inter-router comms pub const PREFILL_SUBJECT: &str = "prefill_events"; +pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events"; /// A trait that users can implement to define custom selection logic pub trait WorkerSelector { diff --git a/lib/llm/src/kv_router/protocols.rs b/lib/llm/src/kv_router/protocols.rs index 28170474ee..24f33bd0d1 100644 --- a/lib/llm/src/kv_router/protocols.rs +++ b/lib/llm/src/kv_router/protocols.rs @@ -136,6 +136,25 @@ pub enum PrefillEventData { CompletePrefill, } +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ActiveSequenceEvent { + pub request_id: String, + pub worker_id: i64, + pub data: ActiveSequenceEventData, + pub router_id: Uuid, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub enum ActiveSequenceEventData { + AddRequest { + token_sequence: Vec, + isl: usize, + overlap: u32, + }, + Free, + MarkPrefillCompleted, +} + #[derive(Serialize, Deserialize, Debug, Clone)] pub struct ActiveBlockEvent { pub request_id: String, diff --git a/lib/llm/src/kv_router/sequence.rs b/lib/llm/src/kv_router/sequence.rs index 0fb822659a..7cd67b1615 100644 --- a/lib/llm/src/kv_router/sequence.rs +++ b/lib/llm/src/kv_router/sequence.rs @@ -25,11 +25,18 @@ use crate::kv_router::indexer::OverlapScores; use crate::kv_router::indexer::WorkerId; use crate::tokens::SequenceHash; +use anyhow::Result; +use dashmap::DashMap; use derive_getters::Getters; +use dynamo_runtime::component::Component; +use dynamo_runtime::traits::events::{EventPublisher, EventSubscriber}; +use futures::StreamExt; use std::collections::{HashMap, HashSet}; -use std::sync::{mpsc, Arc}; -use std::thread; -use std::time::Duration; +use std::sync::Arc; +use uuid::Uuid; + +use super::protocols::{ActiveSequenceEvent, ActiveSequenceEventData}; +use crate::kv_router::ACTIVE_SEQUENCES_SUBJECT; // TODO: use the common request_id if it exists in the repo pub type RequestId = String; @@ -191,41 +198,45 @@ enum UpdateSequences { }, NewBlocks { token_sequence: Arc>, - resp_tx: mpsc::SyncSender, + resp_tx: tokio::sync::oneshot::Sender, }, PotentialBlocks { token_sequence: Arc>, - resp_tx: mpsc::SyncSender, + resp_tx: tokio::sync::oneshot::Sender, }, PotentialBlocksAndTokens { token_sequence: Arc>, isl: usize, overlap: u32, - resp_tx: mpsc::SyncSender<(usize, usize)>, + resp_tx: tokio::sync::oneshot::Sender<(usize, usize)>, }, ActiveBlocks { - resp_tx: mpsc::SyncSender, + resp_tx: tokio::sync::oneshot::Sender, }, ActiveTokens { - resp_tx: mpsc::SyncSender, + resp_tx: tokio::sync::oneshot::Sender, }, Shutdown, } /// Multi-worker extension of ActiveSequences that distributes requests across multiple threads pub struct ActiveSequencesMultiWorker { - senders: HashMap>, - request_to_worker: HashMap, - handles: HashMap>, + senders: Arc>>, + request_to_worker: Arc>, + handles: Arc>>, block_size: usize, + component: Component, + router_id: Uuid, } impl ActiveSequencesMultiWorker { - pub fn new(block_size: usize, worker_ids: Vec) -> Self { + pub fn new(component: Component, block_size: usize, worker_ids: Vec) -> Self { assert!(block_size > 1, "block_size must be greater than 1"); - let mut senders = HashMap::new(); - let mut handles = HashMap::new(); + let senders = Arc::new(DashMap::new()); + let handles = Arc::new(DashMap::new()); + let request_to_worker = Arc::new(DashMap::new()); + let router_id = Uuid::new_v4(); for worker_id in worker_ids { let (sender, handle) = Self::start_worker(block_size); @@ -233,22 +244,50 @@ impl ActiveSequencesMultiWorker { handles.insert(worker_id, handle); } - Self { - senders, - request_to_worker: HashMap::new(), + let multi_worker = Self { + senders: senders.clone(), + request_to_worker: request_to_worker.clone(), handles, block_size, - } + component: component.clone(), + router_id, + }; + + // Start the subscription loop + let senders_clone = senders.clone(); + let request_to_worker_clone = request_to_worker.clone(); + let component_clone = component.clone(); + let router_id_clone = router_id; + + tokio::spawn(async move { + if let Err(e) = Self::subscribe_to_events( + senders_clone, + request_to_worker_clone, + component_clone, + router_id_clone, + ) + .await + { + tracing::error!("Error in active sequences events subscription: {}", e); + } + }); + + multi_worker } - /// Helper method to start a worker thread - fn start_worker(block_size: usize) -> (mpsc::Sender, thread::JoinHandle<()>) { - let (request_tx, request_rx) = mpsc::channel::(); + /// Helper method to start a worker task + fn start_worker( + block_size: usize, + ) -> ( + tokio::sync::mpsc::UnboundedSender, + tokio::task::JoinHandle<()>, + ) { + let (request_tx, mut request_rx) = tokio::sync::mpsc::unbounded_channel(); - let handle = thread::spawn(move || { + let handle = tokio::spawn(async move { let mut active_sequences = ActiveSequences::new(block_size); - while let Ok(command) = request_rx.recv() { + while let Some(command) = request_rx.recv().await { match command { UpdateSequences::AddRequest { request_id, @@ -309,9 +348,81 @@ impl ActiveSequencesMultiWorker { (request_tx, handle) } + /// Background task to subscribe to active sequence events and update all workers + async fn subscribe_to_events( + senders: Arc>>, + request_to_worker: Arc>, + component: Component, + router_id: Uuid, + ) -> Result<()> { + let mut subscriber = component + .subscribe_with_type::(ACTIVE_SEQUENCES_SUBJECT) + .await?; + + while let Some(result) = subscriber.next().await { + let Ok(event) = result else { + tracing::error!( + "Error receiving active sequence event: {}", + result.unwrap_err() + ); + continue; + }; + + // Skip events emitted by itself + if event.router_id == router_id { + continue; + } + + match &event.data { + ActiveSequenceEventData::AddRequest { + token_sequence, + isl, + overlap, + } => { + request_to_worker.insert(event.request_id.clone(), event.worker_id); + + if let Some(sender) = senders.get(&event.worker_id) { + let _ = sender.send(UpdateSequences::AddRequest { + request_id: event.request_id.clone(), + token_sequence: token_sequence.clone(), + isl: *isl, + overlap: *overlap, + }); + } else { + tracing::warn!( + "Worker {} not found, cannot process AddRequest", + event.worker_id + ); + } + } + ActiveSequenceEventData::Free => { + if let Some((_, worker_id)) = request_to_worker.remove(&event.request_id) { + if let Some(sender) = senders.get(&worker_id) { + let _ = sender.send(UpdateSequences::Free { + request_id: event.request_id.clone(), + }); + } + } + } + ActiveSequenceEventData::MarkPrefillCompleted => { + if let Some(worker_id) = request_to_worker.get(&event.request_id) { + if let Some(sender) = senders.get(&*worker_id) { + let _ = sender.send(UpdateSequences::MarkPrefillCompleted { + request_id: event.request_id.clone(), + }); + } + } + } + } + } + + Ok(()) + } + /// Update the set of workers, adding and removing as needed - pub fn update_workers(&mut self, new_worker_ids: Vec) -> HashMap { - let current_workers: HashSet = self.senders.keys().copied().collect(); + pub fn update_workers(&self, new_worker_ids: Vec) { + let current_workers: HashSet = + self.senders.iter().map(|entry| *entry.key()).collect(); let new_workers: HashSet = new_worker_ids.into_iter().collect(); let workers_to_remove: Vec = @@ -324,11 +435,11 @@ impl ActiveSequencesMultiWorker { tracing::warn!("Removing worker {}", worker_id); // Send shutdown command to the worker - if let Some(sender) = self.senders.remove(worker_id) { + if let Some((_, sender)) = self.senders.remove(worker_id) { let _ = sender.send(UpdateSequences::Shutdown); } - if let Some(handle) = self.handles.remove(worker_id) { - let _ = handle.join(); + if let Some((_, handle)) = self.handles.remove(worker_id) { + handle.abort(); } } @@ -340,64 +451,115 @@ impl ActiveSequencesMultiWorker { self.senders.insert(*worker_id, sender); self.handles.insert(*worker_id, handle); } - - // Return active blocks for all workers - self.active_blocks() } - pub fn add_request( - &mut self, + pub async fn add_request( + &self, request_id: RequestId, token_sequence: Vec, isl: usize, overlap: u32, worker_id: WorkerId, - ) { + ) -> Result<()> { if !self.senders.contains_key(&worker_id) { - panic!("Worker ID {worker_id} not found"); + return Err(anyhow::anyhow!("Worker ID {worker_id} not found")); } + // Publish event + let event = ActiveSequenceEvent { + request_id: request_id.clone(), + worker_id, + data: ActiveSequenceEventData::AddRequest { + token_sequence: token_sequence.clone(), + isl, + overlap, + }, + router_id: self.router_id, + }; + self.component + .publish(ACTIVE_SEQUENCES_SUBJECT, &event) + .await?; + + // Update local state self.request_to_worker.insert(request_id.clone(), worker_id); - self.senders[&worker_id] + self.senders + .get(&worker_id) + .unwrap() .send(UpdateSequences::AddRequest { request_id, token_sequence, isl, overlap, }) - .expect("Failed to send add_request command to worker"); + .map_err(|_| anyhow::anyhow!("Failed to send add_request command to worker"))?; + + Ok(()) } - pub fn free(&mut self, request_id: &RequestId) { + pub async fn free(&self, request_id: &RequestId) -> Result<()> { let worker_id = self .request_to_worker .get(request_id) - .copied() - .expect("Request ID not found in request_to_worker mapping"); - - self.senders[&worker_id] + .map(|entry| *entry) + .ok_or_else(|| anyhow::anyhow!("Request ID not found in request_to_worker mapping"))?; + + // Publish event + let event = ActiveSequenceEvent { + request_id: request_id.clone(), + worker_id, + data: ActiveSequenceEventData::Free, + router_id: self.router_id, + }; + self.component + .publish(ACTIVE_SEQUENCES_SUBJECT, &event) + .await?; + + // Update local state + self.senders + .get(&worker_id) + .unwrap() .send(UpdateSequences::Free { request_id: request_id.clone(), }) - .expect("Failed to send free command to worker"); + .map_err(|_| anyhow::anyhow!("Failed to send free command to worker"))?; self.request_to_worker.remove(request_id); + + Ok(()) } /// Mark prefill as completed for a request - pub fn mark_prefill_completed(&mut self, request_id: &RequestId) { + pub async fn mark_prefill_completed(&self, request_id: &RequestId) -> Result<()> { let worker_id = self .request_to_worker .get(request_id) - .copied() - .expect("Request ID not found in request_to_worker mapping"); - - self.senders[&worker_id] + .map(|entry| *entry) + .ok_or_else(|| anyhow::anyhow!("Request ID not found in request_to_worker mapping"))?; + + // Publish event + let event = ActiveSequenceEvent { + request_id: request_id.clone(), + worker_id, + data: ActiveSequenceEventData::MarkPrefillCompleted, + router_id: self.router_id, + }; + self.component + .publish(ACTIVE_SEQUENCES_SUBJECT, &event) + .await?; + + // Update local state + self.senders + .get(&worker_id) + .unwrap() .send(UpdateSequences::MarkPrefillCompleted { request_id: request_id.clone(), }) - .expect("Failed to send mark_prefill_completed command to worker"); + .map_err(|_| { + anyhow::anyhow!("Failed to send mark_prefill_completed command to worker") + })?; + + Ok(()) } /// Get the number of workers @@ -406,37 +568,49 @@ impl ActiveSequencesMultiWorker { } /// Generic method to query all workers with a given command - fn query_workers( + async fn query_workers( &self, token_sequence: Option>, - command_fn: impl Fn(Option>>, mpsc::SyncSender) -> UpdateSequences, - ) -> HashMap { + command_fn: impl Fn( + Option>>, + tokio::sync::oneshot::Sender, + ) -> UpdateSequences, + ) -> HashMap { let mut results = HashMap::new(); let token_sequence_shared = token_sequence.map(Arc::new); let mut receivers = Vec::new(); // Send queries to all workers in parallel - for (worker_id, sender) in &self.senders { - let (resp_tx, resp_rx) = mpsc::sync_channel(0); + for entry in self.senders.iter() { + let worker_id = *entry.key(); + let sender = entry.value(); + let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); receivers.push((worker_id, resp_rx)); - sender - .send(command_fn(token_sequence_shared.clone(), resp_tx)) - .expect("Failed to send command to worker"); + if let Err(e) = sender.send(command_fn(token_sequence_shared.clone(), resp_tx)) { + tracing::error!("Failed to send command to worker {}: {}", worker_id, e); + } } // Collect results from all workers for (worker_id, receiver) in receivers { - let result = receiver - .recv_timeout(Duration::from_secs(1)) - .expect("Failed to receive response from worker"); - results.insert(*worker_id, result); + match tokio::time::timeout(tokio::time::Duration::from_secs(1), receiver).await { + Ok(Ok(result)) => { + results.insert(worker_id, result); + } + Ok(Err(_)) => { + tracing::error!("Worker {} dropped response channel", worker_id); + } + Err(_) => { + tracing::error!("Timeout waiting for response from worker {}", worker_id); + } + } } results } /// Query all workers for the number of new blocks that would be added by a token sequence - pub fn new_blocks(&self, token_sequence: Vec) -> HashMap { + pub async fn new_blocks(&self, token_sequence: Vec) -> HashMap { self.query_workers(Some(token_sequence), |ts, resp_tx| match ts { Some(ts) => UpdateSequences::NewBlocks { token_sequence: ts, @@ -444,10 +618,14 @@ impl ActiveSequencesMultiWorker { }, None => unreachable!("token_sequence should always be Some for new_blocks"), }) + .await } /// Query all workers for the total number of blocks (new + active) that would be used by a token sequence - pub fn potential_blocks(&self, token_sequence: Vec) -> HashMap { + pub async fn potential_blocks( + &self, + token_sequence: Vec, + ) -> HashMap { self.query_workers(Some(token_sequence), |ts, resp_tx| match ts { Some(ts) => UpdateSequences::PotentialBlocks { token_sequence: ts, @@ -455,10 +633,11 @@ impl ActiveSequencesMultiWorker { }, None => unreachable!("token_sequence should always be Some for potential_blocks"), }) + .await } /// Query all workers for the potential tokens (new + active) that would be used by a token sequence with overlap - pub fn potential_blocks_and_tokens( + pub async fn potential_blocks_and_tokens( &self, token_sequence: Vec, isl: usize, @@ -470,53 +649,68 @@ impl ActiveSequencesMultiWorker { let mut receivers = Vec::new(); // Send queries to all workers in parallel - for (worker_id, sender) in &self.senders { - let (resp_tx, resp_rx) = mpsc::sync_channel(0); + for entry in self.senders.iter() { + let worker_id = *entry.key(); + let sender = entry.value(); + let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); receivers.push((worker_id, resp_rx)); - sender - .send(UpdateSequences::PotentialBlocksAndTokens { - token_sequence: token_sequence_shared.clone(), - isl, - overlap: overlaps.scores.get(worker_id).copied().unwrap_or(0), - resp_tx, - }) - .expect("Failed to send potential_tokens command to worker"); + if let Err(e) = sender.send(UpdateSequences::PotentialBlocksAndTokens { + token_sequence: token_sequence_shared.clone(), + isl, + overlap: overlaps.scores.get(&worker_id).copied().unwrap_or(0), + resp_tx, + }) { + tracing::error!( + "Failed to send potential_tokens command to worker {}: {}", + worker_id, + e + ); + } } // Collect results from all workers for (worker_id, receiver) in receivers { - let (blocks, tokens) = receiver - .recv_timeout(Duration::from_secs(1)) - .expect("Failed to receive response from worker"); - potential_blocks.insert(*worker_id, blocks); - potential_tokens.insert(*worker_id, tokens); + match tokio::time::timeout(tokio::time::Duration::from_secs(1), receiver).await { + Ok(Ok((blocks, tokens))) => { + potential_blocks.insert(worker_id, blocks); + potential_tokens.insert(worker_id, tokens); + } + Ok(Err(_)) => { + tracing::error!("Worker {} dropped response channel", worker_id); + } + Err(_) => { + tracing::error!("Timeout waiting for response from worker {}", worker_id); + } + } } (potential_blocks, potential_tokens) } /// Query all workers for their current number of active blocks - pub fn active_blocks(&self) -> HashMap { + pub async fn active_blocks(&self) -> HashMap { self.query_workers(None, |_, resp_tx| UpdateSequences::ActiveBlocks { resp_tx }) + .await } /// Query all workers for their current number of active tokens - pub fn active_tokens(&self) -> HashMap { + pub async fn active_tokens(&self) -> HashMap { self.query_workers(None, |_, resp_tx| UpdateSequences::ActiveTokens { resp_tx }) + .await } } impl Drop for ActiveSequencesMultiWorker { fn drop(&mut self) { - // Send shutdown command to all workers - for sender in self.senders.values() { - let _ = sender.send(UpdateSequences::Shutdown); + // Send shutdown to all workers + for entry in self.senders.iter() { + let _ = entry.value().send(UpdateSequences::Shutdown); } - // Wait for all threads to finish - for (_, handle) in self.handles.drain() { - let _ = handle.join(); + // Abort all tasks + for entry in self.handles.iter() { + entry.value().abort(); } } } @@ -524,44 +718,248 @@ impl Drop for ActiveSequencesMultiWorker { #[cfg(test)] mod tests { use super::*; + use dynamo_runtime::{DistributedRuntime, Runtime}; + use std::sync::{Arc, Mutex}; + use std::thread; #[test] - fn test_multi_worker_block_sharing() { - // Create multi-worker sequence manager with 3 workers + #[ignore] + fn test_multi_worker_block_sharing() -> Result<()> { + // Initialize logging once + dynamo_runtime::logging::init(); + let block_size = 4; // arbitrary block size - let worker_ids = vec![0, 1, 2]; - let mut seq_manager = ActiveSequencesMultiWorker::new(block_size, worker_ids); - - // Add requests to each worker - // Worker 0: sequence [0, 1, 2] - seq_manager.add_request( - "request_0".to_string(), - vec![0, 1, 2], - 12, // ISL (3 blocks * 4 block_size) - 0, // no overlap - 0, // worker_id - ); - // Worker 1: sequence [3, 4] - seq_manager.add_request( - "request_1".to_string(), - vec![3, 4], - 8, // ISL (2 blocks * 4 block_size) - 0, // no overlap - 1, // worker_id - ); + // Shared state for collecting results from both threads + let active_tokens_after_add = Arc::new(Mutex::new(HashMap::new())); + let potential_blocks_result = Arc::new(Mutex::new(HashMap::new())); + let active_blocks_after_free = Arc::new(Mutex::new(HashMap::new())); + let active_tokens_after_free = Arc::new(Mutex::new(HashMap::new())); + + let active_tokens_after_add_clone = active_tokens_after_add.clone(); + let potential_blocks_result_clone = potential_blocks_result.clone(); + let active_blocks_after_free_clone = active_blocks_after_free.clone(); + let active_tokens_after_free_clone = active_tokens_after_free.clone(); + + // Clone again for the second thread + let active_tokens_after_add_clone2 = active_tokens_after_add.clone(); + let potential_blocks_result_clone2 = potential_blocks_result.clone(); + let active_blocks_after_free_clone2 = active_blocks_after_free.clone(); + let active_tokens_after_free_clone2 = active_tokens_after_free.clone(); + + // Thread 1: First runtime with workers 0 and 1 + let handle1 = thread::spawn(move || { + let rt = tokio::runtime::Runtime::new().unwrap(); + + rt.block_on(async { + // Create runtime and distributed runtime + let runtime = Runtime::from_current()?; + let distributed = DistributedRuntime::from_settings(runtime.clone()).await?; + + // Create namespace and component with same names as thread 2 + let namespace = distributed.namespace("test_multiworker_sequences")?; + let component = namespace + .component("sequences")? + .service_builder() + .create() + .await?; + + // Create multi-worker sequence manager with workers 0 and 1 + let worker_ids = vec![0, 1]; + let seq_manager = + ActiveSequencesMultiWorker::new(component, block_size, worker_ids); + + // Give some time for the subscription loop to start + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Add requests to workers + // Worker 0: sequence [0, 1, 2] + seq_manager + .add_request( + "request_0".to_string(), + vec![0, 1, 2], + 12, // ISL (3 blocks * 4 block_size) + 0, // no overlap + 0, // worker_id + ) + .await?; + + // Worker 1: sequence [3, 4] + seq_manager + .add_request( + "request_1".to_string(), + vec![3, 4], + 8, // ISL (2 blocks * 4 block_size) + 0, // no overlap + 1, // worker_id + ) + .await?; + + // Give some time for the commands to be processed and synchronization + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Get active tokens from workers 0 and 1 + let tokens = seq_manager.active_tokens().await; + active_tokens_after_add_clone + .lock() + .unwrap() + .insert(0, tokens.get(&0).copied().unwrap_or(0)); + active_tokens_after_add_clone + .lock() + .unwrap() + .insert(1, tokens.get(&1).copied().unwrap_or(0)); + + // Test potential blocks for sequence [0, 1] + let potential = seq_manager.potential_blocks(vec![0, 1]).await; + potential_blocks_result_clone + .lock() + .unwrap() + .insert(0, potential.get(&0).copied().unwrap_or(0)); + potential_blocks_result_clone + .lock() + .unwrap() + .insert(1, potential.get(&1).copied().unwrap_or(0)); + + // Wait for second thread to process its requests + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Free requests from workers 0 and 1 + seq_manager.free(&"request_0".to_string()).await?; + seq_manager.free(&"request_1".to_string()).await?; + + // Give some time for the commands to be processed + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Get final active blocks and tokens + let blocks = seq_manager.active_blocks().await; + let tokens = seq_manager.active_tokens().await; + + active_blocks_after_free_clone + .lock() + .unwrap() + .insert(0, blocks.get(&0).copied().unwrap_or(0)); + active_blocks_after_free_clone + .lock() + .unwrap() + .insert(1, blocks.get(&1).copied().unwrap_or(0)); + active_tokens_after_free_clone + .lock() + .unwrap() + .insert(0, tokens.get(&0).copied().unwrap_or(0)); + active_tokens_after_free_clone + .lock() + .unwrap() + .insert(1, tokens.get(&1).copied().unwrap_or(0)); + + // Keep runtime alive a bit longer for synchronization + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Shutdown runtime + runtime.shutdown(); + + Ok::<(), anyhow::Error>(()) + }) + }); - // Worker 2: sequence [0, 1, 2, 3] - seq_manager.add_request( - "request_2".to_string(), - vec![0, 1, 2, 3], - 16, // ISL (4 blocks * 4 block_size) - 0, // no overlap - 2, // worker_id - ); + // Thread 2: Second runtime with worker 2 + let handle2 = thread::spawn(move || { + let rt = tokio::runtime::Runtime::new().unwrap(); + + rt.block_on(async { + // Create runtime and distributed runtime + let runtime = Runtime::from_current()?; + let distributed = DistributedRuntime::from_settings(runtime.clone()).await?; + + // Create namespace and component with same names as thread 1 + let namespace = distributed.namespace("test_multiworker_sequences")?; + let component = namespace + .component("sequences")? + .service_builder() + .create() + .await?; + + // Create multi-worker sequence manager with worker 2 + let worker_ids = vec![2]; + let seq_manager = + ActiveSequencesMultiWorker::new(component, block_size, worker_ids); + + // Give some time for the subscription loop to start + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Wait a bit to ensure thread 1 has started + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Worker 2: sequence [0, 1, 2, 3] + seq_manager + .add_request( + "request_2".to_string(), + vec![0, 1, 2, 3], + 16, // ISL (4 blocks * 4 block_size) + 0, // no overlap + 2, // worker_id + ) + .await?; + + // Give some time for the commands to be processed and synchronization + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Get active tokens from worker 2 + let tokens = seq_manager.active_tokens().await; + active_tokens_after_add_clone2 + .lock() + .unwrap() + .insert(2, tokens.get(&2).copied().unwrap_or(0)); + + // Test potential blocks for sequence [0, 1] + let potential = seq_manager.potential_blocks(vec![0, 1]).await; + potential_blocks_result_clone2 + .lock() + .unwrap() + .insert(2, potential.get(&2).copied().unwrap_or(0)); + + // Wait for first thread to free its requests + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Free request from worker 2 + seq_manager.free(&"request_2".to_string()).await?; + + // Give some time for the commands to be processed + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Get final active blocks and tokens + let blocks = seq_manager.active_blocks().await; + let tokens = seq_manager.active_tokens().await; + + active_blocks_after_free_clone2 + .lock() + .unwrap() + .insert(2, blocks.get(&2).copied().unwrap_or(0)); + active_tokens_after_free_clone2 + .lock() + .unwrap() + .insert(2, tokens.get(&2).copied().unwrap_or(0)); + + // Keep runtime alive a bit longer for synchronization + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Shutdown runtime + runtime.shutdown(); + + Ok::<(), anyhow::Error>(()) + }) + }); + + // Wait for both threads to complete + handle1.join().unwrap()?; + handle2.join().unwrap()?; + + // Extract results + let tokens_after_add = active_tokens_after_add.lock().unwrap(); + let potential_blocks = potential_blocks_result.lock().unwrap(); + let blocks_after_free = active_blocks_after_free.lock().unwrap(); + let tokens_after_free = active_tokens_after_free.lock().unwrap(); // Verify active tokens after adding requests - let tokens_after_add = seq_manager.active_tokens(); assert_eq!( tokens_after_add[&0], 12, "Worker 0 should have 12 active tokens" @@ -576,8 +974,6 @@ mod tests { ); // Test potential blocks for sequence [0, 1] - let potential_blocks = seq_manager.potential_blocks(vec![0, 1]); - // Worker 0 should return 3 (already has blocks 0, 1, 2, so no new blocks needed for [0, 1]) assert_eq!( potential_blocks[&0], 3, @@ -596,30 +992,34 @@ mod tests { "Worker 2 should have 4 potential blocks" ); - // Free all original requests - seq_manager.free(&"request_0".to_string()); - seq_manager.free(&"request_1".to_string()); - seq_manager.free(&"request_2".to_string()); - // Verify active blocks are zero for all workers - let active_blocks = seq_manager.active_blocks(); - assert_eq!(active_blocks[&0], 0, "Worker 0 should have 0 active blocks"); - assert_eq!(active_blocks[&1], 0, "Worker 1 should have 0 active blocks"); - assert_eq!(active_blocks[&2], 0, "Worker 2 should have 0 active blocks"); + assert_eq!( + blocks_after_free[&0], 0, + "Worker 0 should have 0 active blocks" + ); + assert_eq!( + blocks_after_free[&1], 0, + "Worker 1 should have 0 active blocks" + ); + assert_eq!( + blocks_after_free[&2], 0, + "Worker 2 should have 0 active blocks" + ); // Verify active tokens are zero for all workers - let final_tokens = seq_manager.active_tokens(); assert_eq!( - final_tokens[&0], 0, + tokens_after_free[&0], 0, "Worker 0 should have 0 active tokens after freeing all" ); assert_eq!( - final_tokens[&1], 0, + tokens_after_free[&1], 0, "Worker 1 should have 0 active tokens after freeing all" ); assert_eq!( - final_tokens[&2], 0, + tokens_after_free[&2], 0, "Worker 2 should have 0 active tokens after freeing all" ); + + Ok(()) } } From 1d9f1914e26b30b5eeeb8cc4264238a0eb6cb198 Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Mon, 4 Aug 2025 14:26:00 -0700 Subject: [PATCH 12/27] e2e test passed --- lib/llm/src/kv_router.rs | 9 +- lib/llm/src/kv_router/scheduler.rs | 130 +++++++++++++++------------- lib/llm/src/kv_router/sequence.rs | 132 +++++++++++++++++------------ 3 files changed, 159 insertions(+), 112 deletions(-) diff --git a/lib/llm/src/kv_router.rs b/lib/llm/src/kv_router.rs index 71bb8397d4..1c3dc2b338 100644 --- a/lib/llm/src/kv_router.rs +++ b/lib/llm/src/kv_router.rs @@ -229,16 +229,21 @@ impl KvRouter { let isl_tokens = tokens.len(); let block_hashes = compute_block_hash_for_seq(tokens, self.block_size); + let seq_hashes = compute_seq_hash_for_block(&block_hashes); let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?; let best_worker_id = self .scheduler - .schedule(context_id.to_string(), isl_tokens, overlap_scores.clone()) + .schedule( + context_id.to_string(), + isl_tokens, + seq_hashes.clone(), + overlap_scores.clone(), + ) .await?; if let Indexer::ApproxKvIndexer(ref indexer) = self.indexer { - let seq_hashes = compute_seq_hash_for_block(&block_hashes); indexer .process_routing_decision(best_worker_id, block_hashes, seq_hashes) .await diff --git a/lib/llm/src/kv_router/scheduler.rs b/lib/llm/src/kv_router/scheduler.rs index a1ad2a3053..08755e5e39 100644 --- a/lib/llm/src/kv_router/scheduler.rs +++ b/lib/llm/src/kv_router/scheduler.rs @@ -1,10 +1,8 @@ // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -use dashmap::DashMap; use dynamo_runtime::component::{Component, Instance}; -use dynamo_runtime::traits::events::{EventPublisher, EventSubscriber}; -use futures::StreamExt; +use dynamo_runtime::traits::events::EventPublisher; use rand::Rng; use serde::{Deserialize, Serialize}; use std::collections::HashMap; @@ -13,12 +11,13 @@ use std::time::Duration; use tokio::sync::Mutex; use super::indexer::OverlapScores; -use super::prefill_counter::{get_snapshot, PrefillCountersMultiWorker}; use super::protocols::WorkerSelectionResult; -use super::scoring::LoadEvent; +use super::sequence::ActiveSequencesMultiWorker; use super::KvRouterConfig; use super::WorkerSelector; -use super::{KV_HIT_RATE_SUBJECT, KV_METRICS_SUBJECT}; +use super::KV_HIT_RATE_SUBJECT; + +use crate::tokens::SequenceHash; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct KVHitRateEvent { @@ -42,28 +41,36 @@ pub enum KvSchedulerError { #[derive(Debug)] pub struct SchedulingResponse { pub best_worker_id: i64, - pub new_tokens: usize, + pub overlap_blocks: u32, } pub struct SchedulingRequest { + pub request_id: String, + pub token_seq: Vec, pub isl_tokens: usize, pub overlaps: OverlapScores, pub decode_blocks: HashMap, pub prefill_tokens: HashMap, - resp_tx: tokio::sync::oneshot::Sender, + resp_tx: Option>, // Changed to Option } impl SchedulingRequest { - pub fn respond(self, response: SchedulingResponse) { - if self.resp_tx.send(response).is_err() { - tracing::error!("failed to send response to requestor"); + pub fn respond(&mut self, response: SchedulingResponse) { + // Changed to &mut self + if let Some(tx) = self.resp_tx.take() { + // Use take() to extract the sender + if tx.send(response).is_err() { + tracing::error!("failed to send response to requestor"); + } + } else { + tracing::error!("respond called multiple times on same request"); } } } pub struct KvScheduler { request_tx: tokio::sync::mpsc::Sender, - prefill_tokens: Arc, + slots: Arc, barrier: Mutex<()>, } @@ -88,33 +95,17 @@ impl KvScheduler { } }); - // token and block states - let active_blocks = Arc::new(DashMap::::new()); // Create dashmap for worker_id -> active_blocks - let prefill_tokens = Arc::new(PrefillCountersMultiWorker::new(component.clone())); - - // listen on load metrics from backends to update active block states - let active_blocks_clone = active_blocks.clone(); - let ns_clone = component.namespace().clone(); - tokio::spawn(async move { - let mut load_subscriber = ns_clone - .subscribe_with_type::(KV_METRICS_SUBJECT) - .await - .expect("Cannot launch load subscriber"); - - while let Some(event_result) = load_subscriber.next().await { - let Ok(load_event) = event_result else { - tracing::warn!("Error receiving load event: {}", event_result.unwrap_err()); - continue; - }; - - active_blocks_clone.insert( - load_event.worker_id, - load_event.data.kv_stats.kv_active_blocks as usize, - ); - } - }); + let worker_ids: Vec = instances + .iter() + .map(|instance| instance.instance_id) + .collect(); + let slots = Arc::new(ActiveSequencesMultiWorker::new( + component, + block_size as usize, + worker_ids, + )); - let prefill_tokens_clone = prefill_tokens.clone(); + let slots_clone = slots.clone(); let (request_tx, request_rx) = tokio::sync::mpsc::channel::(1024); // Background task to handle scheduling requests tokio::spawn(async move { @@ -125,6 +116,11 @@ impl KvScheduler { // First, check for instance updates (non-blocking) if instances_rx.has_changed().unwrap_or(false) { instances = instances_rx.borrow_and_update().clone(); + let worker_ids: Vec = instances + .iter() + .map(|instance| instance.instance_id) + .collect(); + slots_clone.update_workers(worker_ids); } // Then, wait for a new request @@ -134,8 +130,15 @@ impl KvScheduler { }; tracing::trace!("received request to be scheduled"); - request.prefill_tokens = prefill_tokens_clone.running_sums().await; - request.decode_blocks = get_snapshot(active_blocks.as_ref()); + let (decode_blocks, prefill_tokens) = slots_clone + .potential_blocks_and_tokens( + request.token_seq.clone(), + request.isl_tokens, + request.overlaps.clone(), + ) + .await; + request.decode_blocks = decode_blocks; + request.prefill_tokens = prefill_tokens; match selector.select_worker(&instances, &request, block_size) { Ok(selection) => { @@ -149,10 +152,20 @@ impl KvScheduler { let response = SchedulingResponse { best_worker_id: selection.worker_id, - new_tokens: request.isl_tokens - - (selection.overlap_blocks * block_size) as usize, + overlap_blocks: selection.overlap_blocks, }; request.respond(response); + + let _ = slots_clone + .add_request( + request.request_id, + request.token_seq, + request.isl_tokens, + selection.overlap_blocks, + selection.worker_id, + ) + .await; + continue; } Err(KvSchedulerError::NoEndpoints) => { @@ -178,7 +191,7 @@ impl KvScheduler { Ok(KvScheduler { request_tx, - prefill_tokens, + slots, barrier: Mutex::new(()), }) } @@ -187,6 +200,7 @@ impl KvScheduler { &self, request_id: String, isl_tokens: usize, + token_seq: Vec, overlaps: OverlapScores, ) -> Result { // TODO: this is temporary needed for now to ensure blocking of scheduling and updating @@ -195,11 +209,13 @@ impl KvScheduler { let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); let request = SchedulingRequest { + request_id, + token_seq, isl_tokens, overlaps, decode_blocks: HashMap::new(), prefill_tokens: HashMap::new(), - resp_tx, + resp_tx: Some(resp_tx), // Wrap in Some() }; self.request_tx @@ -211,15 +227,14 @@ impl KvScheduler { .map_err(|_| KvSchedulerError::SubscriberShutdown)?; let best_worker_id = response.best_worker_id; - let _ = self - .prefill_tokens - .add_prefill(best_worker_id, request_id, response.new_tokens) - .await; Ok(best_worker_id) } pub async fn mark_prefill_completed(&self, request_id: &str) { - let _ = self.prefill_tokens.remove_prefill(request_id).await; + let _ = self + .slots + .mark_prefill_completed(&request_id.to_string()) + .await; } } @@ -335,16 +350,17 @@ impl WorkerSelector for DefaultWorkerSelector { // Calculate logits for each worker for instance in workers.iter() { let worker_id = instance.instance_id; - - // this is the number of prefill tokens the worker would have if the request were scheduled there - let prefill_token = *prefill_tokens.get(&worker_id).unwrap_or(&0); let overlap = *overlaps.get(&worker_id).unwrap_or(&0); - let new_token = isl - (overlap * block_size) as usize; - let potential_prefill_token = prefill_token + new_token; - let potential_prefill_block = (potential_prefill_token as f64) / (block_size as f64); - // this is the number of decode blocks currently on the worker - let decode_block = *decode_blocks.get(&worker_id).unwrap_or(&0) as f64; + // this is the number of prefill tokens the worker would have if the request were scheduled there + let prefill_token = *prefill_tokens.get(&worker_id).unwrap_or(&isl); + let potential_prefill_block = (prefill_token as f64) / (block_size as f64); + + // this is the number of decode blocks the worker would have if the request were scheduled there + let decode_block = *decode_blocks + .get(&worker_id) + .unwrap_or(&(potential_prefill_block.floor() as usize)) + as f64; // Calculate logit (lower is better) let logit = diff --git a/lib/llm/src/kv_router/sequence.rs b/lib/llm/src/kv_router/sequence.rs index 7cd67b1615..9460ff48e8 100644 --- a/lib/llm/src/kv_router/sequence.rs +++ b/lib/llm/src/kv_router/sequence.rs @@ -30,6 +30,7 @@ use dashmap::DashMap; use derive_getters::Getters; use dynamo_runtime::component::Component; use dynamo_runtime::traits::events::{EventPublisher, EventSubscriber}; +use dynamo_runtime::traits::DistributedRuntimeProvider; use futures::StreamExt; use std::collections::{HashMap, HashSet}; use std::sync::Arc; @@ -37,6 +38,7 @@ use uuid::Uuid; use super::protocols::{ActiveSequenceEvent, ActiveSequenceEventData}; use crate::kv_router::ACTIVE_SEQUENCES_SUBJECT; +use dynamo_runtime::CancellationToken; // TODO: use the common request_id if it exists in the repo pub type RequestId = String; @@ -239,7 +241,9 @@ impl ActiveSequencesMultiWorker { let router_id = Uuid::new_v4(); for worker_id in worker_ids { - let (sender, handle) = Self::start_worker(block_size); + // Create a child cancellation token from the component's runtime + let cancel_token = component.drt().runtime().child_token(); + let (sender, handle) = Self::start_worker(block_size, cancel_token); senders.insert(worker_id, sender); handles.insert(worker_id, handle); } @@ -278,6 +282,7 @@ impl ActiveSequencesMultiWorker { /// Helper method to start a worker task fn start_worker( block_size: usize, + cancel_token: CancellationToken, // Add cancellation token parameter ) -> ( tokio::sync::mpsc::UnboundedSender, tokio::task::JoinHandle<()>, @@ -287,58 +292,76 @@ impl ActiveSequencesMultiWorker { let handle = tokio::spawn(async move { let mut active_sequences = ActiveSequences::new(block_size); - while let Some(command) = request_rx.recv().await { - match command { - UpdateSequences::AddRequest { - request_id, - token_sequence, - isl, - overlap, - } => { - active_sequences.add_request(request_id, token_sequence, isl, overlap); - } - UpdateSequences::Free { request_id } => { - active_sequences.free(&request_id); - } - UpdateSequences::MarkPrefillCompleted { request_id } => { - active_sequences.mark_prefill_completed(&request_id); - } - UpdateSequences::NewBlocks { - token_sequence, - resp_tx, - } => { - let new_blocks = active_sequences.new_blocks(&token_sequence); - let _ = resp_tx.send(new_blocks); - } - UpdateSequences::PotentialBlocks { - token_sequence, - resp_tx, - } => { - let potential_blocks = active_sequences.potential_blocks(&token_sequence); - let _ = resp_tx.send(potential_blocks); - } - UpdateSequences::PotentialBlocksAndTokens { - token_sequence, - isl, - overlap, - resp_tx, - } => { - let potential_tokens = active_sequences.potential_blocks_and_tokens( - &token_sequence, - isl, - overlap, - ); - let _ = resp_tx.send(potential_tokens); - } - UpdateSequences::ActiveBlocks { resp_tx } => { - let active_blocks = active_sequences.active_blocks(); - let _ = resp_tx.send(active_blocks); - } - UpdateSequences::ActiveTokens { resp_tx } => { - let active_tokens = active_sequences.active_tokens(); - let _ = resp_tx.send(active_tokens); + loop { + tokio::select! { + // Handle incoming commands + command = request_rx.recv() => { + match command { + Some(command) => { + match command { + UpdateSequences::AddRequest { + request_id, + token_sequence, + isl, + overlap, + } => { + active_sequences.add_request(request_id, token_sequence, isl, overlap); + } + UpdateSequences::Free { request_id } => { + active_sequences.free(&request_id); + } + UpdateSequences::MarkPrefillCompleted { request_id } => { + active_sequences.mark_prefill_completed(&request_id); + } + UpdateSequences::NewBlocks { + token_sequence, + resp_tx, + } => { + let new_blocks = active_sequences.new_blocks(&token_sequence); + let _ = resp_tx.send(new_blocks); + } + UpdateSequences::PotentialBlocks { + token_sequence, + resp_tx, + } => { + let potential_blocks = active_sequences.potential_blocks(&token_sequence); + let _ = resp_tx.send(potential_blocks); + } + UpdateSequences::PotentialBlocksAndTokens { + token_sequence, + isl, + overlap, + resp_tx, + } => { + let potential_tokens = active_sequences.potential_blocks_and_tokens( + &token_sequence, + isl, + overlap, + ); + let _ = resp_tx.send(potential_tokens); + } + UpdateSequences::ActiveBlocks { resp_tx } => { + let active_blocks = active_sequences.active_blocks(); + let _ = resp_tx.send(active_blocks); + } + UpdateSequences::ActiveTokens { resp_tx } => { + let active_tokens = active_sequences.active_tokens(); + let _ = resp_tx.send(active_tokens); + } + UpdateSequences::Shutdown => { + break; + } + } + } + None => { + // Channel closed, exit + break; + } + } } - UpdateSequences::Shutdown => { + // Handle cancellation + _ = cancel_token.cancelled() => { + tracing::debug!("Worker task cancelled"); break; } } @@ -447,7 +470,10 @@ impl ActiveSequencesMultiWorker { for worker_id in &workers_to_add { tracing::warn!("Adding worker {}", worker_id); - let (sender, handle) = Self::start_worker(self.block_size); + let (sender, handle) = Self::start_worker( + self.block_size, + self.component.drt().runtime().child_token(), + ); self.senders.insert(*worker_id, sender); self.handles.insert(*worker_id, handle); } From 9f71c972ad7fdcb7d23cbddd0d7c537e9c5b1505 Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Mon, 4 Aug 2025 14:28:59 -0700 Subject: [PATCH 13/27] dummy mutex was useless --- lib/llm/src/kv_router/scheduler.rs | 6 ------ 1 file changed, 6 deletions(-) diff --git a/lib/llm/src/kv_router/scheduler.rs b/lib/llm/src/kv_router/scheduler.rs index 08755e5e39..c00d2494ff 100644 --- a/lib/llm/src/kv_router/scheduler.rs +++ b/lib/llm/src/kv_router/scheduler.rs @@ -71,7 +71,6 @@ impl SchedulingRequest { pub struct KvScheduler { request_tx: tokio::sync::mpsc::Sender, slots: Arc, - barrier: Mutex<()>, } impl KvScheduler { @@ -192,7 +191,6 @@ impl KvScheduler { Ok(KvScheduler { request_tx, slots, - barrier: Mutex::new(()), }) } @@ -203,10 +201,6 @@ impl KvScheduler { token_seq: Vec, overlaps: OverlapScores, ) -> Result { - // TODO: this is temporary needed for now to ensure blocking of scheduling and updating - // Need to remove in future if we have better algo to truly enable async processing - let _guard = self.barrier.lock().await; - let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); let request = SchedulingRequest { request_id, From bffeb7ba5656603ffae9af8a751208e2d9b8c61e Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Mon, 4 Aug 2025 14:29:51 -0700 Subject: [PATCH 14/27] clippy --- lib/llm/src/kv_router/scheduler.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/llm/src/kv_router/scheduler.rs b/lib/llm/src/kv_router/scheduler.rs index c00d2494ff..89cebd9fdc 100644 --- a/lib/llm/src/kv_router/scheduler.rs +++ b/lib/llm/src/kv_router/scheduler.rs @@ -8,7 +8,6 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; -use tokio::sync::Mutex; use super::indexer::OverlapScores; use super::protocols::WorkerSelectionResult; From e77a28cbe25fc8f18cdb34f6054f8668644a8109 Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Tue, 5 Aug 2025 09:47:07 -0700 Subject: [PATCH 15/27] disable nats metrics publishing --- lib/llm/src/kv_router/publisher.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lib/llm/src/kv_router/publisher.rs b/lib/llm/src/kv_router/publisher.rs index 4415738b6e..09d3a3447e 100644 --- a/lib/llm/src/kv_router/publisher.rs +++ b/lib/llm/src/kv_router/publisher.rs @@ -512,8 +512,6 @@ impl WorkerMetricsPublisher { 0 }); - self.start_nats_metrics_publishing(component.namespace().clone(), worker_id); - component .endpoint(KV_METRICS_ENDPOINT) .endpoint_builder() @@ -530,6 +528,7 @@ impl WorkerMetricsPublisher { /// /// This task monitors metric changes (specifically kv_active_blocks and num_requests_waiting) /// and publishes stable metrics to NATS after they've been unchanged for 1ms. + #[allow(dead_code)] fn start_nats_metrics_publishing(&self, namespace: Namespace, worker_id: i64) { let nats_rx = self.rx.clone(); From c7543d1496a649257caabb5d03921dbb00d1eaf3 Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Tue, 5 Aug 2025 09:47:53 -0700 Subject: [PATCH 16/27] fmt --- lib/llm/src/kv_router/scheduler.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/lib/llm/src/kv_router/scheduler.rs b/lib/llm/src/kv_router/scheduler.rs index 89cebd9fdc..856d2907ac 100644 --- a/lib/llm/src/kv_router/scheduler.rs +++ b/lib/llm/src/kv_router/scheduler.rs @@ -187,10 +187,7 @@ impl KvScheduler { tracing::trace!("background endpoint subscriber shutting down"); }); - Ok(KvScheduler { - request_tx, - slots, - }) + Ok(KvScheduler { request_tx, slots }) } pub async fn schedule( From 1a79bb71fd0b7c888cf91a28e8e675ffc7287272 Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Tue, 5 Aug 2025 09:49:26 -0700 Subject: [PATCH 17/27] clippy --- lib/llm/src/kv_router/publisher.rs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/lib/llm/src/kv_router/publisher.rs b/lib/llm/src/kv_router/publisher.rs index 09d3a3447e..22c88b08d0 100644 --- a/lib/llm/src/kv_router/publisher.rs +++ b/lib/llm/src/kv_router/publisher.rs @@ -503,14 +503,14 @@ impl WorkerMetricsPublisher { let handler = Arc::new(KvLoadEndpointHandler::new(metrics_rx.clone())); let handler = Ingress::for_engine(handler)?; - let worker_id = component - .drt() - .primary_lease() - .map(|lease| lease.id()) - .unwrap_or_else(|| { - tracing::warn!("Component is static, assuming worker_id of 0"); - 0 - }); + // let worker_id = component + // .drt() + // .primary_lease() + // .map(|lease| lease.id()) + // .unwrap_or_else(|| { + // tracing::warn!("Component is static, assuming worker_id of 0"); + // 0 + // }); component .endpoint(KV_METRICS_ENDPOINT) From bd4bcff94124348b36c4365b49ef3abed770aa99 Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Tue, 5 Aug 2025 12:06:08 -0700 Subject: [PATCH 18/27] free requests --- lib/llm/src/kv_router.rs | 8 ++++++-- lib/llm/src/kv_router/scheduler.rs | 4 ++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/lib/llm/src/kv_router.rs b/lib/llm/src/kv_router.rs index 1c3dc2b338..6a2a0beb79 100644 --- a/lib/llm/src/kv_router.rs +++ b/lib/llm/src/kv_router.rs @@ -258,12 +258,14 @@ impl KvRouter { Ok((best_worker_id, overlap_amount)) } - /// Free all blocks associated with a request pub async fn mark_prefill_completed(&self, request_id: &str) { self.scheduler.mark_prefill_completed(request_id).await } - /// Get the block size this router was configured with + pub async fn free(&self, request_id: &str) { + self.scheduler.free(request_id).await + } + pub fn block_size(&self) -> u32 { self.block_size } @@ -349,6 +351,8 @@ impl AsyncEngine, ManyOut Date: Tue, 5 Aug 2025 23:47:10 -0700 Subject: [PATCH 19/27] make state sharing a flag, defaulting to false --- .../frontend/src/dynamo/frontend/main.py | 7 + components/router/src/main.rs | 2 +- launch/dynamo-run/src/flags.rs | 7 + lib/bindings/python/rust/llm/entrypoint.rs | 10 +- lib/llm/src/discovery/model_manager.rs | 4 +- lib/llm/src/kv_router.rs | 16 ++- lib/llm/src/kv_router/scheduler.rs | 2 + lib/llm/src/kv_router/sequence.rs | 125 ++++++++++-------- 8 files changed, 112 insertions(+), 61 deletions(-) diff --git a/components/frontend/src/dynamo/frontend/main.py b/components/frontend/src/dynamo/frontend/main.py index c7903d3a56..d51bbc7951 100644 --- a/components/frontend/src/dynamo/frontend/main.py +++ b/components/frontend/src/dynamo/frontend/main.py @@ -72,6 +72,12 @@ def parse_args(): help=" KV Router. Disable KV events.", ) parser.set_defaults(use_kv_events=True) + parser.add_argument( + "--router-replica-sync", + action="store_true", + default=False, + help="KV Router: Enable replica synchronization across multiple router instances. When true, routers will publish and subscribe to events to maintain consistent state.", + ) return parser.parse_args() @@ -86,6 +92,7 @@ async def async_main(): overlap_score_weight=flags.kv_overlap_score_weight, router_temperature=flags.router_temperature, use_kv_events=flags.use_kv_events, + replica_sync=flags.router_replica_sync, ) elif flags.router_mode == "random": router_mode = RouterMode.Random diff --git a/components/router/src/main.rs b/components/router/src/main.rs index 5f741338fd..093ba596b5 100644 --- a/components/router/src/main.rs +++ b/components/router/src/main.rs @@ -66,7 +66,7 @@ async fn app(runtime: Runtime) -> Result<()> { let selector = Box::new(CustomWorkerSelector::default()); - let router = KvRouter::new(component.clone(), args.block_size, Some(selector), true).await?; + let router = KvRouter::new(component.clone(), args.block_size, Some(selector), true, false).await?; let router = Ingress::for_engine(Arc::new(router))?; component diff --git a/launch/dynamo-run/src/flags.rs b/launch/dynamo-run/src/flags.rs index b2b641f6a1..ed58096cc8 100644 --- a/launch/dynamo-run/src/flags.rs +++ b/launch/dynamo-run/src/flags.rs @@ -96,6 +96,12 @@ pub struct Flags { #[arg(long)] pub use_kv_events: Option, + /// KV Router: Whether to enable replica synchronization across multiple router instances. + /// When true, routers will publish and subscribe to events to maintain consistent state. + /// Default: false + #[arg(long)] + pub router_replica_sync: Option, + /// Max model context length. Reduce this if you don't have enough VRAM for the full model /// context length (e.g. Llama 4). /// Defaults to the model's max, which is usually model_max_length in tokenizer_config.json. @@ -179,6 +185,7 @@ impl Flags { self.kv_overlap_score_weight, self.router_temperature, self.use_kv_events, + self.router_replica_sync, self.max_num_batched_tokens, ), ) diff --git a/lib/bindings/python/rust/llm/entrypoint.rs b/lib/bindings/python/rust/llm/entrypoint.rs index adfb16f6dd..6daebada26 100644 --- a/lib/bindings/python/rust/llm/entrypoint.rs +++ b/lib/bindings/python/rust/llm/entrypoint.rs @@ -34,13 +34,19 @@ pub struct KvRouterConfig { #[pymethods] impl KvRouterConfig { #[new] - #[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true))] - fn new(overlap_score_weight: f64, router_temperature: f64, use_kv_events: bool) -> Self { + #[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, replica_sync=false))] + fn new( + overlap_score_weight: f64, + router_temperature: f64, + use_kv_events: bool, + replica_sync: bool, + ) -> Self { KvRouterConfig { inner: RsKvRouterConfig { overlap_score_weight, router_temperature, use_kv_events, + replica_sync, ..Default::default() }, } diff --git a/lib/llm/src/discovery/model_manager.rs b/lib/llm/src/discovery/model_manager.rs index caa105e921..cef298459e 100644 --- a/lib/llm/src/discovery/model_manager.rs +++ b/lib/llm/src/discovery/model_manager.rs @@ -213,11 +213,13 @@ impl ModelManager { kv_router_config: Option, ) -> anyhow::Result> { let selector = Box::new(DefaultWorkerSelector::new(kv_router_config)); + let config = kv_router_config.unwrap_or_default(); let chooser = KvRouter::new( component.clone(), kv_cache_block_size, Some(selector), - kv_router_config.unwrap_or_default().use_kv_events, + config.use_kv_events, + config.router_replica_sync, ) .await?; let new_kv_chooser = Arc::new(chooser); diff --git a/lib/llm/src/kv_router.rs b/lib/llm/src/kv_router.rs index 6a2a0beb79..ca8aae1660 100644 --- a/lib/llm/src/kv_router.rs +++ b/lib/llm/src/kv_router.rs @@ -80,6 +80,8 @@ pub struct KvRouterConfig { pub use_kv_events: bool, + pub router_replica_sync: bool, + // TODO: this is not actually used for now // Would need this (along with total kv blocks) to trigger AllWorkersBusy error for e.g. rate-limiting pub max_num_batched_tokens: u32, @@ -91,6 +93,7 @@ impl Default for KvRouterConfig { overlap_score_weight: 1.0, router_temperature: 0.0, use_kv_events: true, + router_replica_sync: false, max_num_batched_tokens: 8192, } } @@ -103,6 +106,7 @@ impl KvRouterConfig { overlap_score_weight: Option, temperature: Option, use_kv_events: Option, + replica_sync: Option, max_num_batched_tokens: Option, ) -> Self { let default = Self::default(); @@ -110,6 +114,7 @@ impl KvRouterConfig { overlap_score_weight: overlap_score_weight.unwrap_or(default.overlap_score_weight), router_temperature: temperature.unwrap_or(default.router_temperature), use_kv_events: use_kv_events.unwrap_or(default.use_kv_events), + router_replica_sync: replica_sync.unwrap_or(default.router_replica_sync), max_num_batched_tokens: max_num_batched_tokens .unwrap_or(default.max_num_batched_tokens), } @@ -152,6 +157,7 @@ impl KvRouter { block_size: u32, selector: Option>, use_kv_events: bool, + replica_sync: bool, ) -> Result { let cancellation_token = component .drt() @@ -180,8 +186,14 @@ impl KvRouter { )) }; - let scheduler = - KvScheduler::start(component.clone(), block_size, instances_rx, selector).await?; + let scheduler = KvScheduler::start( + component.clone(), + block_size, + instances_rx, + selector, + replica_sync, + ) + .await?; // [gluo TODO] try subscribe_with_type::, // error checking below will be different. diff --git a/lib/llm/src/kv_router/scheduler.rs b/lib/llm/src/kv_router/scheduler.rs index ea95d009a9..ca3b1a666a 100644 --- a/lib/llm/src/kv_router/scheduler.rs +++ b/lib/llm/src/kv_router/scheduler.rs @@ -78,6 +78,7 @@ impl KvScheduler { block_size: u32, mut instances_rx: tokio::sync::watch::Receiver>, // Changed from ProcessedEndpoints selector: Option>, + replica_sync: bool, ) -> Result { let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default())); let mut instances: Vec = instances_rx.borrow_and_update().clone(); @@ -101,6 +102,7 @@ impl KvScheduler { component, block_size as usize, worker_ids, + replica_sync, )); let slots_clone = slots.clone(); diff --git a/lib/llm/src/kv_router/sequence.rs b/lib/llm/src/kv_router/sequence.rs index 9460ff48e8..86d3ca66a5 100644 --- a/lib/llm/src/kv_router/sequence.rs +++ b/lib/llm/src/kv_router/sequence.rs @@ -229,10 +229,16 @@ pub struct ActiveSequencesMultiWorker { block_size: usize, component: Component, router_id: Uuid, + replica_sync: bool, } impl ActiveSequencesMultiWorker { - pub fn new(component: Component, block_size: usize, worker_ids: Vec) -> Self { + pub fn new( + component: Component, + block_size: usize, + worker_ids: Vec, + replica_sync: bool, + ) -> Self { assert!(block_size > 1, "block_size must be greater than 1"); let senders = Arc::new(DashMap::new()); @@ -255,26 +261,29 @@ impl ActiveSequencesMultiWorker { block_size, component: component.clone(), router_id, + replica_sync, }; - // Start the subscription loop - let senders_clone = senders.clone(); - let request_to_worker_clone = request_to_worker.clone(); - let component_clone = component.clone(); - let router_id_clone = router_id; - - tokio::spawn(async move { - if let Err(e) = Self::subscribe_to_events( - senders_clone, - request_to_worker_clone, - component_clone, - router_id_clone, - ) - .await - { - tracing::error!("Error in active sequences events subscription: {}", e); - } - }); + // Start the subscription loop only if replica_sync is enabled + if replica_sync { + let senders_clone = senders.clone(); + let request_to_worker_clone = request_to_worker.clone(); + let component_clone = component.clone(); + let router_id_clone = router_id; + + tokio::spawn(async move { + if let Err(e) = Self::subscribe_to_events( + senders_clone, + request_to_worker_clone, + component_clone, + router_id_clone, + ) + .await + { + tracing::error!("Error in active sequences events subscription: {}", e); + } + }); + } multi_worker } @@ -491,20 +500,22 @@ impl ActiveSequencesMultiWorker { return Err(anyhow::anyhow!("Worker ID {worker_id} not found")); } - // Publish event - let event = ActiveSequenceEvent { - request_id: request_id.clone(), - worker_id, - data: ActiveSequenceEventData::AddRequest { - token_sequence: token_sequence.clone(), - isl, - overlap, - }, - router_id: self.router_id, - }; - self.component - .publish(ACTIVE_SEQUENCES_SUBJECT, &event) - .await?; + // Publish event only if replica_sync is enabled + if self.replica_sync { + let event = ActiveSequenceEvent { + request_id: request_id.clone(), + worker_id, + data: ActiveSequenceEventData::AddRequest { + token_sequence: token_sequence.clone(), + isl, + overlap, + }, + router_id: self.router_id, + }; + self.component + .publish(ACTIVE_SEQUENCES_SUBJECT, &event) + .await?; + } // Update local state self.request_to_worker.insert(request_id.clone(), worker_id); @@ -530,16 +541,18 @@ impl ActiveSequencesMultiWorker { .map(|entry| *entry) .ok_or_else(|| anyhow::anyhow!("Request ID not found in request_to_worker mapping"))?; - // Publish event - let event = ActiveSequenceEvent { - request_id: request_id.clone(), - worker_id, - data: ActiveSequenceEventData::Free, - router_id: self.router_id, - }; - self.component - .publish(ACTIVE_SEQUENCES_SUBJECT, &event) - .await?; + // Publish event only if replica_sync is enabled + if self.replica_sync { + let event = ActiveSequenceEvent { + request_id: request_id.clone(), + worker_id, + data: ActiveSequenceEventData::Free, + router_id: self.router_id, + }; + self.component + .publish(ACTIVE_SEQUENCES_SUBJECT, &event) + .await?; + } // Update local state self.senders @@ -563,16 +576,18 @@ impl ActiveSequencesMultiWorker { .map(|entry| *entry) .ok_or_else(|| anyhow::anyhow!("Request ID not found in request_to_worker mapping"))?; - // Publish event - let event = ActiveSequenceEvent { - request_id: request_id.clone(), - worker_id, - data: ActiveSequenceEventData::MarkPrefillCompleted, - router_id: self.router_id, - }; - self.component - .publish(ACTIVE_SEQUENCES_SUBJECT, &event) - .await?; + // Publish event only if replica_sync is enabled + if self.replica_sync { + let event = ActiveSequenceEvent { + request_id: request_id.clone(), + worker_id, + data: ActiveSequenceEventData::MarkPrefillCompleted, + router_id: self.router_id, + }; + self.component + .publish(ACTIVE_SEQUENCES_SUBJECT, &event) + .await?; + } // Update local state self.senders @@ -793,7 +808,7 @@ mod tests { // Create multi-worker sequence manager with workers 0 and 1 let worker_ids = vec![0, 1]; let seq_manager = - ActiveSequencesMultiWorker::new(component, block_size, worker_ids); + ActiveSequencesMultiWorker::new(component, block_size, worker_ids, true); // Give some time for the subscription loop to start tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; @@ -907,7 +922,7 @@ mod tests { // Create multi-worker sequence manager with worker 2 let worker_ids = vec![2]; let seq_manager = - ActiveSequencesMultiWorker::new(component, block_size, worker_ids); + ActiveSequencesMultiWorker::new(component, block_size, worker_ids, true); // Give some time for the subscription loop to start tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; From 8e6db78706350b586101c632b6b2db0f85c3e05b Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Wed, 6 Aug 2025 07:50:23 -0700 Subject: [PATCH 20/27] rename to router_replica_sync --- components/frontend/src/dynamo/frontend/main.py | 2 +- lib/bindings/python/rust/llm/entrypoint.rs | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/components/frontend/src/dynamo/frontend/main.py b/components/frontend/src/dynamo/frontend/main.py index d51bbc7951..e6fb893a13 100644 --- a/components/frontend/src/dynamo/frontend/main.py +++ b/components/frontend/src/dynamo/frontend/main.py @@ -92,7 +92,7 @@ async def async_main(): overlap_score_weight=flags.kv_overlap_score_weight, router_temperature=flags.router_temperature, use_kv_events=flags.use_kv_events, - replica_sync=flags.router_replica_sync, + router_replica_sync=flags.router_replica_sync, ) elif flags.router_mode == "random": router_mode = RouterMode.Random diff --git a/lib/bindings/python/rust/llm/entrypoint.rs b/lib/bindings/python/rust/llm/entrypoint.rs index 6daebada26..93692b6f2e 100644 --- a/lib/bindings/python/rust/llm/entrypoint.rs +++ b/lib/bindings/python/rust/llm/entrypoint.rs @@ -34,19 +34,19 @@ pub struct KvRouterConfig { #[pymethods] impl KvRouterConfig { #[new] - #[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, replica_sync=false))] + #[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, router_replica_sync=false))] fn new( overlap_score_weight: f64, router_temperature: f64, use_kv_events: bool, - replica_sync: bool, + router_replica_sync: bool, ) -> Self { KvRouterConfig { inner: RsKvRouterConfig { overlap_score_weight, router_temperature, use_kv_events, - replica_sync, + router_replica_sync, ..Default::default() }, } From 65ee094a23f14f936bfc0821f9789ff36c53a4c8 Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Wed, 6 Aug 2025 12:02:27 -0700 Subject: [PATCH 21/27] fmt --- components/router/src/main.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/components/router/src/main.rs b/components/router/src/main.rs index 093ba596b5..8a42eef7d5 100644 --- a/components/router/src/main.rs +++ b/components/router/src/main.rs @@ -66,7 +66,14 @@ async fn app(runtime: Runtime) -> Result<()> { let selector = Box::new(CustomWorkerSelector::default()); - let router = KvRouter::new(component.clone(), args.block_size, Some(selector), true, false).await?; + let router = KvRouter::new( + component.clone(), + args.block_size, + Some(selector), + true, + false, + ) + .await?; let router = Ingress::for_engine(Arc::new(router))?; component From 6f27d5236ad9749c90e06abf3dabe8783f510bae Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Wed, 6 Aug 2025 15:02:49 -0700 Subject: [PATCH 22/27] reviewer comments --- components/router/src/main.rs | 9 +-------- lib/llm/src/discovery/model_manager.rs | 4 +--- lib/llm/src/kv_router.rs | 9 +++++---- lib/llm/src/kv_router/scheduler.rs | 26 ++++++++++++++++++-------- 4 files changed, 25 insertions(+), 23 deletions(-) diff --git a/components/router/src/main.rs b/components/router/src/main.rs index 8a42eef7d5..1ee7fb64fd 100644 --- a/components/router/src/main.rs +++ b/components/router/src/main.rs @@ -66,14 +66,7 @@ async fn app(runtime: Runtime) -> Result<()> { let selector = Box::new(CustomWorkerSelector::default()); - let router = KvRouter::new( - component.clone(), - args.block_size, - Some(selector), - true, - false, - ) - .await?; + let router = KvRouter::new(component.clone(), args.block_size, Some(selector), None).await?; let router = Ingress::for_engine(Arc::new(router))?; component diff --git a/lib/llm/src/discovery/model_manager.rs b/lib/llm/src/discovery/model_manager.rs index cef298459e..6773c2e98e 100644 --- a/lib/llm/src/discovery/model_manager.rs +++ b/lib/llm/src/discovery/model_manager.rs @@ -213,13 +213,11 @@ impl ModelManager { kv_router_config: Option, ) -> anyhow::Result> { let selector = Box::new(DefaultWorkerSelector::new(kv_router_config)); - let config = kv_router_config.unwrap_or_default(); let chooser = KvRouter::new( component.clone(), kv_cache_block_size, Some(selector), - config.use_kv_events, - config.router_replica_sync, + kv_router_config, ) .await?; let new_kv_chooser = Arc::new(chooser); diff --git a/lib/llm/src/kv_router.rs b/lib/llm/src/kv_router.rs index ca8aae1660..1e488872e1 100644 --- a/lib/llm/src/kv_router.rs +++ b/lib/llm/src/kv_router.rs @@ -156,9 +156,10 @@ impl KvRouter { component: Component, block_size: u32, selector: Option>, - use_kv_events: bool, - replica_sync: bool, + kv_router_config: Option, ) -> Result { + let kv_router_config = kv_router_config.unwrap_or_default(); + let cancellation_token = component .drt() .primary_lease() @@ -175,7 +176,7 @@ impl KvRouter { } }; - let indexer = if use_kv_events { + let indexer = if kv_router_config.use_kv_events { Indexer::KvIndexer(KvIndexer::new(cancellation_token.clone(), block_size)) } else { // hard code 120 seconds for now @@ -191,7 +192,7 @@ impl KvRouter { block_size, instances_rx, selector, - replica_sync, + kv_router_config.router_replica_sync, ) .await?; diff --git a/lib/llm/src/kv_router/scheduler.rs b/lib/llm/src/kv_router/scheduler.rs index ca3b1a666a..398d19ab9c 100644 --- a/lib/llm/src/kv_router/scheduler.rs +++ b/lib/llm/src/kv_router/scheduler.rs @@ -50,7 +50,8 @@ pub struct SchedulingRequest { pub overlaps: OverlapScores, pub decode_blocks: HashMap, pub prefill_tokens: HashMap, - resp_tx: Option>, // Changed to Option + // Option to take it out to send the response without moving the struct + resp_tx: Option>, } impl SchedulingRequest { @@ -114,13 +115,22 @@ impl KvScheduler { loop { // First, check for instance updates (non-blocking) - if instances_rx.has_changed().unwrap_or(false) { - instances = instances_rx.borrow_and_update().clone(); - let worker_ids: Vec = instances - .iter() - .map(|instance| instance.instance_id) - .collect(); - slots_clone.update_workers(worker_ids); + match instances_rx.has_changed() { + Ok(true) => { + instances = instances_rx.borrow_and_update().clone(); + let worker_ids: Vec = instances + .iter() + .map(|instance| instance.instance_id) + .collect(); + slots_clone.update_workers(worker_ids); + } + Ok(false) => { + // No changes, continue. This is the happy path. + } + Err(_) => { + tracing::warn!("endpoint watch sender shutdown"); + break; + } } // Then, wait for a new request From 038acd42c0da035c9a22e0c994f90b9fbbec52d1 Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Wed, 6 Aug 2025 16:53:31 -0700 Subject: [PATCH 23/27] update docs --- docs/architecture/kv_cache_routing.md | 77 ++++++++++++++++----------- 1 file changed, 45 insertions(+), 32 deletions(-) diff --git a/docs/architecture/kv_cache_routing.md b/docs/architecture/kv_cache_routing.md index a78feef9f5..dbd1c6484f 100644 --- a/docs/architecture/kv_cache_routing.md +++ b/docs/architecture/kv_cache_routing.md @@ -23,6 +23,7 @@ The KV-aware routing arguments: - `--use-kv-events`: Sets whether to listen to KV events for maintaining the global view of cached blocks. If true, then we use the `KvIndexer` to listen to the block creation and deletion events. If false, `ApproxKvIndexer`, which assumes the kv cache of historical prompts exists for fixed time durations (hard-coded to 120s), is used to predict the kv cache hit ratio in each engine. Set false if your backend engine does not emit KV events. +- `--router-replica-sync`: Enables state synchronization between multiple router replicas via NATS. When enabled, router replicas share their view of KV cache distribution and active sequences, allowing all routers to make optimal routing decisions even when requests are distributed across multiple router instances. This improves fault tolerance and routing accuracy in multi-router deployments. ## Architecture @@ -45,6 +46,22 @@ We can then use the default routing methods exposed by the client class to send KV Cache routing uses direct routing with a special worker selection algorithm. +## Serving Two Router Replicas + +For improved fault tolerance, you can launch two frontend + router replicas. Since the frontend and router are currently tied together, you'll need to use two different HTTP ports for each instance. + +To enable state sharing between the router replicas (which provides more accurate routing decisions), use the `--router-replica-sync` flag when starting the frontend: + +```bash +# Router replica 1 +python -m dynamo.frontend --router-mode kv --port 8000 --router-replica-sync + +# Router replica 2 +python -m dynamo.frontend --router-mode kv --port 8001 --router-replica-sync +``` + +When `--router-replica-sync` is enabled, the router replicas will communicate with each other via NATS to maintain consistent state across instances. This allows both routers to have a complete view of the KV cache distribution and make optimal routing decisions, even when requests are distributed across multiple router instances. + ## Understanding KV Cache The leading Large Language Models (LLMs) today are auto-regressive and based off of the [transformer architecture](https://proceedings.neurips.cc/paper_files/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf). One key inference optimization technique is to cache the already computed keys and values and to reuse them for the future tokens. This is called the [KV Cache](https://developer.nvidia.com/blog/mastering-llm-techniques-inference-optimization/#key-value_caching). @@ -88,30 +105,46 @@ Further details can be found for: [TRT-LLM](https://developer.nvidia.com/blog/in | +------------------+------------------+ | | | - | KV match: 15% | KV match: 50% | KV match: 75% + | Cached: 2 blocks | Cached: 5 blocks | Cached: 8 blocks + | Prefill: 8 blks | Prefill: 5 blks | Prefill: 2 blks + | Decode: 10 blks | Decode: 7 blks | Decode: 9 blks v v v +----------------+ +----------------+ +----------------+ | Worker 1 | | Worker 2 | | Worker 3 | - | (Load: 30%) | | (Load: 50%) | | (Load: 80%) | +----------------+ +----------------+ +----------------+ ``` Load balancing in LLM serving becomes complex when enabling KV Cache reuse. While KV Cache reuse can save significant computation, if the routing strategy is not aware of the unique KV states of each worker we can: -- miss opportunities for KV Cache reuse if routing to the “wrong” node +- miss opportunities for KV Cache reuse if routing to the "wrong" node - get into an imbalanced state where a few workers are processing many requests, lowering throughput of entire system -The best way to solve these issues is for the router to have a global view of KV Cache and load. With this view, the router can use a cost function to score the workers and make decisions to maximize cache hits while keeping the system balanced and throughput high. +The router uses a cost function that considers both the prefill cost (influenced by cached blocks) and the decode load to make optimal routing decisions: + +### Cost Calculation + +1. **Prefill blocks**: The number of tokens that need to be processed during prefill is predicted based on the request's input tokens and the cached blocks available on each worker. This is divided by the block size to get the effective "prefill blocks". This prediction is updated when the first output token is produced, signaling prefill completion. -In the above image, our cost function is (KV match - Load) so we select Worker 2 even though Worker 3 would offer the best KV match. -- Worker 1 = (0.15 - 0.30) = -0.15 -- **Worker 2 = (0.50 - 0.50) = 0** -- Worker 3 = (0.75 - 0.80) = -0.05 +2. **Decode blocks**: The number of blocks needed during the decode phase is predicted based on the request's input tokens and the current active sequences on each worker. This is updated when the request is freed (blocks are dereferenced or freed). + +3. **Cost formula**: `cost = overlap_score_weight * prefill_blocks + decode_blocks` + - Lower cost is better + - The `overlap_score_weight` parameter controls the importance of cache hits vs. load balancing + - A higher weight prioritizes cache reuse (better TTFT) while a lower weight prioritizes load distribution (better ITL) + +### Worker Selection + +The router selects the worker with the lowest cost. When `router_temperature` is set to a non-zero value, the router uses softmax sampling on the normalized cost logits to introduce randomness in the selection, which can help with load distribution. + +Example calculation with `overlap_score_weight = 1.0`: +- Worker 1: cost = 1.0 * 8 + 10 = 18 +- **Worker 2: cost = 1.0 * 5 + 7 = 12** (selected - lowest cost) +- Worker 3: cost = 1.0 * 2 + 9 = 11 ## Events -In Dynamo, we want to support KV Cache Routing and load balancing for many backends that have different implementations of KV Cache and record different metrics. To that end, we built a KVPublisher that can be plugged into any framework to publish KV Events and a WorkerMetricsPublisher that can publish Metric Events. +In Dynamo, we support KV Cache Routing for many backends that have different implementations of KV Cache. To enable this, we built a KVPublisher that can be plugged into any framework to publish KV Events. -On the receiving side we have a KVIndexer which accepts events from the KVPublisher and puts them into a global prefix tree and a KvMetricsAggregator which aggregates metric events by worker. +On the receiving side we have a KVIndexer which accepts events from the KVPublisher and puts them into a global prefix tree for tracking cached blocks across all workers. ```text +----------------+ +-----------------+ @@ -121,13 +154,8 @@ On the receiving side we have a KVIndexer which accepts events from the KVPublis | +------------+ | remove_kv_block() | | KVIndexer | | | |KVPublisher | |------------------------>| +-------------+ | | +------------+ | | | -| | num_request_waiting | +--------------+| -| +------------+ | gpu_cache_usage_perc | |KvMetricsAggre|| -| |KvMetrics | |------------------------>| | gator || -| |Publisher | | ... | +--------------+| -| +------------+ | +-----------------+ -+----------------+ - +| | | | ++----------------+ +-----------------+ ``` ### KVPublisher @@ -144,18 +172,3 @@ The KVIndexer builds and maintains a global view of cached blocks in a prefix tr The KVIndexer has a method `find_matches_for_request`, which takes in tokens and returns a dictionary with keys of worker id and values of the number of matched KV Blocks. -### WorkerMetricsPublisher -We added a KvMetrics Publisher which sends the following metrics to the KvMetricsAggregator: -- num_requests_waiting -- gpu_cache_usage_perc -- gpu_prefix_cache_hit_rate -- request_active_slots -- request_total_slots -- kv_active_blocks -- kv_total_blocks - -Currently, the WorkerMetricsPublisher exists as a Python binding. - -### KvMetricsAggregator -The KvMetricsAggregator receives these metrics and aggregates them. It has a method `get_metrics` which returns an object of `AggregatedMetrics`. - From 42fb6a0646798ba238ded28183b4f2ced2b6795d Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Wed, 6 Aug 2025 17:01:06 -0700 Subject: [PATCH 24/27] use dashmap 5.5.3 for now so we don't have two versions --- Cargo.lock | 18 ++---------------- lib/bindings/python/Cargo.lock | 18 ++---------------- lib/llm/Cargo.toml | 2 +- 3 files changed, 5 insertions(+), 33 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 655112a2dd..7613430aaf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1556,20 +1556,6 @@ dependencies = [ "parking_lot_core", ] -[[package]] -name = "dashmap" -version = "6.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" -dependencies = [ - "cfg-if 1.0.0", - "crossbeam-utils", - "hashbrown 0.14.5", - "lock_api", - "once_cell", - "parking_lot_core", -] - [[package]] name = "data-encoding" version = "2.9.0" @@ -1913,7 +1899,7 @@ dependencies = [ "chrono", "criterion", "cudarc 0.16.2", - "dashmap 6.1.0", + "dashmap", "derive-getters", "derive_builder", "dialoguer", @@ -8738,7 +8724,7 @@ dependencies = [ "asynchronous-codec", "bytes", "crossbeam-queue", - "dashmap 5.5.3", + "dashmap", "futures-channel", "futures-io", "futures-task", diff --git a/lib/bindings/python/Cargo.lock b/lib/bindings/python/Cargo.lock index 514bf90361..e24e54e6d8 100644 --- a/lib/bindings/python/Cargo.lock +++ b/lib/bindings/python/Cargo.lock @@ -978,20 +978,6 @@ dependencies = [ "parking_lot_core", ] -[[package]] -name = "dashmap" -version = "6.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" -dependencies = [ - "cfg-if 1.0.0", - "crossbeam-utils", - "hashbrown 0.14.5", - "lock_api", - "once_cell", - "parking_lot_core", -] - [[package]] name = "data-encoding" version = "2.9.0" @@ -1201,7 +1187,7 @@ dependencies = [ "candle-core", "chrono", "cudarc", - "dashmap 6.1.0", + "dashmap", "derive-getters", "derive_builder", "dialoguer", @@ -5856,7 +5842,7 @@ dependencies = [ "asynchronous-codec", "bytes", "crossbeam-queue", - "dashmap 5.5.3", + "dashmap", "futures-channel", "futures-io", "futures-task", diff --git a/lib/llm/Cargo.toml b/lib/llm/Cargo.toml index 485455e39b..4905ede4c2 100644 --- a/lib/llm/Cargo.toml +++ b/lib/llm/Cargo.toml @@ -85,7 +85,7 @@ derive-getters = "0.5" offset-allocator = "0.2" regex = "1" rayon = "1" -dashmap = "6" +dashmap = { version = "5.5.3" } # input/text dialoguer = { version = "0.11", default-features = false, features = ["editor", "history"] } From 45c74d10b38738cf2d53cddfdbde7cf96a68f27d Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Wed, 6 Aug 2025 17:07:01 -0700 Subject: [PATCH 25/27] inter-router comm doc --- docs/architecture/kv_cache_routing.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/docs/architecture/kv_cache_routing.md b/docs/architecture/kv_cache_routing.md index dbd1c6484f..c8be8c3f13 100644 --- a/docs/architecture/kv_cache_routing.md +++ b/docs/architecture/kv_cache_routing.md @@ -172,3 +172,15 @@ The KVIndexer builds and maintains a global view of cached blocks in a prefix tr The KVIndexer has a method `find_matches_for_request`, which takes in tokens and returns a dictionary with keys of worker id and values of the number of matched KV Blocks. +### Inter-Router Communication + +In multi-router deployments, each router only observes a subset of requests. To maintain a consistent global view of active sequences and KV cache states, routers broadcast their local actions to other replicas through three synchronization events: + +1. **AddRequest**: Published when assigning a request to a worker, containing the request ID, worker ID, token sequence blocks, and overlap score. This updates other routers' tracking of which blocks are in use. + +2. **MarkPrefillCompleted**: Published when a request transitions from prefill to decode phase, signaling that prefill tokens should no longer count toward the worker's active prefill load. + +3. **Free**: Published when a request completes and its resources are released, allowing other routers to update their block reference counts. + +Each event includes a unique router ID to prevent processing of self-generated events. This asynchronous communication ensures all routers maintain synchronized KV cache state for optimal routing decisions despite handling different request streams. + From b853c99762672935aee02399b8738e30a00b2084 Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Wed, 6 Aug 2025 17:11:13 -0700 Subject: [PATCH 26/27] make clear defaults in doc --- docs/architecture/kv_cache_routing.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/architecture/kv_cache_routing.md b/docs/architecture/kv_cache_routing.md index c8be8c3f13..6ca0181186 100644 --- a/docs/architecture/kv_cache_routing.md +++ b/docs/architecture/kv_cache_routing.md @@ -17,13 +17,13 @@ For performance testing, compare a typical workload with `--router-mode random|r The KV-aware routing arguments: -- `--kv-overlap-score-weight`: Sets the amount of weighting on overlaps with prefix caches, which directly contributes to the prefill cost. A large weight is expected to yield a better TTFT (at the expense of worse ITL). When set to 0, prefix caches are not considered at all (falling back to pure load balancing behavior on the active blocks). +- `--kv-overlap-score-weight`: Sets the amount of weighting on overlaps with prefix caches, which directly contributes to the prefill cost. A large weight is expected to yield a better TTFT (at the expense of worse ITL). When set to 0, prefix caches are not considered at all (falling back to pure load balancing behavior on the active blocks). Defaults to 1. -- `--router-temperature`: Sets the temperature when randomly selecting workers to route to via softmax sampling on the router cost logits. Setting it to 0 recovers the deterministic behavior where the min logit is picked. +- `--router-temperature`: Sets the temperature when randomly selecting workers to route to via softmax sampling on the router cost logits. Setting it to 0 (default) recovers the deterministic behavior where the min logit is picked. -- `--use-kv-events`: Sets whether to listen to KV events for maintaining the global view of cached blocks. If true, then we use the `KvIndexer` to listen to the block creation and deletion events. If false, `ApproxKvIndexer`, which assumes the kv cache of historical prompts exists for fixed time durations (hard-coded to 120s), is used to predict the kv cache hit ratio in each engine. Set false if your backend engine does not emit KV events. +- `--use-kv-events`/`--no-kv-events`: Sets whether to listen to KV events for maintaining the global view of cached blocks. If true (default), then we use the `KvIndexer` to listen to the block creation and deletion events. If false, `ApproxKvIndexer`, which assumes the kv cache of historical prompts exists for fixed time durations (hard-coded to 120s), is used to predict the kv cache hit ratio in each engine. Set false if your backend engine does not emit KV events. -- `--router-replica-sync`: Enables state synchronization between multiple router replicas via NATS. When enabled, router replicas share their view of KV cache distribution and active sequences, allowing all routers to make optimal routing decisions even when requests are distributed across multiple router instances. This improves fault tolerance and routing accuracy in multi-router deployments. +- `--router-replica-sync`: Enables state synchronization between multiple router replicas via NATS. Disabled by default, and can be enabled by passing the flag in. When enabled, router replicas share their view of KV cache distribution and active sequences, allowing all routers to make optimal routing decisions even when requests are distributed across multiple router instances. This improves fault tolerance and routing accuracy in multi-router deployments. ## Architecture From 19a1aad471ed7ea64e8e1f826f0edf520f961c1e Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Thu, 7 Aug 2025 15:32:42 -0700 Subject: [PATCH 27/27] improve formula readability --- lib/llm/src/kv_router/scheduler.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/lib/llm/src/kv_router/scheduler.rs b/lib/llm/src/kv_router/scheduler.rs index 398d19ab9c..6603b0e906 100644 --- a/lib/llm/src/kv_router/scheduler.rs +++ b/lib/llm/src/kv_router/scheduler.rs @@ -375,10 +375,11 @@ impl WorkerSelector for DefaultWorkerSelector { worker_logits.insert(worker_id, logit); + let overlap_weight = self.kv_router_config.overlap_score_weight; tracing::info!( - "Formula for {worker_id}: {logit:.3} = {:.1} * {potential_prefill_block:.3} + {decode_block:.3} (cached_blocks: {})", - self.kv_router_config.overlap_score_weight, - overlap, + "Formula for {worker_id} with {overlap} cached blocks: {logit:.3} \ + = {overlap_weight:.1} * prefill_blocks + decode_blocks \ + = {overlap_weight:.1} * {potential_prefill_block:.3} + {decode_block:.3}" ); }