diff --git a/Cargo.lock b/Cargo.lock index 5c4b0afd0c..655112a2dd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1556,6 +1556,20 @@ dependencies = [ "parking_lot_core", ] +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if 1.0.0", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "data-encoding" version = "2.9.0" @@ -1899,6 +1913,7 @@ dependencies = [ "chrono", "criterion", "cudarc 0.16.2", + "dashmap 6.1.0", "derive-getters", "derive_builder", "dialoguer", @@ -8723,7 +8738,7 @@ dependencies = [ "asynchronous-codec", "bytes", "crossbeam-queue", - "dashmap", + "dashmap 5.5.3", "futures-channel", "futures-io", "futures-task", diff --git a/components/metrics/src/lib.rs b/components/metrics/src/lib.rs index 7fb632b881..1ef523b9ee 100644 --- a/components/metrics/src/lib.rs +++ b/components/metrics/src/lib.rs @@ -84,7 +84,7 @@ use std::net::SocketAddr; use std::time::Duration as StdDuration; use dynamo_llm::kv_router::protocols::{ForwardPassMetrics, LoadMetrics}; -use dynamo_llm::kv_router::scheduler::Endpoint; +use dynamo_llm::kv_router::scoring::Endpoint; use dynamo_llm::kv_router::scoring::ProcessedEndpoints; use dynamo_runtime::{ diff --git a/lib/bindings/python/Cargo.lock b/lib/bindings/python/Cargo.lock index d677224444..514bf90361 100644 --- a/lib/bindings/python/Cargo.lock +++ b/lib/bindings/python/Cargo.lock @@ -978,6 +978,20 @@ dependencies = [ "parking_lot_core", ] +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if 1.0.0", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "data-encoding" version = "2.9.0" @@ -1187,6 +1201,7 @@ dependencies = [ "candle-core", "chrono", "cudarc", + "dashmap 6.1.0", "derive-getters", "derive_builder", "dialoguer", @@ -5841,7 +5856,7 @@ dependencies = [ "asynchronous-codec", "bytes", "crossbeam-queue", - "dashmap", + "dashmap 5.5.3", "futures-channel", "futures-io", "futures-task", diff --git a/lib/llm/Cargo.toml b/lib/llm/Cargo.toml index f36994144d..485455e39b 100644 --- a/lib/llm/Cargo.toml +++ b/lib/llm/Cargo.toml @@ -85,6 +85,7 @@ derive-getters = "0.5" offset-allocator = "0.2" regex = "1" rayon = "1" +dashmap = "6" # input/text dialoguer = { version = "0.11", default-features = false, features = ["editor", "history"] } diff --git a/lib/llm/src/kv_router.rs b/lib/llm/src/kv_router.rs index 2e5d02fe82..6a2a0beb79 100644 --- a/lib/llm/src/kv_router.rs +++ b/lib/llm/src/kv_router.rs @@ -15,11 +15,11 @@ use dynamo_runtime::{ protocols::annotated::Annotated, }; use futures::stream::{self, StreamExt}; -use tokio::sync::Mutex; pub mod approx; pub mod indexer; pub mod metrics_aggregator; +pub mod prefill_counter; pub mod protocols; pub mod publisher; pub mod recorder; @@ -48,9 +48,18 @@ use dynamo_runtime::traits::events::EventSubscriber; // [gluo TODO] shouldn't need to be public // this should be discovered from the component + +// for metric scraping (pull-based) +pub const KV_METRICS_ENDPOINT: &str = "load_metrics"; + +// for metric publishing (push-based) pub const KV_EVENT_SUBJECT: &str = "kv_events"; pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate"; -pub const KV_METRICS_ENDPOINT: &str = "load_metrics"; +pub const KV_METRICS_SUBJECT: &str = "kv_metrics"; + +// for inter-router comms +pub const PREFILL_SUBJECT: &str = "prefill_events"; +pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events"; /// A trait that users can implement to define custom selection logic pub trait WorkerSelector { @@ -135,10 +144,6 @@ pub struct KvRouter { scheduler: KvScheduler, block_size: u32, - - // To ensure blocking reads / writes - // TODO: benchmark tradeoffs - find_best_match_mutex: Mutex<()>, } impl KvRouter { @@ -175,13 +180,8 @@ impl KvRouter { )) }; - let scheduler = KvScheduler::start( - component.namespace().clone(), - block_size, - instances_rx, - selector, - ) - .await?; + let scheduler = + KvScheduler::start(component.clone(), block_size, instances_rx, selector).await?; // [gluo TODO] try subscribe_with_type::, // error checking below will be different. @@ -215,7 +215,6 @@ impl KvRouter { indexer, scheduler, block_size, - find_best_match_mutex: Mutex::new(()), // Add this }) } @@ -227,10 +226,6 @@ impl KvRouter { context_id: &str, tokens: &[u32], ) -> anyhow::Result<(i64, u32)> { - // Acquire mutex to serialize access - // TODO: may as well make all the subroutines synchronous if benchmarking favors this - let _guard = self.find_best_match_mutex.lock().await; - let isl_tokens = tokens.len(); let block_hashes = compute_block_hash_for_seq(tokens, self.block_size); @@ -263,17 +258,14 @@ impl KvRouter { Ok((best_worker_id, overlap_amount)) } - /// Free all blocks associated with a request - pub async fn mark_prefill_completed(&self, request_id: &String) { + pub async fn mark_prefill_completed(&self, request_id: &str) { self.scheduler.mark_prefill_completed(request_id).await } - /// Free all blocks associated with a request - pub async fn free(&self, request_id: &String) { + pub async fn free(&self, request_id: &str) { self.scheduler.free(request_id).await } - /// Get the block size this router was configured with pub fn block_size(&self) -> u32 { self.block_size } diff --git a/lib/llm/src/kv_router/metrics_aggregator.rs b/lib/llm/src/kv_router/metrics_aggregator.rs index 018f6cf84a..7ab4e1372c 100644 --- a/lib/llm/src/kv_router/metrics_aggregator.rs +++ b/lib/llm/src/kv_router/metrics_aggregator.rs @@ -18,7 +18,7 @@ use std::sync::Once; pub use crate::kv_router::protocols::{ForwardPassMetrics, LoadMetrics, PredictiveLoadMetrics}; use crate::kv_router::KV_METRICS_ENDPOINT; -use crate::kv_router::scheduler::Endpoint; +use crate::kv_router::scoring::Endpoint; use crate::kv_router::ProcessedEndpoints; use dynamo_runtime::component::Component; use dynamo_runtime::{service::EndpointInfo, utils::Duration, Result}; diff --git a/lib/llm/src/kv_router/prefill_counter.rs b/lib/llm/src/kv_router/prefill_counter.rs new file mode 100644 index 0000000000..5052faf752 --- /dev/null +++ b/lib/llm/src/kv_router/prefill_counter.rs @@ -0,0 +1,545 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use anyhow::Result; +use dynamo_runtime::component::Component; +use dynamo_runtime::traits::events::{EventPublisher, EventSubscriber}; +use futures::StreamExt; +use std::sync::Arc; +use uuid::Uuid; + +use super::protocols::{PrefillEvent, PrefillEventData}; +use crate::kv_router::PREFILL_SUBJECT; +use dashmap::DashMap; +use std::collections::HashMap; +use std::hash::Hash; + +pub fn get_snapshot(state: &DashMap) -> HashMap +where + K: Clone + Hash + Eq, + V: Copy, +{ + state + .iter() + .map(|entry| (entry.key().clone(), *entry.value())) + .collect() +} + +#[derive(Default)] +struct PrefillCounterState { + tokens_map: HashMap, // Plain HashMap + running_sum: usize, // Plain usize +} + +impl PrefillCounterState { + fn insert(&mut self, key: String, value: usize) -> Option { + // Takes &mut self + let old_value = self.tokens_map.insert(key, value); + + if let Some(old) = old_value { + self.running_sum -= old; + self.running_sum += value; + } else { + self.running_sum += value; + } + + old_value + } + + fn remove(&mut self, key: &str) -> Option { + // Takes &mut self + let removed = self.tokens_map.remove(key); + + if let Some(value) = removed { + self.running_sum -= value; + } + + removed + } + + fn running_sum(&self) -> usize { + self.running_sum + } +} + +/// A counter that tracks pending prefill tokens for each request. +/// +/// This struct maintains a local hashmap of request_id to token count, +/// and a running sum of all tokens. It no longer handles its own subscriptions. +#[derive(Default)] // Removed Clone +pub struct PrefillCounter { + state: PrefillCounterState, // No Arc, direct ownership +} + +impl PrefillCounter { + // Internal methods for direct state manipulation (no publishing) + fn insert_direct(&mut self, request_id: String, tokens: usize) -> Option { + // Takes &mut self + self.state.insert(request_id, tokens) + } + + fn remove_direct(&mut self, request_id: &str) -> Option { + // Takes &mut self + self.state.remove(request_id) + } + + #[allow(dead_code)] + fn update_direct(&mut self, request_id: String, new_tokens: usize) { + // Takes &mut self + if let Some(old_tokens) = self.state.tokens_map.get(&request_id).copied() { + let delta = new_tokens as isize - old_tokens as isize; + self.state.running_sum = (self.state.running_sum as isize + delta) as usize; + self.state.tokens_map.insert(request_id, new_tokens); + } + } + + pub fn get(&self, request_id: &str) -> Option { + self.state.tokens_map.get(request_id).copied() + } + + pub fn running_sum(&self) -> usize { + self.state.running_sum() + } + + pub fn len(&self) -> usize { + self.state.tokens_map.len() + } + + pub fn is_empty(&self) -> bool { + self.state.tokens_map.is_empty() + } +} + +/// A collection of PrefillCounters for multiple workers with centralized event handling +pub struct PrefillCountersMultiWorker { + pub counters: Arc>, + pub request_to_workers: Arc>, + component: Component, + router_id: Uuid, +} + +impl PrefillCountersMultiWorker { + // Helper function to handle new prefill logic + fn handle_new_prefill( + counters: &Arc>, + request_to_workers: &Arc>, + request_id: &str, + worker_id: i64, + tokens: usize, + ) { + // Check if request already exists + if let Some(existing_worker_id) = request_to_workers.get(request_id) { + tracing::warn!( + "Request {} already exists for worker {}, but trying to add to worker {}", + request_id, + *existing_worker_id, + worker_id + ); + } + + // Update mapping + request_to_workers.insert(request_id.to_string(), worker_id); + + // Get or create counter and insert using get_mut + if let Some(mut counter) = counters.get_mut(&worker_id) { + counter.insert_direct(request_id.to_string(), tokens); + } else { + tracing::warn!( + "Worker {} does not exist, creating new PrefillCounter", + worker_id + ); + let mut new_counter = PrefillCounter::default(); + new_counter.insert_direct(request_id.to_string(), tokens); + counters.insert(worker_id, new_counter); + }; + } + + // Helper function to handle complete prefill logic + fn handle_complete_prefill( + counters: &Arc>, + request_to_workers: &Arc>, + request_id: &str, + ) -> Option { + // Remove from request_to_workers and get the worker_id + let Some((_, worker_id)) = request_to_workers.remove(request_id) else { + tracing::warn!("Request {} not found in request_to_workers", request_id); + return None; + }; + + // Use the worker_id from request_to_workers with get_mut + let Some(mut counter) = counters.get_mut(&worker_id) else { + tracing::warn!( + "No counter found for worker {} for request {}", + worker_id, + request_id + ); + return None; + }; + + let removed_tokens = counter.remove_direct(request_id); + if removed_tokens.is_none() { + tracing::warn!("Attempted to remove non-existent request: {}", request_id); + } + + removed_tokens + } + + pub fn new(component: Component) -> Self { + let counters = Arc::new(DashMap::new()); + let request_to_workers = Arc::new(DashMap::new()); + let router_id = Uuid::new_v4(); + + let multi_worker = Self { + counters: counters.clone(), + request_to_workers: request_to_workers.clone(), + component: component.clone(), + router_id, + }; + + // Start the subscription loop + let counters_clone = counters.clone(); + let request_to_workers_clone = request_to_workers.clone(); + let component_clone = component.clone(); + let router_id_clone = router_id; + + tokio::spawn(async move { + if let Err(e) = Self::subscribe_to_events( + counters_clone, + request_to_workers_clone, + component_clone, + router_id_clone, + ) + .await + { + tracing::error!("Error in prefill events subscription: {}", e); + } + }); + + multi_worker + } + + /// Background task to subscribe to prefill events and update all counters + async fn subscribe_to_events( + counters: Arc>, + request_to_workers: Arc>, + component: Component, + router_id: Uuid, + ) -> Result<()> { + let mut subscriber = component + .subscribe_with_type::(PREFILL_SUBJECT) + .await?; + + while let Some(result) = subscriber.next().await { + let Ok(event) = result else { + tracing::error!("Error receiving prefill event: {}", result.unwrap_err()); + continue; + }; + + // Skip events emitted by itself + if event.router_id == router_id { + continue; + } + + match event.data { + PrefillEventData::NewPrefill(tokens) => { + Self::handle_new_prefill( + &counters, + &request_to_workers, + &event.request_id, + event.worker_id, + tokens, + ); + } + PrefillEventData::UpdatePrefill(_) => { + // Do nothing for now + continue; + } + PrefillEventData::CompletePrefill => { + Self::handle_complete_prefill( + &counters, + &request_to_workers, + &event.request_id, + ); + } + } + } + + Ok(()) + } + + pub async fn add_prefill( + &self, + worker_id: i64, + request_id: String, + new_tokens: usize, + ) -> Result<()> { + let event = PrefillEvent { + request_id: request_id.clone(), + worker_id, + data: PrefillEventData::NewPrefill(new_tokens), + router_id: self.router_id, + }; + self.component.publish(PREFILL_SUBJECT, &event).await?; + + // Use the helper function + Self::handle_new_prefill( + &self.counters, + &self.request_to_workers, + &request_id, + worker_id, + new_tokens, + ); + + Ok(()) + } + + pub async fn remove_prefill(&self, request_id: &str) -> Result> { + // Send the event first with dummy worker_id + let event = PrefillEvent { + request_id: request_id.to_string(), + worker_id: 0, // Dummy worker_id + data: PrefillEventData::CompletePrefill, + router_id: self.router_id, + }; + self.component.publish(PREFILL_SUBJECT, &event).await?; + + // Use the helper function + Ok(Self::handle_complete_prefill( + &self.counters, + &self.request_to_workers, + request_id, + )) + } + + /// Get the running sums for all workers as a HashMap + pub async fn running_sums(&self) -> HashMap { + self.counters + .iter() + .map(|entry| (*entry.key(), entry.value().running_sum())) + .collect() + } + + /// Get a specific counter's running sum + pub async fn get_worker_sum(&self, worker_id: i64) -> Option { + self.counters.get(&worker_id).map(|c| c.running_sum()) + } +} + +#[cfg(test)] +mod integration_tests { + use super::*; + use dynamo_runtime::{DistributedRuntime, Runtime}; + use std::sync::{Arc, Mutex}; + use std::thread; + use tokio::time::Duration; + + #[test] + #[ignore] + fn test_prefill_counter_multiworker_synchronization() -> Result<()> { + // Initialize logging once + dynamo_runtime::logging::init(); + + let worker_id_1 = 1; + let worker_id_2 = 2; + let tokens_per_request = 100; + let requests_per_worker = 10; + + // Shared state for collecting results from both threads + let results1 = Arc::new(Mutex::new(None)); + let results2 = Arc::new(Mutex::new(None)); + let final_results1 = Arc::new(Mutex::new(None)); + let final_results2 = Arc::new(Mutex::new(None)); + + let results1_clone = results1.clone(); + let results2_clone = results2.clone(); + let final_results1_clone = final_results1.clone(); + let final_results2_clone = final_results2.clone(); + + // Thread 1: First distributed runtime with multi_worker1 + let handle1 = thread::spawn(move || { + let rt = tokio::runtime::Runtime::new().unwrap(); + + rt.block_on(async { + // Create runtime and distributed runtime + let runtime = Runtime::from_current()?; + let distributed = DistributedRuntime::from_settings(runtime.clone()).await?; + + // Create namespace and components with same names + let namespace = distributed.namespace("test_prefill_multiworker")?; + let component = namespace + .component("counters")? + .service_builder() + .create() + .await?; + + // Create first PrefillCountersMultiWorker instance + let multi_worker1 = PrefillCountersMultiWorker::new(component); + + // Give some time for subscribers to initialize + tokio::time::sleep(Duration::from_millis(3000)).await; + + // Send requests to multi_worker1's worker + for i in 0..requests_per_worker { + let request_id = format!("mw1_request_{}", i); + multi_worker1 + .add_prefill(worker_id_1, request_id, tokens_per_request) + .await?; + } + + // Wait for synchronization + tokio::time::sleep(Duration::from_millis(1000)).await; + + // Get running sums after additions + let sums1 = multi_worker1.running_sums().await; + *results1_clone.lock().unwrap() = Some(sums1); + + // Wait for other thread to add its requests + tokio::time::sleep(Duration::from_millis(2000)).await; + + // Remove all requests from multi_worker1 + for i in 0..requests_per_worker { + let request_id = format!("mw1_request_{}", i); + multi_worker1.remove_prefill(&request_id).await?; + } + + // Wait for removal synchronization + tokio::time::sleep(Duration::from_millis(1000)).await; + + // Get final running sums + let final_sums1 = multi_worker1.running_sums().await; + *final_results1_clone.lock().unwrap() = Some(final_sums1); + + // Keep runtime alive a bit longer for synchronization + tokio::time::sleep(Duration::from_millis(1000)).await; + + // Shutdown runtime + runtime.shutdown(); + + Ok::<(), anyhow::Error>(()) + }) + }); + + // Thread 2: Second distributed runtime with multi_worker2 + let handle2 = thread::spawn(move || { + let rt = tokio::runtime::Runtime::new().unwrap(); + + rt.block_on(async { + // Create runtime and distributed runtime + let runtime = Runtime::from_current()?; + let distributed = DistributedRuntime::from_settings(runtime.clone()).await?; + + // Create namespace and components with same names + let namespace = distributed.namespace("test_prefill_multiworker")?; + let component = namespace + .component("counters")? + .service_builder() + .create() + .await?; + + // Create second PrefillCountersMultiWorker instance + let multi_worker2 = PrefillCountersMultiWorker::new(component); + + // Give some time for subscribers to initialize + tokio::time::sleep(Duration::from_millis(3000)).await; + + // Wait a bit to ensure multi_worker1 has started + tokio::time::sleep(Duration::from_millis(500)).await; + + // Send requests to multi_worker2's worker + for i in 0..requests_per_worker { + let request_id = format!("mw2_request_{}", i); + multi_worker2 + .add_prefill(worker_id_2, request_id, tokens_per_request) + .await?; + } + + // Wait for synchronization + tokio::time::sleep(Duration::from_millis(1000)).await; + + // Get running sums after additions + let sums2 = multi_worker2.running_sums().await; + *results2_clone.lock().unwrap() = Some(sums2); + + // Wait for other thread to remove its requests + tokio::time::sleep(Duration::from_millis(2000)).await; + + // Remove all requests from multi_worker2 + for i in 0..requests_per_worker { + let request_id = format!("mw2_request_{}", i); + multi_worker2.remove_prefill(&request_id).await?; + } + + // Wait for removal synchronization + tokio::time::sleep(Duration::from_millis(1000)).await; + + // Get final running sums + let final_sums2 = multi_worker2.running_sums().await; + *final_results2_clone.lock().unwrap() = Some(final_sums2); + + // Keep runtime alive a bit longer for synchronization + tokio::time::sleep(Duration::from_millis(1000)).await; + + // Shutdown runtime + runtime.shutdown(); + + Ok::<(), anyhow::Error>(()) + }) + }); + + // Wait for both threads to complete + handle1.join().unwrap()?; + handle2.join().unwrap()?; + + // Extract results + let sums1 = results1.lock().unwrap().take().unwrap(); + let sums2 = results2.lock().unwrap().take().unwrap(); + let final_sums1 = final_results1.lock().unwrap().take().unwrap(); + let final_sums2 = final_results2.lock().unwrap().take().unwrap(); + + // Verify both multi-workers see all requests + assert_eq!( + sums1.get(&worker_id_1), + Some(&(requests_per_worker * tokens_per_request)), + "MultiWorker1 should see worker 1's requests" + ); + assert_eq!( + sums1.get(&worker_id_2), + Some(&(requests_per_worker * tokens_per_request)), + "MultiWorker1 should see worker 2's requests" + ); + assert_eq!( + sums2.get(&worker_id_1), + Some(&(requests_per_worker * tokens_per_request)), + "MultiWorker2 should see worker 1's requests" + ); + assert_eq!( + sums2.get(&worker_id_2), + Some(&(requests_per_worker * tokens_per_request)), + "MultiWorker2 should see worker 2's requests" + ); + + // Verify both multi-workers show zero sums after removal + assert_eq!( + final_sums1.get(&worker_id_1).copied().unwrap_or(0), + 0, + "MultiWorker1 should show zero for worker 1" + ); + assert_eq!( + final_sums1.get(&worker_id_2).copied().unwrap_or(0), + 0, + "MultiWorker1 should show zero for worker 2" + ); + assert_eq!( + final_sums2.get(&worker_id_1).copied().unwrap_or(0), + 0, + "MultiWorker2 should show zero for worker 1" + ); + assert_eq!( + final_sums2.get(&worker_id_2).copied().unwrap_or(0), + 0, + "MultiWorker2 should show zero for worker 2" + ); + + Ok(()) + } +} diff --git a/lib/llm/src/kv_router/protocols.rs b/lib/llm/src/kv_router/protocols.rs index e6429f3909..24f33bd0d1 100644 --- a/lib/llm/src/kv_router/protocols.rs +++ b/lib/llm/src/kv_router/protocols.rs @@ -1,20 +1,9 @@ // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use crate::tokens::Token; + +use crate::tokens::{SequenceHash, Token}; use serde::{Deserialize, Serialize}; +use uuid::Uuid; #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct RouterRequest { @@ -128,6 +117,56 @@ impl From for ExternalSequenceBlockHash { } } +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct PrefillEvent { + pub request_id: String, + pub worker_id: i64, + pub data: PrefillEventData, + pub router_id: Uuid, +} + +/// Represents the different stages of prefilling tokens for a request. +/// +/// Each variant contains a `usize` representing the number of tokens +/// that are pending prefill in the request. +#[derive(Serialize, Deserialize, Debug, Clone)] +pub enum PrefillEventData { + NewPrefill(usize), + UpdatePrefill(usize), + CompletePrefill, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ActiveSequenceEvent { + pub request_id: String, + pub worker_id: i64, + pub data: ActiveSequenceEventData, + pub router_id: Uuid, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub enum ActiveSequenceEventData { + AddRequest { + token_sequence: Vec, + isl: usize, + overlap: u32, + }, + Free, + MarkPrefillCompleted, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ActiveBlockEvent { + pub request_id: String, + pub data: ActiveBlockEventData, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub enum ActiveBlockEventData { + NewBlock(Vec), + FreeBlock, +} + /// Represents a collection of cache events and a shutdown flag. #[derive(Serialize, Deserialize, Debug, Clone)] pub struct KvCacheEvents { diff --git a/lib/llm/src/kv_router/publisher.rs b/lib/llm/src/kv_router/publisher.rs index 2f75b59c06..22c88b08d0 100644 --- a/lib/llm/src/kv_router/publisher.rs +++ b/lib/llm/src/kv_router/publisher.rs @@ -16,12 +16,13 @@ use crate::kv_router::{ indexer::{compute_block_hash_for_seq, RouterEvent}, protocols::*, - KV_EVENT_SUBJECT, KV_METRICS_ENDPOINT, + scoring::LoadEvent, + KV_EVENT_SUBJECT, KV_METRICS_ENDPOINT, KV_METRICS_SUBJECT, }; use async_trait::async_trait; use dynamo_runtime::traits::{events::EventPublisher, DistributedRuntimeProvider}; use dynamo_runtime::{ - component::Component, + component::{Component, Namespace}, pipeline::{ network::Ingress, AsyncEngine, AsyncEngineContextProvider, ManyOut, ResponseStream, SingleIn, @@ -499,9 +500,18 @@ impl WorkerMetricsPublisher { pub async fn create_endpoint(&self, component: Component) -> Result<()> { let mut metrics_rx = self.rx.clone(); - let handler = Arc::new(KvLoadEndpoingHander::new(metrics_rx.clone())); + let handler = Arc::new(KvLoadEndpointHandler::new(metrics_rx.clone())); let handler = Ingress::for_engine(handler)?; + // let worker_id = component + // .drt() + // .primary_lease() + // .map(|lease| lease.id()) + // .unwrap_or_else(|| { + // tracing::warn!("Component is static, assuming worker_id of 0"); + // 0 + // }); + component .endpoint(KV_METRICS_ENDPOINT) .endpoint_builder() @@ -513,13 +523,90 @@ impl WorkerMetricsPublisher { .start() .await } + + /// Starts a background task to publish metrics over NATS + /// + /// This task monitors metric changes (specifically kv_active_blocks and num_requests_waiting) + /// and publishes stable metrics to NATS after they've been unchanged for 1ms. + #[allow(dead_code)] + fn start_nats_metrics_publishing(&self, namespace: Namespace, worker_id: i64) { + let nats_rx = self.rx.clone(); + + tokio::spawn(async move { + let mut rx = nats_rx; + let mut last_kv_active_blocks: Option = None; + let mut last_num_requests_waiting: Option = None; + let mut pending_publish: Option> = None; + let mut publish_timer = + Box::pin(tokio::time::sleep(tokio::time::Duration::from_secs(0))); + publish_timer.as_mut().reset(tokio::time::Instant::now()); // Complete immediately + + loop { + tokio::select! { + // Handle metrics changes + result = rx.changed() => { + if result.is_err() { + tracing::debug!( + "Metrics publisher sender dropped, stopping NATS background task" + ); + break; + } + + let metrics = rx.borrow_and_update().clone(); + + // Extract the values we care about + let current_kv_active_blocks = metrics.kv_stats.kv_active_blocks; + let current_num_requests_waiting = + metrics.worker_stats.num_requests_waiting; + + // Check if these specific metrics have changed + let has_changed = match (last_kv_active_blocks, last_num_requests_waiting) { + (Some(last_kv), Some(last_requests)) => { + last_kv != current_kv_active_blocks + || last_requests != current_num_requests_waiting + } + _ => true, // First time, consider it changed + }; + + // If load metrics changed, schedule a publish + if has_changed { + pending_publish = Some(metrics.clone()); + last_kv_active_blocks = Some(current_kv_active_blocks); + last_num_requests_waiting = Some(current_num_requests_waiting); + + // Start the 1ms timer + publish_timer.as_mut().reset( + tokio::time::Instant::now() + tokio::time::Duration::from_millis(1) + ); + } + } + // Timer expired - publish if we have pending metrics + _ = &mut publish_timer => { + if let Some(metrics) = pending_publish.take() { + // Create LoadEvent wrapping the metrics + let load_event = LoadEvent { + worker_id, + data: (*metrics).clone(), + }; + + if let Err(e) = + namespace.publish(KV_METRICS_SUBJECT, &load_event).await + { + tracing::warn!("Failed to publish metrics over NATS: {}", e); + } + } + } + } + } + }); + } } -struct KvLoadEndpoingHander { +struct KvLoadEndpointHandler { metrics_rx: tokio::sync::watch::Receiver>, } -impl KvLoadEndpoingHander { +impl KvLoadEndpointHandler { pub fn new(metrics_rx: tokio::sync::watch::Receiver>) -> Self { Self { metrics_rx } } @@ -527,7 +614,7 @@ impl KvLoadEndpoingHander { #[async_trait] impl AsyncEngine, ManyOut>, Error> - for KvLoadEndpoingHander + for KvLoadEndpointHandler { async fn generate( &self, @@ -880,3 +967,116 @@ mod test_exponential_backoff { assert!(max_calculated <= MAX_BACKOFF_MS); } } + +#[cfg(test)] +mod test_worker_metrics_publisher { + use super::*; + use crate::kv_router::protocols::{ForwardPassMetrics, KvStats, WorkerStats}; + use dynamo_runtime::traits::events::EventSubscriber; // Add this import + use dynamo_runtime::{DistributedRuntime, Runtime}; + use futures::StreamExt; + + #[tokio::test] + #[ignore] // Mark as ignored as requested + async fn test_metrics_publishing_behavior() -> Result<()> { + // Set up runtime and namespace + let rt = Runtime::from_current().unwrap(); + let drt = DistributedRuntime::from_settings(rt.clone()).await?; + let namespace = drt.namespace("test".to_string())?; + + // Create a subscriber for the metrics events using subscribe_with_type + let mut subscriber = namespace + .subscribe_with_type::(KV_METRICS_SUBJECT) + .await + .unwrap(); + + // Create WorkerMetricsPublisher + let publisher = WorkerMetricsPublisher::new().unwrap(); + let worker_id = 1234; + + // Start NATS metrics publishing + publisher.start_nats_metrics_publishing(namespace.clone(), worker_id); + + // Allow some time for the background task to start + tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + + // Test 1: Publish 10 different metrics with 0.5ms intervals + // Only the last one should be published after 1ms of stability + for i in 0..10 { + let metrics = Arc::new(ForwardPassMetrics { + kv_stats: KvStats { + kv_active_blocks: (i * 100) as u64, // Changing load metric + kv_total_blocks: 1000, + gpu_cache_usage_perc: 0.5, + gpu_prefix_cache_hit_rate: 0.8, + }, + worker_stats: WorkerStats { + num_requests_waiting: (i * 10) as u64, // Changing load metric + data_parallel_rank: None, + request_active_slots: 50, + request_total_slots: 100, + }, + spec_decode_stats: None, + }); + + publisher.publish(metrics).unwrap(); + tokio::time::sleep(tokio::time::Duration::from_micros(100)).await; + } + + // Wait a bit more than 1ms to ensure the last metric is published + tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + + // Verify we receive exactly one event with the last metric values + let result = + tokio::time::timeout(tokio::time::Duration::from_millis(500), subscriber.next()) + .await + .unwrap(); + + let event = result.unwrap().unwrap(); // Unwrap the Option and the Result + assert_eq!(event.worker_id, worker_id); + assert_eq!(event.data.kv_stats.kv_active_blocks, 900); // Last value: 9 * 100 + assert_eq!(event.data.worker_stats.num_requests_waiting, 90); // Last value: 9 * 10 + + // Ensure no more events are waiting + let no_msg = + tokio::time::timeout(tokio::time::Duration::from_millis(50), subscriber.next()).await; + assert!(no_msg.is_err(), "Expected no more messages, but found one"); + + // Test 2: Publish 10 more metrics where everything changes EXCEPT the load metrics + for i in 0..10 { + let metrics = Arc::new(ForwardPassMetrics { + kv_stats: KvStats { + kv_active_blocks: 900, // Keep same as last published + kv_total_blocks: 1000 + (i * 100) as u64, // Change other metrics + gpu_cache_usage_perc: 0.3 + (i as f32 * 0.05), // Change other metrics + gpu_prefix_cache_hit_rate: 0.7 + (i as f32 * 0.01), // Change other metrics + }, + worker_stats: WorkerStats { + num_requests_waiting: 90, // Keep same as last published + data_parallel_rank: None, + request_active_slots: 40 + (i * 5) as u64, // Change other metrics + request_total_slots: 100 + (i * 10) as u64, // Change other metrics + }, + spec_decode_stats: None, + }); + + publisher.publish(metrics).unwrap(); + tokio::time::sleep(tokio::time::Duration::from_micros(100)).await; + } + + // Wait to ensure no events are published + tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + + // Verify no events are received + let no_msg = + tokio::time::timeout(tokio::time::Duration::from_millis(50), subscriber.next()).await; + assert!( + no_msg.is_err(), + "Expected no messages when load metrics don't change" + ); + + rt.shutdown(); + + Ok(()) + } +} diff --git a/lib/llm/src/kv_router/scheduler.rs b/lib/llm/src/kv_router/scheduler.rs index 797f738187..ea95d009a9 100644 --- a/lib/llm/src/kv_router/scheduler.rs +++ b/lib/llm/src/kv_router/scheduler.rs @@ -1,36 +1,22 @@ // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use dynamo_runtime::component::Namespace; + +use dynamo_runtime::component::{Component, Instance}; use dynamo_runtime::traits::events::EventPublisher; use rand::Rng; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; -use tokio::sync::Mutex; +use super::indexer::OverlapScores; use super::protocols::WorkerSelectionResult; +use super::sequence::ActiveSequencesMultiWorker; +use super::KvRouterConfig; use super::WorkerSelector; -use crate::kv_router::indexer::OverlapScores; -use crate::kv_router::protocols::LoadMetrics; -use crate::kv_router::sequence::ActiveSequencesMultiWorker; -use crate::kv_router::KvRouterConfig; -use crate::kv_router::KV_HIT_RATE_SUBJECT; +use super::KV_HIT_RATE_SUBJECT; + use crate::tokens::SequenceHash; -use dynamo_runtime::component::Instance; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct KVHitRateEvent { @@ -51,61 +37,44 @@ pub enum KvSchedulerError { SubscriberShutdown, } -/// [gluo FIXME] exactly the same as EndpointInfo except that 'data' -/// is cleaned (not optional) -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub struct Endpoint { - pub name: String, - pub subject: String, - pub data: LoadMetrics, -} - -impl Endpoint { - pub fn worker_id(&self) -> i64 { - i64::from_str_radix( - self.subject - .split("-") - .last() - .expect("invalid subject") - .to_string() - .as_str(), - 16, - ) - .expect("invalid worker id") - } -} - #[derive(Debug)] pub struct SchedulingResponse { pub best_worker_id: i64, - pub overlap_blocks: u32, // Add this field - pub endpoints_changed: Option>, + pub overlap_blocks: u32, } pub struct SchedulingRequest { + pub request_id: String, + pub token_seq: Vec, pub isl_tokens: usize, pub overlaps: OverlapScores, - pub potential_blocks: HashMap, - pub potential_tokens: HashMap, - resp_tx: tokio::sync::oneshot::Sender, + pub decode_blocks: HashMap, + pub prefill_tokens: HashMap, + resp_tx: Option>, // Changed to Option } impl SchedulingRequest { - pub fn respond(self, response: SchedulingResponse) { - if self.resp_tx.send(response).is_err() { - tracing::error!("failed to send response to requestor"); + pub fn respond(&mut self, response: SchedulingResponse) { + // Changed to &mut self + if let Some(tx) = self.resp_tx.take() { + // Use take() to extract the sender + if tx.send(response).is_err() { + tracing::error!("failed to send response to requestor"); + } + } else { + tracing::error!("respond called multiple times on same request"); } } } pub struct KvScheduler { request_tx: tokio::sync::mpsc::Sender, - sequences: Arc>, + slots: Arc, } impl KvScheduler { pub async fn start( - ns: Namespace, + component: Component, block_size: u32, mut instances_rx: tokio::sync::watch::Receiver>, // Changed from ProcessedEndpoints selector: Option>, @@ -113,93 +82,104 @@ impl KvScheduler { let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default())); let mut instances: Vec = instances_rx.borrow_and_update().clone(); - // Get worker IDs from instances - let worker_ids: Vec = instances.iter().map(|i| i.instance_id).collect(); - let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel::(); + let ns_clone = component.namespace().clone(); tokio::spawn(async move { let mut event_rx = event_rx; while let Some(event) = event_rx.recv().await { - if let Err(e) = ns.publish(KV_HIT_RATE_SUBJECT, &event).await { + if let Err(e) = ns_clone.publish(KV_HIT_RATE_SUBJECT, &event).await { tracing::warn!("Failed to publish KV hit rate event: {:?}", e); } } }); - let sequences = Arc::new(Mutex::new(ActiveSequencesMultiWorker::new( + let worker_ids: Vec = instances + .iter() + .map(|instance| instance.instance_id) + .collect(); + let slots = Arc::new(ActiveSequencesMultiWorker::new( + component, block_size as usize, worker_ids, - ))); + )); - // Channel to accept new scheduling requests + let slots_clone = slots.clone(); let (request_tx, request_rx) = tokio::sync::mpsc::channel::(1024); // Background task to handle scheduling requests tokio::spawn(async move { - let mut request: SchedulingRequest; let mut request_rx = request_rx; - let mut pending_endpoint_update: Option> = None; tracing::trace!("scheduler background task started"); - 'outer: loop { - request = tokio::select! { - biased; + loop { + // First, check for instance updates (non-blocking) + if instances_rx.has_changed().unwrap_or(false) { + instances = instances_rx.borrow_and_update().clone(); + let worker_ids: Vec = instances + .iter() + .map(|instance| instance.instance_id) + .collect(); + slots_clone.update_workers(worker_ids); + } - _ = instances_rx.changed() => { - instances = instances_rx.borrow_and_update().clone(); - let worker_ids: Vec = instances.iter().map(|i| i.instance_id).collect(); - pending_endpoint_update = Some(worker_ids); - continue 'outer; - } + // Then, wait for a new request + let Some(mut request) = request_rx.recv().await else { + tracing::warn!("scheduler shutdown"); + break; + }; + tracing::trace!("received request to be scheduled"); + + let (decode_blocks, prefill_tokens) = slots_clone + .potential_blocks_and_tokens( + request.token_seq.clone(), + request.isl_tokens, + request.overlaps.clone(), + ) + .await; + request.decode_blocks = decode_blocks; + request.prefill_tokens = prefill_tokens; + + match selector.select_worker(&instances, &request, block_size) { + Ok(selection) => { + if let Err(e) = event_tx.send(KVHitRateEvent { + worker_id: selection.worker_id, + isl_blocks: selection.required_blocks as usize, + overlap_blocks: selection.overlap_blocks, + }) { + tracing::warn!("Failed to send KV hit rate event: {:?}", e); + } - maybe_new_request = request_rx.recv() => { - let Some(new_request) = maybe_new_request else { - tracing::warn!("scheduler shutdown"); - break 'outer; + let response = SchedulingResponse { + best_worker_id: selection.worker_id, + overlap_blocks: selection.overlap_blocks, }; - tracing::trace!("received request to be scheduled"); - new_request + request.respond(response); + + let _ = slots_clone + .add_request( + request.request_id, + request.token_seq, + request.isl_tokens, + selection.overlap_blocks, + selection.worker_id, + ) + .await; + + continue; } - }; - - loop { - // When calling selector.select_worker, we need to adapt - match selector.select_worker(&instances, &request, block_size) { - Ok(selection) => { - if let Err(e) = event_tx.send(KVHitRateEvent { - worker_id: selection.worker_id, - isl_blocks: selection.required_blocks as usize, - overlap_blocks: selection.overlap_blocks, - }) { - tracing::warn!("Failed to send KV hit rate event: {:?}", e); - } - - let response = SchedulingResponse { - best_worker_id: selection.worker_id, - overlap_blocks: selection.overlap_blocks, - endpoints_changed: pending_endpoint_update.take(), - }; - request.respond(response); - continue 'outer; - } - Err(KvSchedulerError::NoEndpoints) => { - tracing::trace!("no endpoints available; waiting for endpoints update"); - instances_rx.changed().await.ok(); - instances = instances_rx.borrow_and_update().clone(); - let worker_ids: Vec = - instances.iter().map(|i| i.instance_id).collect(); - pending_endpoint_update = Some(worker_ids); - continue; - } - // TODO: this is not actually hooked up - Err(KvSchedulerError::AllWorkersBusy) => { - tracing::trace!("all workers busy; waiting for more capacity"); - tokio::time::sleep(Duration::from_millis(5)).await; - continue; - } - Err(e) => { - tracing::error!("error scheduling request: {:?}", e); - break 'outer; - } + Err(KvSchedulerError::NoEndpoints) => { + tracing::trace!("no endpoints available; waiting for endpoints update"); + tokio::time::sleep(Duration::from_millis(5)).await; + continue; + } + // TODO: this is not actually hooked up + Err(KvSchedulerError::AllWorkersBusy) => { + tracing::trace!("all workers busy; waiting for more capacity"); + tokio::time::sleep(Duration::from_millis(5)).await; + continue; + } + Err(e) => { + tracing::error!("error scheduling request: {:?}", e); + break; } } } @@ -207,10 +187,7 @@ impl KvScheduler { tracing::trace!("background endpoint subscriber shutting down"); }); - Ok(KvScheduler { - request_tx, - sequences, - }) + Ok(KvScheduler { request_tx, slots }) } pub async fn schedule( @@ -220,19 +197,17 @@ impl KvScheduler { token_seq: Vec, overlaps: OverlapScores, ) -> Result { - let mut sequences = self.sequences.lock().await; - - let (potential_blocks, potential_tokens) = - sequences.potential_blocks_and_tokens(token_seq.clone(), isl_tokens, overlaps.clone()); - let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); let request = SchedulingRequest { + request_id, + token_seq, isl_tokens, overlaps, - potential_blocks, - potential_tokens, - resp_tx, + decode_blocks: HashMap::new(), + prefill_tokens: HashMap::new(), + resp_tx: Some(resp_tx), // Wrap in Some() }; + self.request_tx .send(request) .await @@ -241,30 +216,19 @@ impl KvScheduler { .await .map_err(|_| KvSchedulerError::SubscriberShutdown)?; - if let Some(new_worker_ids) = response.endpoints_changed { - sequences.update_workers(new_worker_ids); - } - - sequences.add_request( - request_id, - token_seq, - isl_tokens, - response.overlap_blocks, - response.best_worker_id, - ); - - Ok(response.best_worker_id) + let best_worker_id = response.best_worker_id; + Ok(best_worker_id) } - pub async fn mark_prefill_completed(&self, request_id: &String) { - let mut sequences = self.sequences.lock().await; - sequences.mark_prefill_completed(request_id) + pub async fn mark_prefill_completed(&self, request_id: &str) { + let _ = self + .slots + .mark_prefill_completed(&request_id.to_string()) + .await; } - /// Free all blocks associated with a request - pub async fn free(&self, request_id: &String) { - let mut sequences = self.sequences.lock().await; - sequences.free(request_id) + pub async fn free(&self, request_id: &str) { + let _ = self.slots.free(&request_id.to_string()).await; } } @@ -307,8 +271,9 @@ fn softmax_sample(logits: &HashMap, temperature: f64) -> i64 { let normalized: Vec<_> = values .iter() .map(|&v| { - let norm = v / (max_val - min_val); // Lower is better, so negate + // Note we don't need to do actual min-max norm here, just off by an offset + let norm = v / (max_val - min_val); -norm }) .collect(); @@ -370,10 +335,8 @@ impl WorkerSelector for DefaultWorkerSelector { let request_blocks = isl.div_ceil(block_size as usize); let overlaps = &request.overlaps.scores; - // active blocks for decoding - let potential_active_blocks = &request.potential_blocks; - // active tokens in the batch (processed by the linear layers), mostly prefill tokens - let potential_active_tokens = &request.potential_tokens; + let decode_blocks = &request.decode_blocks; + let prefill_tokens = &request.prefill_tokens; let mut worker_logits = HashMap::new(); let mut max_logit = f64::NEG_INFINITY; @@ -381,52 +344,39 @@ impl WorkerSelector for DefaultWorkerSelector { // Calculate logits for each worker for instance in workers.iter() { let worker_id = instance.instance_id; - // this is the number of tokens each worker would have if the request were scheduled there - let potential_tokens = *potential_active_tokens.get(&worker_id).unwrap_or_else(|| { - tracing::warn!( - "assuming {isl} tokens for {worker_id}, as the endpoint does not exist yet" - ); - &isl - }) as f64; - - // this is the number of blocks each worker would have if the request were scheduled there - let potential_blocks = *potential_active_blocks.get(&worker_id).unwrap_or_else(|| - {tracing::warn!("assuming {request_blocks} decoding blocks for {worker_id}, as the endpoint does not exist yet"); - &request_blocks - }) as f64; - - let potential_prefill_blocks = potential_tokens / (block_size as f64); + let overlap = *overlaps.get(&worker_id).unwrap_or(&0); + + // this is the number of prefill tokens the worker would have if the request were scheduled there + let prefill_token = *prefill_tokens.get(&worker_id).unwrap_or(&isl); + let potential_prefill_block = (prefill_token as f64) / (block_size as f64); + + // this is the number of decode blocks the worker would have if the request were scheduled there + let decode_block = *decode_blocks + .get(&worker_id) + .unwrap_or(&(potential_prefill_block.floor() as usize)) + as f64; // Calculate logit (lower is better) - let logit = self.kv_router_config.overlap_score_weight * potential_prefill_blocks - + potential_blocks; + let logit = + self.kv_router_config.overlap_score_weight * potential_prefill_block + decode_block; max_logit = max_logit.max(logit); worker_logits.insert(worker_id, logit); tracing::info!( - "Formula for {worker_id}: {logit:.3} = {:.1} * {potential_prefill_blocks:.3} + {potential_blocks:.3} (cached_blocks: {})", + "Formula for {worker_id}: {logit:.3} = {:.1} * {potential_prefill_block:.3} + {decode_block:.3} (cached_blocks: {})", self.kv_router_config.overlap_score_weight, - overlaps.get(&worker_id).unwrap_or(&0), + overlap, ); } - // Normalize by dividing by max value - if max_logit > 0.0 { - for logit in worker_logits.values_mut() { - *logit /= max_logit; - } - } - // Use softmax sampling to select worker let temperature = self.kv_router_config.router_temperature; let best_worker_id = softmax_sample(&worker_logits, temperature); - - let overlap_blocks = overlaps.get(&best_worker_id).copied().unwrap_or(0); let best_logit = worker_logits[&best_worker_id]; tracing::info!( - "Selected worker: {}, normalized logit: {:.3}", + "Selected worker: {}, logit: {:.3}", best_worker_id, best_logit ); @@ -434,7 +384,7 @@ impl WorkerSelector for DefaultWorkerSelector { Ok(WorkerSelectionResult { worker_id: best_worker_id, required_blocks: request_blocks as u64, - overlap_blocks, + overlap_blocks: overlaps.get(&best_worker_id).copied().unwrap_or(0), }) } } diff --git a/lib/llm/src/kv_router/scoring.rs b/lib/llm/src/kv_router/scoring.rs index c065a613ba..5934716c7c 100644 --- a/lib/llm/src/kv_router/scoring.rs +++ b/lib/llm/src/kv_router/scoring.rs @@ -15,10 +15,39 @@ //! Scoring functions for the KV router. +use super::protocols::{ForwardPassMetrics, LoadMetrics}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; -use crate::kv_router::scheduler::Endpoint; +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct LoadEvent { + pub worker_id: i64, + pub data: ForwardPassMetrics, +} + +/// [gluo FIXME] exactly the same as EndpointInfo except that 'data' +/// is cleaned (not optional) +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct Endpoint { + pub name: String, + pub subject: String, + pub data: LoadMetrics, +} + +impl Endpoint { + pub fn worker_id(&self) -> i64 { + i64::from_str_radix( + self.subject + .split("-") + .last() + .expect("invalid subject") + .to_string() + .as_str(), + 16, + ) + .expect("invalid worker id") + } +} #[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq)] pub struct ProcessedEndpoints { diff --git a/lib/llm/src/kv_router/sequence.rs b/lib/llm/src/kv_router/sequence.rs index 060eaf1ed3..9460ff48e8 100644 --- a/lib/llm/src/kv_router/sequence.rs +++ b/lib/llm/src/kv_router/sequence.rs @@ -1,17 +1,5 @@ // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. //! KV Cache Sequence Management for LLM Inference //! @@ -37,11 +25,20 @@ use crate::kv_router::indexer::OverlapScores; use crate::kv_router::indexer::WorkerId; use crate::tokens::SequenceHash; +use anyhow::Result; +use dashmap::DashMap; use derive_getters::Getters; +use dynamo_runtime::component::Component; +use dynamo_runtime::traits::events::{EventPublisher, EventSubscriber}; +use dynamo_runtime::traits::DistributedRuntimeProvider; +use futures::StreamExt; use std::collections::{HashMap, HashSet}; -use std::sync::{mpsc, Arc}; -use std::thread; -use std::time::Duration; +use std::sync::Arc; +use uuid::Uuid; + +use super::protocols::{ActiveSequenceEvent, ActiveSequenceEventData}; +use crate::kv_router::ACTIVE_SEQUENCES_SUBJECT; +use dynamo_runtime::CancellationToken; // TODO: use the common request_id if it exists in the repo pub type RequestId = String; @@ -134,7 +131,7 @@ impl ActiveSequences { if let Some(tokens) = self.prefill_tokens.remove(request_id) { self.active_tokens = self .active_tokens - .checked_sub(tokens.saturating_sub(1)) // Keep 1 token for decoding + .checked_sub(tokens) .expect("active_tokens underflow"); } } @@ -171,11 +168,7 @@ impl ActiveSequences { /// Free all blocks associated with a request pub fn free(&mut self, request_id: &RequestId) -> usize { - // decoding has one active token - self.active_tokens = self - .active_tokens - .checked_sub(self.prefill_tokens.remove(request_id).unwrap_or(1)) - .expect("active_tokens < 0"); + self.mark_prefill_completed(request_id); let Some(token_seq) = self.active_seqs.get(request_id) else { tracing::warn!("Trying to free free non-existent request {request_id}"); @@ -207,127 +200,252 @@ enum UpdateSequences { }, NewBlocks { token_sequence: Arc>, - resp_tx: mpsc::SyncSender, + resp_tx: tokio::sync::oneshot::Sender, }, PotentialBlocks { token_sequence: Arc>, - resp_tx: mpsc::SyncSender, + resp_tx: tokio::sync::oneshot::Sender, }, PotentialBlocksAndTokens { token_sequence: Arc>, isl: usize, overlap: u32, - resp_tx: mpsc::SyncSender<(usize, usize)>, + resp_tx: tokio::sync::oneshot::Sender<(usize, usize)>, }, ActiveBlocks { - resp_tx: mpsc::SyncSender, + resp_tx: tokio::sync::oneshot::Sender, }, ActiveTokens { - resp_tx: mpsc::SyncSender, + resp_tx: tokio::sync::oneshot::Sender, }, Shutdown, } /// Multi-worker extension of ActiveSequences that distributes requests across multiple threads pub struct ActiveSequencesMultiWorker { - senders: HashMap>, - request_to_worker: HashMap, - handles: HashMap>, + senders: Arc>>, + request_to_worker: Arc>, + handles: Arc>>, block_size: usize, + component: Component, + router_id: Uuid, } impl ActiveSequencesMultiWorker { - pub fn new(block_size: usize, worker_ids: Vec) -> Self { + pub fn new(component: Component, block_size: usize, worker_ids: Vec) -> Self { assert!(block_size > 1, "block_size must be greater than 1"); - let mut senders = HashMap::new(); - let mut handles = HashMap::new(); + let senders = Arc::new(DashMap::new()); + let handles = Arc::new(DashMap::new()); + let request_to_worker = Arc::new(DashMap::new()); + let router_id = Uuid::new_v4(); for worker_id in worker_ids { - let (sender, handle) = Self::start_worker(block_size); + // Create a child cancellation token from the component's runtime + let cancel_token = component.drt().runtime().child_token(); + let (sender, handle) = Self::start_worker(block_size, cancel_token); senders.insert(worker_id, sender); handles.insert(worker_id, handle); } - Self { - senders, - request_to_worker: HashMap::new(), + let multi_worker = Self { + senders: senders.clone(), + request_to_worker: request_to_worker.clone(), handles, block_size, - } + component: component.clone(), + router_id, + }; + + // Start the subscription loop + let senders_clone = senders.clone(); + let request_to_worker_clone = request_to_worker.clone(); + let component_clone = component.clone(); + let router_id_clone = router_id; + + tokio::spawn(async move { + if let Err(e) = Self::subscribe_to_events( + senders_clone, + request_to_worker_clone, + component_clone, + router_id_clone, + ) + .await + { + tracing::error!("Error in active sequences events subscription: {}", e); + } + }); + + multi_worker } - /// Helper method to start a worker thread - fn start_worker(block_size: usize) -> (mpsc::Sender, thread::JoinHandle<()>) { - let (request_tx, request_rx) = mpsc::channel::(); + /// Helper method to start a worker task + fn start_worker( + block_size: usize, + cancel_token: CancellationToken, // Add cancellation token parameter + ) -> ( + tokio::sync::mpsc::UnboundedSender, + tokio::task::JoinHandle<()>, + ) { + let (request_tx, mut request_rx) = tokio::sync::mpsc::unbounded_channel(); - let handle = thread::spawn(move || { + let handle = tokio::spawn(async move { let mut active_sequences = ActiveSequences::new(block_size); - while let Ok(command) = request_rx.recv() { - match command { - UpdateSequences::AddRequest { - request_id, - token_sequence, - isl, - overlap, - } => { - active_sequences.add_request(request_id, token_sequence, isl, overlap); - } - UpdateSequences::Free { request_id } => { - active_sequences.free(&request_id); - } - UpdateSequences::MarkPrefillCompleted { request_id } => { - active_sequences.mark_prefill_completed(&request_id); + loop { + tokio::select! { + // Handle incoming commands + command = request_rx.recv() => { + match command { + Some(command) => { + match command { + UpdateSequences::AddRequest { + request_id, + token_sequence, + isl, + overlap, + } => { + active_sequences.add_request(request_id, token_sequence, isl, overlap); + } + UpdateSequences::Free { request_id } => { + active_sequences.free(&request_id); + } + UpdateSequences::MarkPrefillCompleted { request_id } => { + active_sequences.mark_prefill_completed(&request_id); + } + UpdateSequences::NewBlocks { + token_sequence, + resp_tx, + } => { + let new_blocks = active_sequences.new_blocks(&token_sequence); + let _ = resp_tx.send(new_blocks); + } + UpdateSequences::PotentialBlocks { + token_sequence, + resp_tx, + } => { + let potential_blocks = active_sequences.potential_blocks(&token_sequence); + let _ = resp_tx.send(potential_blocks); + } + UpdateSequences::PotentialBlocksAndTokens { + token_sequence, + isl, + overlap, + resp_tx, + } => { + let potential_tokens = active_sequences.potential_blocks_and_tokens( + &token_sequence, + isl, + overlap, + ); + let _ = resp_tx.send(potential_tokens); + } + UpdateSequences::ActiveBlocks { resp_tx } => { + let active_blocks = active_sequences.active_blocks(); + let _ = resp_tx.send(active_blocks); + } + UpdateSequences::ActiveTokens { resp_tx } => { + let active_tokens = active_sequences.active_tokens(); + let _ = resp_tx.send(active_tokens); + } + UpdateSequences::Shutdown => { + break; + } + } + } + None => { + // Channel closed, exit + break; + } + } } - UpdateSequences::NewBlocks { - token_sequence, - resp_tx, - } => { - let new_blocks = active_sequences.new_blocks(&token_sequence); - let _ = resp_tx.send(new_blocks); - } - UpdateSequences::PotentialBlocks { - token_sequence, - resp_tx, - } => { - let potential_blocks = active_sequences.potential_blocks(&token_sequence); - let _ = resp_tx.send(potential_blocks); + // Handle cancellation + _ = cancel_token.cancelled() => { + tracing::debug!("Worker task cancelled"); + break; } - UpdateSequences::PotentialBlocksAndTokens { - token_sequence, - isl, - overlap, - resp_tx, - } => { - let potential_tokens = active_sequences.potential_blocks_and_tokens( - &token_sequence, - isl, - overlap, + } + } + }); + + (request_tx, handle) + } + + /// Background task to subscribe to active sequence events and update all workers + async fn subscribe_to_events( + senders: Arc>>, + request_to_worker: Arc>, + component: Component, + router_id: Uuid, + ) -> Result<()> { + let mut subscriber = component + .subscribe_with_type::(ACTIVE_SEQUENCES_SUBJECT) + .await?; + + while let Some(result) = subscriber.next().await { + let Ok(event) = result else { + tracing::error!( + "Error receiving active sequence event: {}", + result.unwrap_err() + ); + continue; + }; + + // Skip events emitted by itself + if event.router_id == router_id { + continue; + } + + match &event.data { + ActiveSequenceEventData::AddRequest { + token_sequence, + isl, + overlap, + } => { + request_to_worker.insert(event.request_id.clone(), event.worker_id); + + if let Some(sender) = senders.get(&event.worker_id) { + let _ = sender.send(UpdateSequences::AddRequest { + request_id: event.request_id.clone(), + token_sequence: token_sequence.clone(), + isl: *isl, + overlap: *overlap, + }); + } else { + tracing::warn!( + "Worker {} not found, cannot process AddRequest", + event.worker_id ); - let _ = resp_tx.send(potential_tokens); - } - UpdateSequences::ActiveBlocks { resp_tx } => { - let active_blocks = active_sequences.active_blocks(); - let _ = resp_tx.send(active_blocks); } - UpdateSequences::ActiveTokens { resp_tx } => { - let active_tokens = active_sequences.active_tokens(); - let _ = resp_tx.send(active_tokens); + } + ActiveSequenceEventData::Free => { + if let Some((_, worker_id)) = request_to_worker.remove(&event.request_id) { + if let Some(sender) = senders.get(&worker_id) { + let _ = sender.send(UpdateSequences::Free { + request_id: event.request_id.clone(), + }); + } } - UpdateSequences::Shutdown => { - break; + } + ActiveSequenceEventData::MarkPrefillCompleted => { + if let Some(worker_id) = request_to_worker.get(&event.request_id) { + if let Some(sender) = senders.get(&*worker_id) { + let _ = sender.send(UpdateSequences::MarkPrefillCompleted { + request_id: event.request_id.clone(), + }); + } } } } - }); + } - (request_tx, handle) + Ok(()) } /// Update the set of workers, adding and removing as needed - pub fn update_workers(&mut self, new_worker_ids: Vec) -> HashMap { - let current_workers: HashSet = self.senders.keys().copied().collect(); + pub fn update_workers(&self, new_worker_ids: Vec) { + let current_workers: HashSet = + self.senders.iter().map(|entry| *entry.key()).collect(); let new_workers: HashSet = new_worker_ids.into_iter().collect(); let workers_to_remove: Vec = @@ -340,11 +458,11 @@ impl ActiveSequencesMultiWorker { tracing::warn!("Removing worker {}", worker_id); // Send shutdown command to the worker - if let Some(sender) = self.senders.remove(worker_id) { + if let Some((_, sender)) = self.senders.remove(worker_id) { let _ = sender.send(UpdateSequences::Shutdown); } - if let Some(handle) = self.handles.remove(worker_id) { - let _ = handle.join(); + if let Some((_, handle)) = self.handles.remove(worker_id) { + handle.abort(); } } @@ -352,68 +470,122 @@ impl ActiveSequencesMultiWorker { for worker_id in &workers_to_add { tracing::warn!("Adding worker {}", worker_id); - let (sender, handle) = Self::start_worker(self.block_size); + let (sender, handle) = Self::start_worker( + self.block_size, + self.component.drt().runtime().child_token(), + ); self.senders.insert(*worker_id, sender); self.handles.insert(*worker_id, handle); } - - // Return active blocks for all workers - self.active_blocks() } - pub fn add_request( - &mut self, + pub async fn add_request( + &self, request_id: RequestId, token_sequence: Vec, isl: usize, overlap: u32, worker_id: WorkerId, - ) { + ) -> Result<()> { if !self.senders.contains_key(&worker_id) { - panic!("Worker ID {worker_id} not found"); + return Err(anyhow::anyhow!("Worker ID {worker_id} not found")); } + // Publish event + let event = ActiveSequenceEvent { + request_id: request_id.clone(), + worker_id, + data: ActiveSequenceEventData::AddRequest { + token_sequence: token_sequence.clone(), + isl, + overlap, + }, + router_id: self.router_id, + }; + self.component + .publish(ACTIVE_SEQUENCES_SUBJECT, &event) + .await?; + + // Update local state self.request_to_worker.insert(request_id.clone(), worker_id); - self.senders[&worker_id] + self.senders + .get(&worker_id) + .unwrap() .send(UpdateSequences::AddRequest { request_id, token_sequence, isl, overlap, }) - .expect("Failed to send add_request command to worker"); + .map_err(|_| anyhow::anyhow!("Failed to send add_request command to worker"))?; + + Ok(()) } - pub fn free(&mut self, request_id: &RequestId) { + pub async fn free(&self, request_id: &RequestId) -> Result<()> { let worker_id = self .request_to_worker .get(request_id) - .copied() - .expect("Request ID not found in request_to_worker mapping"); - - self.senders[&worker_id] + .map(|entry| *entry) + .ok_or_else(|| anyhow::anyhow!("Request ID not found in request_to_worker mapping"))?; + + // Publish event + let event = ActiveSequenceEvent { + request_id: request_id.clone(), + worker_id, + data: ActiveSequenceEventData::Free, + router_id: self.router_id, + }; + self.component + .publish(ACTIVE_SEQUENCES_SUBJECT, &event) + .await?; + + // Update local state + self.senders + .get(&worker_id) + .unwrap() .send(UpdateSequences::Free { request_id: request_id.clone(), }) - .expect("Failed to send free command to worker"); + .map_err(|_| anyhow::anyhow!("Failed to send free command to worker"))?; self.request_to_worker.remove(request_id); + + Ok(()) } /// Mark prefill as completed for a request - pub fn mark_prefill_completed(&mut self, request_id: &RequestId) { + pub async fn mark_prefill_completed(&self, request_id: &RequestId) -> Result<()> { let worker_id = self .request_to_worker .get(request_id) - .copied() - .expect("Request ID not found in request_to_worker mapping"); - - self.senders[&worker_id] + .map(|entry| *entry) + .ok_or_else(|| anyhow::anyhow!("Request ID not found in request_to_worker mapping"))?; + + // Publish event + let event = ActiveSequenceEvent { + request_id: request_id.clone(), + worker_id, + data: ActiveSequenceEventData::MarkPrefillCompleted, + router_id: self.router_id, + }; + self.component + .publish(ACTIVE_SEQUENCES_SUBJECT, &event) + .await?; + + // Update local state + self.senders + .get(&worker_id) + .unwrap() .send(UpdateSequences::MarkPrefillCompleted { request_id: request_id.clone(), }) - .expect("Failed to send mark_prefill_completed command to worker"); + .map_err(|_| { + anyhow::anyhow!("Failed to send mark_prefill_completed command to worker") + })?; + + Ok(()) } /// Get the number of workers @@ -422,37 +594,49 @@ impl ActiveSequencesMultiWorker { } /// Generic method to query all workers with a given command - fn query_workers( + async fn query_workers( &self, token_sequence: Option>, - command_fn: impl Fn(Option>>, mpsc::SyncSender) -> UpdateSequences, - ) -> HashMap { + command_fn: impl Fn( + Option>>, + tokio::sync::oneshot::Sender, + ) -> UpdateSequences, + ) -> HashMap { let mut results = HashMap::new(); let token_sequence_shared = token_sequence.map(Arc::new); let mut receivers = Vec::new(); // Send queries to all workers in parallel - for (worker_id, sender) in &self.senders { - let (resp_tx, resp_rx) = mpsc::sync_channel(0); + for entry in self.senders.iter() { + let worker_id = *entry.key(); + let sender = entry.value(); + let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); receivers.push((worker_id, resp_rx)); - sender - .send(command_fn(token_sequence_shared.clone(), resp_tx)) - .expect("Failed to send command to worker"); + if let Err(e) = sender.send(command_fn(token_sequence_shared.clone(), resp_tx)) { + tracing::error!("Failed to send command to worker {}: {}", worker_id, e); + } } // Collect results from all workers for (worker_id, receiver) in receivers { - let result = receiver - .recv_timeout(Duration::from_secs(1)) - .expect("Failed to receive response from worker"); - results.insert(*worker_id, result); + match tokio::time::timeout(tokio::time::Duration::from_secs(1), receiver).await { + Ok(Ok(result)) => { + results.insert(worker_id, result); + } + Ok(Err(_)) => { + tracing::error!("Worker {} dropped response channel", worker_id); + } + Err(_) => { + tracing::error!("Timeout waiting for response from worker {}", worker_id); + } + } } results } /// Query all workers for the number of new blocks that would be added by a token sequence - pub fn new_blocks(&self, token_sequence: Vec) -> HashMap { + pub async fn new_blocks(&self, token_sequence: Vec) -> HashMap { self.query_workers(Some(token_sequence), |ts, resp_tx| match ts { Some(ts) => UpdateSequences::NewBlocks { token_sequence: ts, @@ -460,10 +644,14 @@ impl ActiveSequencesMultiWorker { }, None => unreachable!("token_sequence should always be Some for new_blocks"), }) + .await } /// Query all workers for the total number of blocks (new + active) that would be used by a token sequence - pub fn potential_blocks(&self, token_sequence: Vec) -> HashMap { + pub async fn potential_blocks( + &self, + token_sequence: Vec, + ) -> HashMap { self.query_workers(Some(token_sequence), |ts, resp_tx| match ts { Some(ts) => UpdateSequences::PotentialBlocks { token_sequence: ts, @@ -471,10 +659,11 @@ impl ActiveSequencesMultiWorker { }, None => unreachable!("token_sequence should always be Some for potential_blocks"), }) + .await } /// Query all workers for the potential tokens (new + active) that would be used by a token sequence with overlap - pub fn potential_blocks_and_tokens( + pub async fn potential_blocks_and_tokens( &self, token_sequence: Vec, isl: usize, @@ -486,53 +675,68 @@ impl ActiveSequencesMultiWorker { let mut receivers = Vec::new(); // Send queries to all workers in parallel - for (worker_id, sender) in &self.senders { - let (resp_tx, resp_rx) = mpsc::sync_channel(0); + for entry in self.senders.iter() { + let worker_id = *entry.key(); + let sender = entry.value(); + let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); receivers.push((worker_id, resp_rx)); - sender - .send(UpdateSequences::PotentialBlocksAndTokens { - token_sequence: token_sequence_shared.clone(), - isl, - overlap: overlaps.scores.get(worker_id).copied().unwrap_or(0), - resp_tx, - }) - .expect("Failed to send potential_tokens command to worker"); + if let Err(e) = sender.send(UpdateSequences::PotentialBlocksAndTokens { + token_sequence: token_sequence_shared.clone(), + isl, + overlap: overlaps.scores.get(&worker_id).copied().unwrap_or(0), + resp_tx, + }) { + tracing::error!( + "Failed to send potential_tokens command to worker {}: {}", + worker_id, + e + ); + } } // Collect results from all workers for (worker_id, receiver) in receivers { - let (blocks, tokens) = receiver - .recv_timeout(Duration::from_secs(1)) - .expect("Failed to receive response from worker"); - potential_blocks.insert(*worker_id, blocks); - potential_tokens.insert(*worker_id, tokens); + match tokio::time::timeout(tokio::time::Duration::from_secs(1), receiver).await { + Ok(Ok((blocks, tokens))) => { + potential_blocks.insert(worker_id, blocks); + potential_tokens.insert(worker_id, tokens); + } + Ok(Err(_)) => { + tracing::error!("Worker {} dropped response channel", worker_id); + } + Err(_) => { + tracing::error!("Timeout waiting for response from worker {}", worker_id); + } + } } (potential_blocks, potential_tokens) } /// Query all workers for their current number of active blocks - pub fn active_blocks(&self) -> HashMap { + pub async fn active_blocks(&self) -> HashMap { self.query_workers(None, |_, resp_tx| UpdateSequences::ActiveBlocks { resp_tx }) + .await } /// Query all workers for their current number of active tokens - pub fn active_tokens(&self) -> HashMap { + pub async fn active_tokens(&self) -> HashMap { self.query_workers(None, |_, resp_tx| UpdateSequences::ActiveTokens { resp_tx }) + .await } } impl Drop for ActiveSequencesMultiWorker { fn drop(&mut self) { - // Send shutdown command to all workers - for sender in self.senders.values() { - let _ = sender.send(UpdateSequences::Shutdown); + // Send shutdown to all workers + for entry in self.senders.iter() { + let _ = entry.value().send(UpdateSequences::Shutdown); } - // Wait for all threads to finish - for (_, handle) in self.handles.drain() { - let _ = handle.join(); + // Abort all tasks + for entry in self.handles.iter() { + entry.value().abort(); } } } @@ -540,44 +744,248 @@ impl Drop for ActiveSequencesMultiWorker { #[cfg(test)] mod tests { use super::*; + use dynamo_runtime::{DistributedRuntime, Runtime}; + use std::sync::{Arc, Mutex}; + use std::thread; #[test] - fn test_multi_worker_block_sharing() { - // Create multi-worker sequence manager with 3 workers + #[ignore] + fn test_multi_worker_block_sharing() -> Result<()> { + // Initialize logging once + dynamo_runtime::logging::init(); + let block_size = 4; // arbitrary block size - let worker_ids = vec![0, 1, 2]; - let mut seq_manager = ActiveSequencesMultiWorker::new(block_size, worker_ids); - - // Add requests to each worker - // Worker 0: sequence [0, 1, 2] - seq_manager.add_request( - "request_0".to_string(), - vec![0, 1, 2], - 12, // ISL (3 blocks * 4 block_size) - 0, // no overlap - 0, // worker_id - ); - // Worker 1: sequence [3, 4] - seq_manager.add_request( - "request_1".to_string(), - vec![3, 4], - 8, // ISL (2 blocks * 4 block_size) - 0, // no overlap - 1, // worker_id - ); + // Shared state for collecting results from both threads + let active_tokens_after_add = Arc::new(Mutex::new(HashMap::new())); + let potential_blocks_result = Arc::new(Mutex::new(HashMap::new())); + let active_blocks_after_free = Arc::new(Mutex::new(HashMap::new())); + let active_tokens_after_free = Arc::new(Mutex::new(HashMap::new())); + + let active_tokens_after_add_clone = active_tokens_after_add.clone(); + let potential_blocks_result_clone = potential_blocks_result.clone(); + let active_blocks_after_free_clone = active_blocks_after_free.clone(); + let active_tokens_after_free_clone = active_tokens_after_free.clone(); + + // Clone again for the second thread + let active_tokens_after_add_clone2 = active_tokens_after_add.clone(); + let potential_blocks_result_clone2 = potential_blocks_result.clone(); + let active_blocks_after_free_clone2 = active_blocks_after_free.clone(); + let active_tokens_after_free_clone2 = active_tokens_after_free.clone(); + + // Thread 1: First runtime with workers 0 and 1 + let handle1 = thread::spawn(move || { + let rt = tokio::runtime::Runtime::new().unwrap(); + + rt.block_on(async { + // Create runtime and distributed runtime + let runtime = Runtime::from_current()?; + let distributed = DistributedRuntime::from_settings(runtime.clone()).await?; + + // Create namespace and component with same names as thread 2 + let namespace = distributed.namespace("test_multiworker_sequences")?; + let component = namespace + .component("sequences")? + .service_builder() + .create() + .await?; + + // Create multi-worker sequence manager with workers 0 and 1 + let worker_ids = vec![0, 1]; + let seq_manager = + ActiveSequencesMultiWorker::new(component, block_size, worker_ids); + + // Give some time for the subscription loop to start + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Add requests to workers + // Worker 0: sequence [0, 1, 2] + seq_manager + .add_request( + "request_0".to_string(), + vec![0, 1, 2], + 12, // ISL (3 blocks * 4 block_size) + 0, // no overlap + 0, // worker_id + ) + .await?; + + // Worker 1: sequence [3, 4] + seq_manager + .add_request( + "request_1".to_string(), + vec![3, 4], + 8, // ISL (2 blocks * 4 block_size) + 0, // no overlap + 1, // worker_id + ) + .await?; + + // Give some time for the commands to be processed and synchronization + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Get active tokens from workers 0 and 1 + let tokens = seq_manager.active_tokens().await; + active_tokens_after_add_clone + .lock() + .unwrap() + .insert(0, tokens.get(&0).copied().unwrap_or(0)); + active_tokens_after_add_clone + .lock() + .unwrap() + .insert(1, tokens.get(&1).copied().unwrap_or(0)); + + // Test potential blocks for sequence [0, 1] + let potential = seq_manager.potential_blocks(vec![0, 1]).await; + potential_blocks_result_clone + .lock() + .unwrap() + .insert(0, potential.get(&0).copied().unwrap_or(0)); + potential_blocks_result_clone + .lock() + .unwrap() + .insert(1, potential.get(&1).copied().unwrap_or(0)); + + // Wait for second thread to process its requests + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Free requests from workers 0 and 1 + seq_manager.free(&"request_0".to_string()).await?; + seq_manager.free(&"request_1".to_string()).await?; + + // Give some time for the commands to be processed + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Get final active blocks and tokens + let blocks = seq_manager.active_blocks().await; + let tokens = seq_manager.active_tokens().await; + + active_blocks_after_free_clone + .lock() + .unwrap() + .insert(0, blocks.get(&0).copied().unwrap_or(0)); + active_blocks_after_free_clone + .lock() + .unwrap() + .insert(1, blocks.get(&1).copied().unwrap_or(0)); + active_tokens_after_free_clone + .lock() + .unwrap() + .insert(0, tokens.get(&0).copied().unwrap_or(0)); + active_tokens_after_free_clone + .lock() + .unwrap() + .insert(1, tokens.get(&1).copied().unwrap_or(0)); + + // Keep runtime alive a bit longer for synchronization + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Shutdown runtime + runtime.shutdown(); + + Ok::<(), anyhow::Error>(()) + }) + }); - // Worker 2: sequence [0, 1, 2, 3] - seq_manager.add_request( - "request_2".to_string(), - vec![0, 1, 2, 3], - 16, // ISL (4 blocks * 4 block_size) - 0, // no overlap - 2, // worker_id - ); + // Thread 2: Second runtime with worker 2 + let handle2 = thread::spawn(move || { + let rt = tokio::runtime::Runtime::new().unwrap(); + + rt.block_on(async { + // Create runtime and distributed runtime + let runtime = Runtime::from_current()?; + let distributed = DistributedRuntime::from_settings(runtime.clone()).await?; + + // Create namespace and component with same names as thread 1 + let namespace = distributed.namespace("test_multiworker_sequences")?; + let component = namespace + .component("sequences")? + .service_builder() + .create() + .await?; + + // Create multi-worker sequence manager with worker 2 + let worker_ids = vec![2]; + let seq_manager = + ActiveSequencesMultiWorker::new(component, block_size, worker_ids); + + // Give some time for the subscription loop to start + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Wait a bit to ensure thread 1 has started + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Worker 2: sequence [0, 1, 2, 3] + seq_manager + .add_request( + "request_2".to_string(), + vec![0, 1, 2, 3], + 16, // ISL (4 blocks * 4 block_size) + 0, // no overlap + 2, // worker_id + ) + .await?; + + // Give some time for the commands to be processed and synchronization + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Get active tokens from worker 2 + let tokens = seq_manager.active_tokens().await; + active_tokens_after_add_clone2 + .lock() + .unwrap() + .insert(2, tokens.get(&2).copied().unwrap_or(0)); + + // Test potential blocks for sequence [0, 1] + let potential = seq_manager.potential_blocks(vec![0, 1]).await; + potential_blocks_result_clone2 + .lock() + .unwrap() + .insert(2, potential.get(&2).copied().unwrap_or(0)); + + // Wait for first thread to free its requests + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Free request from worker 2 + seq_manager.free(&"request_2".to_string()).await?; + + // Give some time for the commands to be processed + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Get final active blocks and tokens + let blocks = seq_manager.active_blocks().await; + let tokens = seq_manager.active_tokens().await; + + active_blocks_after_free_clone2 + .lock() + .unwrap() + .insert(2, blocks.get(&2).copied().unwrap_or(0)); + active_tokens_after_free_clone2 + .lock() + .unwrap() + .insert(2, tokens.get(&2).copied().unwrap_or(0)); + + // Keep runtime alive a bit longer for synchronization + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Shutdown runtime + runtime.shutdown(); + + Ok::<(), anyhow::Error>(()) + }) + }); + + // Wait for both threads to complete + handle1.join().unwrap()?; + handle2.join().unwrap()?; + + // Extract results + let tokens_after_add = active_tokens_after_add.lock().unwrap(); + let potential_blocks = potential_blocks_result.lock().unwrap(); + let blocks_after_free = active_blocks_after_free.lock().unwrap(); + let tokens_after_free = active_tokens_after_free.lock().unwrap(); // Verify active tokens after adding requests - let tokens_after_add = seq_manager.active_tokens(); assert_eq!( tokens_after_add[&0], 12, "Worker 0 should have 12 active tokens" @@ -592,8 +1000,6 @@ mod tests { ); // Test potential blocks for sequence [0, 1] - let potential_blocks = seq_manager.potential_blocks(vec![0, 1]); - // Worker 0 should return 3 (already has blocks 0, 1, 2, so no new blocks needed for [0, 1]) assert_eq!( potential_blocks[&0], 3, @@ -612,30 +1018,34 @@ mod tests { "Worker 2 should have 4 potential blocks" ); - // Free all original requests - seq_manager.free(&"request_0".to_string()); - seq_manager.free(&"request_1".to_string()); - seq_manager.free(&"request_2".to_string()); - // Verify active blocks are zero for all workers - let active_blocks = seq_manager.active_blocks(); - assert_eq!(active_blocks[&0], 0, "Worker 0 should have 0 active blocks"); - assert_eq!(active_blocks[&1], 0, "Worker 1 should have 0 active blocks"); - assert_eq!(active_blocks[&2], 0, "Worker 2 should have 0 active blocks"); + assert_eq!( + blocks_after_free[&0], 0, + "Worker 0 should have 0 active blocks" + ); + assert_eq!( + blocks_after_free[&1], 0, + "Worker 1 should have 0 active blocks" + ); + assert_eq!( + blocks_after_free[&2], 0, + "Worker 2 should have 0 active blocks" + ); // Verify active tokens are zero for all workers - let final_tokens = seq_manager.active_tokens(); assert_eq!( - final_tokens[&0], 0, + tokens_after_free[&0], 0, "Worker 0 should have 0 active tokens after freeing all" ); assert_eq!( - final_tokens[&1], 0, + tokens_after_free[&1], 0, "Worker 1 should have 0 active tokens after freeing all" ); assert_eq!( - final_tokens[&2], 0, + tokens_after_free[&2], 0, "Worker 2 should have 0 active tokens after freeing all" ); + + Ok(()) } } diff --git a/lib/llm/src/recorder.rs b/lib/llm/src/recorder.rs index cd808b3ee6..67f8e276c7 100644 --- a/lib/llm/src/recorder.rs +++ b/lib/llm/src/recorder.rs @@ -1,17 +1,5 @@ // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. use serde::{Deserialize, Serialize}; use std::io; @@ -386,6 +374,12 @@ where } } +impl Drop for Recorder { + fn drop(&mut self) { + self.cancel.cancel(); + } +} + /// Helper function to create a rotated file path with an index suffix fn create_rotated_path(base_path: &Path, index: usize) -> PathBuf { let path_str = base_path.to_string_lossy();