diff --git a/components/metrics/src/bin/mock_worker.rs b/components/metrics/src/bin/mock_worker.rs index 6278de73ce..a2238ea5b1 100644 --- a/components/metrics/src/bin/mock_worker.rs +++ b/components/metrics/src/bin/mock_worker.rs @@ -14,7 +14,7 @@ // limitations under the License. use dynamo_llm::kv_router::{ - protocols::ForwardPassMetrics, scheduler::KVHitRateEvent, KV_HIT_RATE_SUBJECT, + protocols::ForwardPassMetrics, protocols::KVHitRateEvent, KV_HIT_RATE_SUBJECT, }; use dynamo_runtime::{ component::{service::EndpointStats, Namespace}, @@ -89,7 +89,7 @@ async fn mock_event_publisher(namespace: Namespace) { let overlap_blocks = rand::rng().random_range(0..=isl_blocks); let event = KVHitRateEvent { - worker_id, + worker: worker_id, isl_blocks, overlap_blocks, }; diff --git a/components/metrics/src/lib.rs b/components/metrics/src/lib.rs index b928938490..678db50ec2 100644 --- a/components/metrics/src/lib.rs +++ b/components/metrics/src/lib.rs @@ -84,8 +84,7 @@ use std::net::SocketAddr; use std::time::Duration as StdDuration; use dynamo_llm::kv_router::protocols::ForwardPassMetrics; -use dynamo_llm::kv_router::scheduler::Endpoint; -use dynamo_llm::kv_router::scoring::ProcessedEndpoints; +use dynamo_llm::kv_router::scoring::{Endpoint, ProcessedEndpoints}; use dynamo_runtime::{ distributed::Component, error, service::EndpointInfo, utils::Duration, Result, @@ -451,6 +450,7 @@ impl PrometheusMetrics { let worker_id = worker_id.to_string(); let metrics = endpoint.data.clone(); + // to not change the existing behavior self.set_worker_gauge( &self.kv_blocks_active, config, diff --git a/components/metrics/src/main.rs b/components/metrics/src/main.rs index fa8186d07a..c6a0996280 100644 --- a/components/metrics/src/main.rs +++ b/components/metrics/src/main.rs @@ -27,7 +27,7 @@ //! - ISL Blocks: Cumulative count of total blocks in all KV hit rate events //! - Overlap Blocks: Cumulative count of blocks that were already in the KV cache use clap::Parser; -use dynamo_llm::kv_router::scheduler::KVHitRateEvent; +use dynamo_llm::kv_router::protocols::{KVHitRateEvent, WorkerDp}; use dynamo_llm::kv_router::KV_HIT_RATE_SUBJECT; use dynamo_runtime::{ error, logging, @@ -180,14 +180,15 @@ async fn app(runtime: Runtime) -> Result<()> { tracing::debug!("Successfully subscribed to KV hit rate events"); while let Some(msg) = subscriber.next().await { - match serde_json::from_slice::(&msg.payload) { + match serde_json::from_slice::>(&msg.payload) { Ok(event) => { // TODO: Lower to debug let cache_hit_pct = (event.overlap_blocks as f64 / event.isl_blocks as f64) * 100.0; tracing::debug!( - "Received KV hit rate event: worker_id={}, isl_blocks={}, overlap_blocks={}, cache_hit_pct={:.2}%", - event.worker_id, + "Received KV hit rate event: worker_id={}, dp_rank={}, isl_blocks={}, overlap_blocks={}, cache_hit_pct={:.2}%", + event.worker.worker_id, + event.worker.dp_rank.unwrap_or(0), event.isl_blocks, event.overlap_blocks, cache_hit_pct @@ -197,7 +198,8 @@ async fn app(runtime: Runtime) -> Result<()> { let mut metrics = metrics_collector_clone.lock().await; metrics.update_kv_hit_rate( &config_clone, - event.worker_id, + // TODO: this will not take care of dp ranks + event.worker.worker_id, event.isl_blocks, event.overlap_blocks, ); diff --git a/components/router/src/main.rs b/components/router/src/main.rs index 3546a9bb30..e8e7d08c72 100644 --- a/components/router/src/main.rs +++ b/components/router/src/main.rs @@ -25,7 +25,7 @@ use std::sync::Arc; use clap::Parser; use dynamo_llm::kv_router::{ - protocols::WorkerSelectionResult, + protocols::{WorkerDp, WorkerSelectionResult}, scheduler::{DefaultWorkerSelector, KvSchedulerError, SchedulingRequest}, scoring::ProcessedEndpoints, KvRouter, WorkerSelector, @@ -89,7 +89,7 @@ impl WorkerSelector for CustomWorkerSelector { workers: &ProcessedEndpoints, request: &SchedulingRequest, block_size: usize, - ) -> Result { + ) -> Result, KvSchedulerError> { // customize logic here // F12 into [DefaultWorkerSelector] to see the original logic self.0.select_worker(workers, request, block_size) diff --git a/launch/dynamo-run/src/subprocess/vllm_v1_inc.py b/launch/dynamo-run/src/subprocess/vllm_v1_inc.py index 04732e11f5..3cd5318a34 100644 --- a/launch/dynamo-run/src/subprocess/vllm_v1_inc.py +++ b/launch/dynamo-run/src/subprocess/vllm_v1_inc.py @@ -23,7 +23,7 @@ import uvloop from vllm.config import VllmConfig -from vllm.distributed.kv_events import KVEventsConfig +from vllm.distributed.kv_events import KVEventsConfig, ZmqEventPublisher from vllm.engine.arg_utils import AsyncEngineArgs from vllm.inputs import TokensPrompt from vllm.sampling_params import SamplingParams @@ -68,7 +68,7 @@ class DynamoStatLoggerPublisher(StatLoggerBase): def __init__(self, component: Component, dp_rank: int) -> None: self.inner = WorkerMetricsPublisher() - self.inner.create_endpoint(component) + self.inner.create_endpoint(component, dp_rank=dp_rank) self.dp_rank = dp_rank def record( @@ -246,12 +246,33 @@ async def init(runtime: DistributedRuntime, config: Config): ) logger.info("VllmWorker has been initialized") + base_zmq_endpoint = "tcp://127.0.0.1:5557" + dp_rank_size = vllm_config.parallel_config.data_parallel_size - zmq_config = ZmqKvEventPublisherConfig( - worker_id=endpoint.lease_id(), kv_block_size=engine_args.block_size - ) + # Store references to prevent garbage collection + kv_publishers = [] + + for dp_rank in range(dp_rank_size): + zmq_endpoint = ZmqEventPublisher.offset_endpoint_port( + base_zmq_endpoint, data_parallel_rank=dp_rank + ) + zmq_config = ZmqKvEventPublisherConfig( + worker_id=endpoint.lease_id(), + kv_block_size=engine_args.block_size, + zmq_endpoint=zmq_endpoint, + ) - _ = ZmqKvEventPublisher(component=component, config=zmq_config) + try: + publisher = ZmqKvEventPublisher(component=component, config=zmq_config) + kv_publishers.append(publisher) + except Exception as e: + logger.error( + f"Failed to create ZmqKvEventPublisher for dp_rank {dp_rank}: {e}" + ) + + logger.debug( + f"Successfully created {len(kv_publishers)} ZmqKvEventPublishers out of {dp_rank_size} expected" + ) handler = RequestHandler(component, engine_client, default_sampling_params) @@ -313,7 +334,7 @@ def cmd_line_args(): endpoint_str = args.endpoint.replace("dyn://", "", 1) endpoint_parts = endpoint_str.split(".") if len(endpoint_parts) != 3: - logging.error( + logger.error( f"Invalid endpoint format: '{args.endpoint}'. Expected 'dyn://namespace.component.endpoint' or 'namespace.component.endpoint'." ) sys.exit(1) diff --git a/lib/bindings/c/src/lib.rs b/lib/bindings/c/src/lib.rs index 1c50f4aa8e..742815230d 100644 --- a/lib/bindings/c/src/lib.rs +++ b/lib/bindings/c/src/lib.rs @@ -14,6 +14,7 @@ // limitations under the License. use async_once_cell::OnceCell as AsyncOnceCell; +use dynamo_llm::kv_router::publisher::KvCacheEventWithDp; use libc::c_char; use once_cell::sync::OnceCell; use std::ffi::CStr; @@ -284,7 +285,12 @@ pub unsafe extern "C" fn dynamo_kv_event_publish_stored( }; let publisher = KV_PUB.get().unwrap(); let event = kv_event_create_stored_from_parts(kv_params, publisher.kv_block_size()); - match publisher.publish(event) { + // NOTE: dummy dp_rank for now + let event_with_dp = KvCacheEventWithDp { + kv_cache_event: event, + dp_rank: None, + }; + match publisher.publish(event_with_dp) { Ok(_) => DynamoLlmResult::OK, Err(e) => { eprintln!("Error publishing stored kv event {:?}", e); @@ -301,7 +307,12 @@ pub extern "C" fn dynamo_kv_event_publish_removed( ) -> DynamoLlmResult { let publisher = KV_PUB.get().unwrap(); let event = kv_event_create_removed_from_parts(event_id, block_ids, num_blocks); - match publisher.publish(event) { + // NOTE: dummy dp_rank for now + let event_with_dp = KvCacheEventWithDp { + kv_cache_event: event, + dp_rank: None, + }; + match publisher.publish(event_with_dp) { Ok(_) => DynamoLlmResult::OK, Err(e) => { eprintln!("Error publishing removed kv event {:?}", e); diff --git a/lib/bindings/python/rust/lib.rs b/lib/bindings/python/rust/lib.rs index 39cc1ea46e..da04c5b5bf 100644 --- a/lib/bindings/python/rust/lib.rs +++ b/lib/bindings/python/rust/lib.rs @@ -61,6 +61,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/lib/bindings/python/rust/llm/kv.rs b/lib/bindings/python/rust/llm/kv.rs index 2d7b3d92b5..5b2b655360 100644 --- a/lib/bindings/python/rust/llm/kv.rs +++ b/lib/bindings/python/rust/llm/kv.rs @@ -22,7 +22,34 @@ use rs::traits::events::EventSubscriber; use tracing; use llm_rs::kv_router::protocols::*; -use llm_rs::kv_router::publisher::{create_stored_blocks, KvEventSourceConfig}; +use llm_rs::kv_router::publisher::{create_stored_blocks, KvCacheEventWithDp, KvEventSourceConfig}; + +#[pyclass] +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct WorkerDp { + #[pyo3(get, set)] + pub worker_id: i64, + #[pyo3(get, set)] + pub dp_rank: Option, +} + +#[pymethods] +impl WorkerDp { + #[new] + #[pyo3(signature = (worker_id, dp_rank = None))] + pub fn new(worker_id: i64, dp_rank: Option) -> Self { + Self { worker_id, dp_rank } + } +} + +impl From for WorkerDp { + fn from(value: llm_rs::kv_router::protocols::WorkerDp) -> Self { + Self { + worker_id: value.worker_id, + dp_rank: value.dp_rank, + } + } +} #[pyclass] pub(crate) struct KvRouter { @@ -57,7 +84,7 @@ impl KvRouter { .schedule(&token_ids, lora_id) .await .map_err(to_pyerr)?; - Ok(worker_id) + Ok(WorkerDp::from(worker_id)) }) } } @@ -78,17 +105,21 @@ impl WorkerMetricsPublisher { }) } - #[pyo3(signature = (component))] + #[pyo3(signature = (component, dp_rank = None))] fn create_endpoint<'p>( &self, py: Python<'p>, component: Component, + dp_rank: Option, ) -> PyResult> { let rs_publisher = self.inner.clone(); let rs_component = component.inner.clone(); pyo3_async_runtimes::tokio::future_into_py(py, async move { rs_publisher - .create_endpoint(rs_component) + .create_endpoint( + rs_component, + dp_rank.as_ref().map(|v| v.to_string()).as_deref(), + ) .await .map_err(to_pyerr)?; Ok(()) @@ -107,7 +138,7 @@ impl WorkerMetricsPublisher { num_requests_waiting: u64, gpu_cache_usage_perc: f32, gpu_prefix_cache_hit_rate: f32, - data_parallel_rank: u32, + data_parallel_rank: DpRank, ) -> PyResult<()> { self.inner .publish( @@ -218,7 +249,7 @@ impl KvEventPublisher { } #[allow(clippy::too_many_arguments)] - #[pyo3(signature = (event_id, token_ids, num_block_tokens, block_hashes, lora_id, parent_hash=None))] + #[pyo3(signature = (event_id, token_ids, num_block_tokens, block_hashes, lora_id, parent_hash=None, dp_rank=None))] fn publish_stored( &mut self, _py: Python, @@ -228,6 +259,7 @@ impl KvEventPublisher { block_hashes: Vec, lora_id: u64, parent_hash: Option, + dp_rank: Option, ) -> PyResult<()> { let event = KvCacheEvent { event_id, @@ -243,11 +275,22 @@ impl KvEventPublisher { ), }), }; + let event_with_dp = KvCacheEventWithDp { + kv_cache_event: event, + dp_rank, + }; - self.inner.publish(event).map_err(to_pyerr) + self.inner.publish(event_with_dp).map_err(to_pyerr) } - fn publish_removed(&self, _py: Python, event_id: u64, block_hashes: Vec) -> PyResult<()> { + #[pyo3(signature = (event_id, block_hashes, dp_rank=None))] + fn publish_removed( + &self, + _py: Python, + event_id: u64, + block_hashes: Vec, + dp_rank: Option, + ) -> PyResult<()> { let block_hashes: Vec = block_hashes .iter() .map(|&h| ExternalSequenceBlockHash::from(h)) @@ -256,22 +299,30 @@ impl KvEventPublisher { event_id, data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }), }; + let event_with_dp = KvCacheEventWithDp { + kv_cache_event: event, + dp_rank, + }; - self.inner.publish(event).map_err(to_pyerr) + self.inner.publish(event_with_dp).map_err(to_pyerr) } } #[pyclass] #[derive(Clone)] pub(crate) struct OverlapScores { - inner: llm_rs::kv_router::indexer::OverlapScores, + inner: llm_rs::kv_router::indexer::OverlapScores, } #[pymethods] impl OverlapScores { #[getter] - fn scores(&self) -> HashMap { - self.inner.scores.clone() + fn scores(&self) -> HashMap { + self.inner + .scores + .iter() + .map(|(k, v)| (WorkerDp::from(*k), *v)) + .collect() } #[getter] @@ -282,7 +333,7 @@ impl OverlapScores { #[pyclass] pub(crate) struct KvIndexer { - inner: Arc, + inner: Arc>, } #[pymethods] @@ -291,12 +342,13 @@ impl KvIndexer { fn new(component: Component, kv_block_size: usize) -> PyResult { let runtime = pyo3_async_runtimes::tokio::get_runtime(); runtime.block_on(async { - let inner: Arc = - llm_rs::kv_router::indexer::KvIndexer::new( - component.inner.drt().runtime().child_token(), - kv_block_size, - ) - .into(); + let inner: Arc< + llm_rs::kv_router::indexer::KvIndexer, + > = llm_rs::kv_router::indexer::KvIndexer::new( + component.inner.drt().runtime().child_token(), + kv_block_size, + ) + .into(); // [gluo TODO] try subscribe_with_type::, // error checking below will be different. let mut kv_events_rx = component @@ -310,8 +362,9 @@ impl KvIndexer { // should have been made to a trait and implemented here? i.e. AsyncEngine style tokio::spawn(async move { while let Some(event) = kv_events_rx.next().await { - let event: llm_rs::kv_router::indexer::RouterEvent = - serde_json::from_slice(&event.payload).unwrap(); + let event: llm_rs::kv_router::protocols::RouterEvent< + llm_rs::kv_router::protocols::WorkerDp, + > = serde_json::from_slice(&event.payload).unwrap(); tracing::debug!("received kv event: {:?}", event); if let Err(e) = kv_events_tx.send(event).await { tracing::trace!( @@ -354,6 +407,8 @@ pub(crate) struct EndpointKvMetrics { #[pyo3(get, set)] pub worker_id: i64, #[pyo3(get, set)] + pub dp_rank: Option, + #[pyo3(get, set)] pub request_active_slots: u64, #[pyo3(get, set)] pub request_total_slots: u64, @@ -407,8 +462,9 @@ impl KvMetricsAggregator { let endpoint_kv_metrics = endpoints .endpoints .iter() - .map(|(worker_id, x)| EndpointKvMetrics { - worker_id: *worker_id, + .map(|(worker_dp, x)| EndpointKvMetrics { + worker_id: worker_dp.worker_id, + dp_rank: worker_dp.dp_rank, request_active_slots: x.data.request_active_slots, request_total_slots: x.data.request_total_slots, kv_active_blocks: x.data.kv_active_blocks, @@ -430,7 +486,7 @@ impl KvMetricsAggregator { #[pyclass] pub(crate) struct KvRecorder { - inner: Arc, + inner: Arc>, } #[pymethods] @@ -481,8 +537,9 @@ impl KvRecorder { // Spawn a task to forward events to the recorder tokio::spawn(async move { while let Some(event) = kv_events_rx.next().await { - let event: llm_rs::kv_router::indexer::RouterEvent = - serde_json::from_slice(&event.payload).unwrap(); + let event: llm_rs::kv_router::protocols::RouterEvent< + llm_rs::kv_router::protocols::WorkerDp, + > = serde_json::from_slice(&event.payload).unwrap(); tracing::debug!("KvRecorder received kv event: {:?}", event); if let Err(e) = event_tx.send(event).await { tracing::trace!( diff --git a/lib/bindings/python/src/dynamo/_core.pyi b/lib/bindings/python/src/dynamo/_core.pyi index 424496fe41..cdf021ea97 100644 --- a/lib/bindings/python/src/dynamo/_core.pyi +++ b/lib/bindings/python/src/dynamo/_core.pyi @@ -356,7 +356,7 @@ class WorkerMetricsPublisher: Create a `WorkerMetricsPublisher` object """ - def create_service(self, component: Component) -> None: + def create_endpoint(self, component: Component, dp_rank: int) -> None: """ Similar to Component.create_service, but only service created through this method will interact with KV router of the same component. @@ -420,6 +420,24 @@ class OverlapScores: ... +class WorkerDp: + """ + Worker data parallel information containing worker ID and optional DP rank. + """ + + worker_id: int + dp_rank: Optional[int] + + def __init__(self, worker_id: int, dp_rank: Optional[int] = None) -> None: + """ + Create a WorkerDp instance. + + Args: + worker_id: The worker ID + dp_rank: Optional data parallel rank + """ + ... + class KvIndexer: """ A KV Indexer that tracks KV Events emitted by workers. Events include add_block and remove_block. diff --git a/lib/bindings/python/src/dynamo/llm/__init__.py b/lib/bindings/python/src/dynamo/llm/__init__.py index 673e2da7ae..dbd158a5f1 100644 --- a/lib/bindings/python/src/dynamo/llm/__init__.py +++ b/lib/bindings/python/src/dynamo/llm/__init__.py @@ -33,6 +33,7 @@ from dynamo._core import KvRouter as KvRouter from dynamo._core import ModelType as ModelType from dynamo._core import OverlapScores as OverlapScores +from dynamo._core import WorkerDp as WorkerDp from dynamo._core import WorkerMetricsPublisher as WorkerMetricsPublisher from dynamo._core import ZmqKvEventPublisher as ZmqKvEventPublisher from dynamo._core import ZmqKvEventPublisherConfig as ZmqKvEventPublisherConfig diff --git a/lib/bindings/python/tests/test_kv_bindings.py b/lib/bindings/python/tests/test_kv_bindings.py index d08b1aa808..c4a71d2481 100644 --- a/lib/bindings/python/tests/test_kv_bindings.py +++ b/lib/bindings/python/tests/test_kv_bindings.py @@ -89,13 +89,19 @@ async def test_event_handler(distributed_runtime): await asyncio.sleep(1) scores = await indexer.find_matches_for_request(test_token, lora_id) assert scores.scores - assert worker_id in scores.scores - assert scores.scores[worker_id] == 1 + + assert len(scores.scores) == 1 # should only be one worker_dp + worker_dp = list(scores.scores.keys())[0] + assert worker_dp.worker_id == worker_id + assert worker_dp.dp_rank is None + # should overlap perfectly + assert list(scores.scores.values())[0] == 1 # remove event event_publisher.remove_event() await asyncio.sleep(1) scores = await indexer.find_matches_for_request(test_token, lora_id) + # indexer is empty, no scores expected assert not scores.scores diff --git a/lib/llm/src/kv_router.rs b/lib/llm/src/kv_router.rs index 535a428984..0ab6e84bd5 100644 --- a/lib/llm/src/kv_router.rs +++ b/lib/llm/src/kv_router.rs @@ -14,6 +14,7 @@ use dynamo_runtime::{ protocols::annotated::Annotated, }; use futures::stream::{self, StreamExt}; +use protocols::WorkerDp; pub mod indexer; pub mod metrics_aggregator; @@ -25,9 +26,11 @@ pub mod scoring; use crate::{ kv_router::{ - indexer::{KvIndexer, KvIndexerInterface, RouterEvent}, + indexer::{KvIndexer, KvIndexerInterface}, metrics_aggregator::KvMetricsAggregator, - protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult}, + protocols::{ + LocalBlockHash, RouterEvent, RouterRequest, RouterResponse, WorkerSelectionResult, + }, scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest}, scoring::ProcessedEndpoints, }, @@ -51,7 +54,7 @@ pub trait WorkerSelector { workers: &ProcessedEndpoints, request: &SchedulingRequest, block_size: usize, - ) -> Result; + ) -> Result, KvSchedulerError>; } /// KV Router configuration parameters @@ -102,7 +105,7 @@ impl KvRouterConfig { /// A KvRouter only decides which worker you should use. It doesn't send you there. /// TODO: Rename this to indicate it only selects a worker, it does not route. pub struct KvRouter { - indexer: KvIndexer, + indexer: KvIndexer, scheduler: KvScheduler, block_size: usize, } @@ -137,7 +140,7 @@ impl KvRouter { tokio::spawn(async move { while let Some(event) = kv_events_rx.next().await { - let event: RouterEvent = match serde_json::from_slice(&event.payload) { + let event: RouterEvent = match serde_json::from_slice(&event.payload) { Ok(event) => event, Err(e) => { tracing::warn!("Failed to deserialize RouterEvent: {:?}", e); @@ -160,7 +163,7 @@ impl KvRouter { } // [TODO] indexer needs to take 'lora_id' as parameter - pub async fn schedule(&self, token_ids: &Vec, _lora_id: u64) -> Result { + pub async fn schedule(&self, token_ids: &Vec, _lora_id: u64) -> Result { // Extracting part of the code in KvRouter::generate() for only // the decision making part, routing is done by the caller let isl_tokens = token_ids.len(); @@ -175,7 +178,7 @@ impl KvRouter { /// Give these tokens, find the worker with the best match in it's KV cache. /// Returned overlap amount is in number of blocks. - async fn find_best_match(&self, tokens: &[u32]) -> anyhow::Result<(i64, u32)> { + async fn find_best_match(&self, tokens: &[u32]) -> anyhow::Result<(WorkerDp, u32)> { let isl_tokens = tokens.len(); let block_size = self.block_size; @@ -202,15 +205,17 @@ impl KvRouter { } #[async_trait] -impl AsyncEngine, ManyOut>, Error> for KvRouter { +impl AsyncEngine, ManyOut>>, Error> + for KvRouter +{ async fn generate( &self, request: SingleIn, - ) -> Result>> { + ) -> Result>>> { let (request, ctx) = request.into_parts(); - let (worker_id, _) = self.find_best_match(&request.tokens).await?; + let (best_match, _) = self.find_best_match(&request.tokens).await?; - let response = RouterResponse { worker_id }; + let response = RouterResponse { worker: best_match }; let response = Annotated::from_data(response); let stream = stream::iter(vec![response]); Ok(ResponseStream::new(Box::pin(stream), ctx.context())) @@ -247,8 +252,11 @@ impl AsyncEngine, ManyOut>; +type SharedRadixBlock = Rc>>; pub fn compute_hash(data: &[u8]) -> u64 { xxh3::xxh3_64_with_seed(data, XXH3_SEED) @@ -133,43 +130,18 @@ pub fn compute_block_hash_for_seq(tokens: &[u32], kv_block_size: usize) -> Vec Self { - Self { worker_id, event } - } -} - /// A block in the Radix Tree. #[derive(Debug)] -struct RadixBlock { +struct RadixBlock { /// A map of child blocks, keyed by their local block hash. - children: HashMap, + children: HashMap>, /// A set of worker IDs associated with this block. - workers: HashSet, + workers: HashSet, /// A buffer of times that this block was last traversed recent_uses: VecDeque, } -impl RadixBlock { +impl RadixBlock { /// Create a new `RadixBlock`. /// /// ### Returns @@ -184,10 +156,10 @@ impl RadixBlock { } } -pub struct RadixTree { +pub struct RadixTree { /// This is the root of the radix/prefix tree /// This will only contain root blocks - root: SharedRadixBlock, + root: SharedRadixBlock, /// This is a global lookup table for all blocks which will let you jump into /// the radix tree at any point @@ -197,18 +169,18 @@ pub struct RadixTree { /// Transitioning to a radix tree only would require a change in the messaging structure /// as the entire prefix would need to be sent. Alternatively, we could use block_depth /// integers to indicate how many blocks to skip and use a radix/prefix tree at each level. - lookup: HashMap>, + lookup: HashMap>>, /// The time buffer the radix tree should check when considering frequence of block accesses expiration_duration: Option, } -impl Default for RadixTree { +impl Default for RadixTree { fn default() -> Self { Self::new() } } -impl RadixTree { +impl RadixTree { /// Create a new `RadixTree`. /// /// ### Returns @@ -236,7 +208,11 @@ impl RadixTree { /// ### Returns /// /// An `OverlapScores` representing the match scores. - pub fn find_matches(&self, sequence: Vec, early_exit: bool) -> OverlapScores { + pub fn find_matches( + &self, + sequence: Vec, + early_exit: bool, + ) -> OverlapScores { let mut scores = OverlapScores::new(); let mut current = self.root.clone(); let now = Instant::now(); @@ -280,12 +256,12 @@ impl RadixTree { /// ### Arguments /// /// * `event` - The `RouterEvent` to apply. - pub fn apply_event(&mut self, event: RouterEvent) { - let (worker_id, event) = (event.worker_id, event.event); + pub fn apply_event(&mut self, event: RouterEvent) { + let (worker_id, event) = (event.worker, event.event); let (id, op) = (event.event_id, event.data); tracing::trace!(id, "Store operation: {:?}", op); - let worker_lookup = self.lookup.entry(worker_id).or_default(); + let worker_lookup = self.lookup.entry(worker_id.clone()).or_default(); match op { KvCacheEventData::Stored(op) => { @@ -301,7 +277,7 @@ impl RadixTree { Some(current) => current.clone(), None => { tracing::warn!( - worker_id = worker_id.to_string(), + worker_id = ?worker_id, id, parent_hash = ?op.parent_hash, "Failed to find parent block; skipping store operation" @@ -331,7 +307,7 @@ impl RadixTree { }; // add our worker_id to the block - block.borrow_mut().workers.insert(worker_id); + block.borrow_mut().workers.insert(worker_id.clone()); // add the block to the worker_id lookup table worker_lookup.insert(block_id.block_hash, block.clone()); @@ -355,7 +331,7 @@ impl RadixTree { Some(entry) => entry.clone(), None => { tracing::warn!( - worker_id = worker_id.to_string(), + worker_id = ?worker_id, id, "Failed to find block to remove; skipping remove operation" ); @@ -379,7 +355,7 @@ impl RadixTree { } } - pub fn remove_worker(&mut self, worker: WorkerId) { + pub fn remove_worker(&mut self, worker: T) { if let Some((_, blocks)) = self.lookup.remove_entry(&worker) { blocks.iter().for_each(|(_, block)| { block.borrow_mut().workers.remove(&worker); @@ -387,7 +363,7 @@ impl RadixTree { } } - pub fn clear_all_blocks(&mut self, worker: WorkerId) { + pub fn clear_all_blocks(&mut self, worker: T) { // Check if the worker has any blocks to clear if let Some(blocks) = self.lookup.get(&worker) { let blocks_to_clear: Vec<_> = blocks.values().collect(); @@ -407,20 +383,20 @@ impl RadixTree { /// Scores representing the overlap of workers. #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct OverlapScores { +pub struct OverlapScores { // map of worker_id to score - pub scores: HashMap, + pub scores: HashMap, // List of frequencies that the blocks have been accessed. Entries with value 0 are omitted. pub frequencies: Vec, } -impl Default for OverlapScores { +impl Default for OverlapScores { fn default() -> Self { Self::new() } } -impl OverlapScores { +impl OverlapScores { /// Create a new `OverlapScores`. /// /// ### Returns @@ -437,10 +413,10 @@ impl OverlapScores { /// /// ### Arguments /// - /// * `workers` - A reference to a `HashSet` of `WorkerId`s. - pub fn update_scores(&mut self, workers: &HashSet) { + /// * `workers` - A reference to a `HashSet` of worker IDs. + pub fn update_scores(&mut self, workers: &HashSet) { for worker in workers { - let score = self.scores.entry(*worker).or_insert(0); + let score = self.scores.entry(worker.clone()).or_insert(0); *score += 1; } } @@ -457,17 +433,17 @@ impl OverlapScores { } /// A request to find matches in the Radix Tree. -pub struct MatchRequest { +pub struct MatchRequest { /// A vector of `LocalBlockHash` representing the sequence to match. sequence: Vec, /// A boolean indicating whether to exit early if a single match is found. early_exit: bool, /// A channel sender to send the `OverlapScores` response. - resp: oneshot::Sender, + resp: oneshot::Sender>, } #[async_trait] -pub trait KvIndexerInterface { +pub trait KvIndexerInterface { /// Find matches for a given sequence of `LocalBlockHash`es. /// /// ### Arguments @@ -480,7 +456,7 @@ pub trait KvIndexerInterface { async fn find_matches( &self, sequence: Vec, - ) -> Result; + ) -> Result, KvRouterError>; /// Find matches for a given sequence of tokens. /// @@ -494,43 +470,43 @@ pub trait KvIndexerInterface { async fn find_matches_for_request( &self, tokens: &[u32], - ) -> Result; + ) -> Result, KvRouterError>; /// Apply a `RouterEvent` to the KV store. /// /// ### Arguments /// /// * `event` - The `RouterEvent` to apply. - async fn apply_event(&mut self, event: RouterEvent); + async fn apply_event(&mut self, event: RouterEvent); /// Remove a worker's entries from the trie. /// /// ### Arguments /// /// * `worker` - The worker to remove from the trie. - async fn remove_worker(&mut self, worker: WorkerId); + async fn remove_worker(&mut self, worker: T); /// Shutdown the KV Indexer. fn shutdown(&mut self); } /// The KV Indexer, managing the KV store and handling events and match requests. -pub struct KvIndexer { +pub struct KvIndexer { /// A `CancellationToken` for managing shutdown. cancel: CancellationToken, /// A sender for `RouterEvent`s. - event_tx: mpsc::Sender, + event_tx: mpsc::Sender>, /// A sender for `MatchRequest`s. - match_tx: mpsc::Sender, + match_tx: mpsc::Sender>, /// A sender for remove worker requests. - remove_worker_tx: mpsc::Sender, + remove_worker_tx: mpsc::Sender, /// A handle to the background task managing the KV store. task: OnceLock>, /// The size of the KV block this indexer can handle. kv_block_size: usize, } -impl KvIndexer { +impl KvIndexer { /// Create a new `KvIndexer`. /// /// ### Arguments @@ -546,9 +522,9 @@ impl KvIndexer { expiration_duration: Option, kv_block_size: usize, ) -> Self { - let (event_tx, event_rx) = mpsc::channel::(2048); - let (match_tx, match_rx) = mpsc::channel::(128); - let (remove_worker_tx, remove_worker_rx) = mpsc::channel::(16); + let (event_tx, event_rx) = mpsc::channel::>(2048); + let (match_tx, match_rx) = mpsc::channel::>(128); + let (remove_worker_tx, remove_worker_rx) = mpsc::channel::(16); let cancel_clone = token.clone(); let task = std::thread::spawn(move || { // create a new tokio runtime which will only perform work on a single thread @@ -624,17 +600,17 @@ impl KvIndexer { /// ### Returns /// /// A `mpsc::Sender` for `RouterEvent`s. - pub fn event_sender(&self) -> mpsc::Sender { + pub fn event_sender(&self) -> mpsc::Sender> { self.event_tx.clone() } } #[async_trait] -impl KvIndexerInterface for KvIndexer { +impl KvIndexerInterface for KvIndexer { async fn find_matches( &self, sequence: Vec, - ) -> Result { + ) -> Result, KvRouterError> { let (resp_tx, resp_rx) = oneshot::channel(); let req = MatchRequest { sequence, @@ -658,7 +634,7 @@ impl KvIndexerInterface for KvIndexer { async fn find_matches_for_request( &self, tokens: &[u32], - ) -> Result { + ) -> Result, KvRouterError> { tracing::debug!( "Finding matches for request tokens: {:?} / len: {}", tokens, @@ -669,11 +645,11 @@ impl KvIndexerInterface for KvIndexer { self.find_matches(sequence).await } - async fn apply_event(&mut self, event: RouterEvent) { + async fn apply_event(&mut self, event: RouterEvent) { self.event_tx.send(event).await.unwrap(); } - async fn remove_worker(&mut self, worker: WorkerId) { + async fn remove_worker(&mut self, worker: T) { self.remove_worker_tx.send(worker).await.unwrap(); } @@ -686,28 +662,28 @@ impl KvIndexerInterface for KvIndexer { } #[derive(Debug, Clone)] -pub struct ShardedMatchRequest { +pub struct ShardedMatchRequest { sequence: Vec, early_exit: bool, - resp: mpsc::Sender, + resp: mpsc::Sender>, } /// The KV Indexer, managing the KV store and handling events and match requests. -pub struct KvIndexerSharded { +pub struct KvIndexerSharded { /// A `CancellationToken` for managing shutdown. cancel: CancellationToken, /// The size of the KV block this indexer can handle. kv_block_size: usize, - worker_assignments: HashMap, + worker_assignments: HashMap, worker_counts: Vec, - event_tx: Vec>, - request_broadcast_tx: broadcast::Sender, - remove_worker_tx: Vec>, + event_tx: Vec>>, + request_broadcast_tx: broadcast::Sender>, + remove_worker_tx: Vec>, tasks: Vec>, } -impl KvIndexerSharded { +impl KvIndexerSharded { /// Create a new `KvIndexerSharded`. /// /// ### Arguments @@ -725,19 +701,18 @@ impl KvIndexerSharded { expiration_duration: Option, kv_block_size: usize, ) -> Self { - let worker_assignments: HashMap = HashMap::new(); + let worker_assignments: HashMap = HashMap::new(); let worker_counts: Vec = vec![0; num_shards]; let mut event_tx = Vec::new(); let mut remove_worker_tx = Vec::new(); let mut tasks = Vec::new(); - let (request_broadcast_tx, _) = broadcast::channel::(1048576); + let (request_broadcast_tx, _) = broadcast::channel::>(1048576); for _ in 0..num_shards { - let (shard_event_tx, mut shard_event_rx) = mpsc::channel::(2048); - let (shard_remove_worker_tx, mut shard_remove_worker_rx) = - mpsc::channel::(16); + let (shard_event_tx, mut shard_event_rx) = mpsc::channel::>(2048); + let (shard_remove_worker_tx, mut shard_remove_worker_rx) = mpsc::channel::(16); let mut shard_broadcast_rx = request_broadcast_tx.subscribe(); let cancel = token.clone(); @@ -812,11 +787,11 @@ impl KvIndexerSharded { } #[async_trait] -impl KvIndexerInterface for KvIndexerSharded { +impl KvIndexerInterface for KvIndexerSharded { async fn find_matches( &self, sequence: Vec, - ) -> Result { + ) -> Result, KvRouterError> { 'match_loop: loop { let (match_tx, mut match_rx) = mpsc::channel(self.event_tx.len()); self.request_broadcast_tx @@ -863,14 +838,14 @@ impl KvIndexerInterface for KvIndexerSharded { async fn find_matches_for_request( &self, tokens: &[u32], - ) -> Result { + ) -> Result, KvRouterError> { let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size); self.find_matches(sequence).await } - async fn apply_event(&mut self, event: RouterEvent) { + async fn apply_event(&mut self, event: RouterEvent) { #[allow(clippy::map_entry)] - if !self.worker_assignments.contains_key(&event.worker_id) { + if !self.worker_assignments.contains_key(&event.worker) { // Get the shard with the smallest amount of workers. let selected_shard = self .worker_counts @@ -881,17 +856,17 @@ impl KvIndexerInterface for KvIndexerSharded { .0; self.worker_assignments - .insert(event.worker_id, selected_shard); + .insert(event.worker.clone(), selected_shard); self.worker_counts[selected_shard] += 1; } - self.event_tx[self.worker_assignments[&event.worker_id]] + self.event_tx[self.worker_assignments[&event.worker]] .send(event) .await .unwrap(); } - async fn remove_worker(&mut self, worker: WorkerId) { + async fn remove_worker(&mut self, worker: T) { if let Some((_, shard)) = self.worker_assignments.remove_entry(&worker) { self.worker_counts[shard] -= 1; self.remove_worker_tx[shard].send(worker).await.unwrap(); @@ -909,13 +884,15 @@ impl KvIndexerInterface for KvIndexerSharded { #[cfg(test)] mod tests { - use super::*; use rstest::rstest; use rstest_reuse::{self, *}; use tokio::time; use tokio_util::sync::CancellationToken; + // Use u64 as a simple WorkerIdTrait implementation for tests + type TestWorkerId = u64; + fn setup() { dynamo_runtime::logging::init(); } @@ -941,13 +918,13 @@ mod tests { } fn create_store_event( - worker_id: WorkerId, + worker_id: TestWorkerId, event_id: u64, hashes: Vec, parent: Option, - ) -> RouterEvent { + ) -> RouterEvent { RouterEvent { - worker_id, + worker: worker_id, event: KvCacheEvent { event_id, data: add_blocks(hashes, parent), @@ -955,9 +932,13 @@ mod tests { } } - fn create_remove_event(worker_id: WorkerId, event_id: u64, hashes: Vec) -> RouterEvent { + fn create_remove_event( + worker_id: TestWorkerId, + event_id: u64, + hashes: Vec, + ) -> RouterEvent { RouterEvent { - worker_id, + worker: worker_id, event: KvCacheEvent { event_id, data: KvCacheEventData::Removed(KvCacheRemoveData { @@ -1338,7 +1319,7 @@ mod tests { token: &CancellationToken, num_shards: usize, kv_block_size: usize, - ) -> Box { + ) -> Box> { if num_shards == 1 { Box::new(KvIndexer::new(token.clone(), kv_block_size)) } else { @@ -1423,7 +1404,7 @@ mod tests { const ONE_MILLIS: Duration = Duration::from_millis(1); setup(); - let mut kv_indexer: Box; + let mut kv_indexer: Box>; let token = CancellationToken::new(); let expiration = Duration::from_millis(50); @@ -1534,7 +1515,7 @@ mod tests { }; let router_event = RouterEvent::new(worker_id, kv_cache_event); - assert_eq!(router_event.worker_id, worker_id); + assert_eq!(router_event.worker, worker_id); assert_eq!(router_event.event.event_id, 1); if let KvCacheEventData::Stored(store_op) = &router_event.event.data { assert_eq!(store_op.blocks.len(), 1); @@ -1551,7 +1532,7 @@ mod tests { #[test] fn test_radix_tree_default() { setup(); - let radix_tree: RadixTree = Default::default(); + let radix_tree: RadixTree = Default::default(); assert!(radix_tree.root.borrow().children.is_empty()); assert!(radix_tree.root.borrow().workers.is_empty()); assert!(radix_tree.lookup.is_empty()); @@ -1560,7 +1541,7 @@ mod tests { #[test] fn test_overlap_scores_default() { setup(); - let overlap_scores: OverlapScores = Default::default(); + let overlap_scores: OverlapScores = Default::default(); assert!(overlap_scores.scores.is_empty()); } } diff --git a/lib/llm/src/kv_router/metrics_aggregator.rs b/lib/llm/src/kv_router/metrics_aggregator.rs index 156d1dfb02..fc2e451314 100644 --- a/lib/llm/src/kv_router/metrics_aggregator.rs +++ b/lib/llm/src/kv_router/metrics_aggregator.rs @@ -18,8 +18,7 @@ use std::sync::Once; pub use crate::kv_router::protocols::ForwardPassMetrics; use crate::kv_router::KV_METRICS_ENDPOINT; -use crate::kv_router::scheduler::Endpoint; -use crate::kv_router::ProcessedEndpoints; +use crate::kv_router::scoring::{Endpoint, ProcessedEndpoints}; use dynamo_runtime::component::Component; use dynamo_runtime::{service::EndpointInfo, utils::Duration, Result}; use tokio::sync::watch; @@ -134,11 +133,16 @@ pub async fn collect_endpoints_task( .collect(); tracing::trace!("Found {} endpoints for service: {service_subject}", endpoints.len()); - let processed = ProcessedEndpoints::new(endpoints); + // Only create and send ProcessedEndpoints if we have valid endpoints + if !endpoints.is_empty() { + let processed = ProcessedEndpoints::new(endpoints); - if watch_tx.send(processed).is_err() { - tracing::trace!("failed to send processed endpoints; shutting down"); - break; + if watch_tx.send(processed).is_err() { + tracing::trace!("failed to send processed endpoints; shutting down"); + break; + } + } else { + tracing::trace!("No valid endpoints found, skipping ProcessedEndpoints creation"); } } } diff --git a/lib/llm/src/kv_router/protocols.rs b/lib/llm/src/kv_router/protocols.rs index 8131a54f72..730be4929d 100644 --- a/lib/llm/src/kv_router/protocols.rs +++ b/lib/llm/src/kv_router/protocols.rs @@ -14,7 +14,40 @@ // limitations under the License. use crate::tokens::Token; +use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; +use std::cmp::Eq; +use std::fmt::Debug; +use std::hash::Hash; + +pub type WorkerId = i64; +pub type DpRank = u32; + +#[derive(Hash, PartialEq, Eq, Debug, Clone, Copy, Serialize, Deserialize, Default)] +pub struct WorkerDp { + pub worker_id: WorkerId, + pub dp_rank: Option, +} + +impl std::fmt::Display for WorkerDp { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.dp_rank { + Some(dp_rank) => write!(f, "{}_{}", self.worker_id, dp_rank), + None => write!(f, "{}", self.worker_id), + } + } +} + +// Cannot add DeserializedOwned otherwise compiler will complain +pub trait WorkerGeneral: + Hash + Eq + Debug + Clone + Send + Sync + Default + 'static + Serialize +{ +} + +impl WorkerGeneral for T where + T: Hash + Eq + Debug + Clone + Send + Sync + Default + 'static + Serialize + DeserializeOwned +{ +} #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct RouterRequest { @@ -22,14 +55,14 @@ pub struct RouterRequest { } #[derive(Debug, Clone, Serialize, Deserialize, Default)] -pub struct RouterResponse { - pub worker_id: i64, +pub struct RouterResponse { + pub worker: T, } #[derive(Debug)] -pub struct WorkerSelectionResult { +pub struct WorkerSelectionResult { /// The worker id of the selected worker - pub worker_id: i64, + pub worker: T, /// The total number of blocks required to prefill the request pub required_blocks: u64, @@ -58,14 +91,14 @@ pub struct ForwardPassMetrics { /// A [`LocalBlockHash`] is a hash computed from the tokens_ids, extra_token_ids and the optional /// lora_id of a block. -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Ord, PartialOrd)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub struct LocalBlockHash(pub u64); /// A sequence aware hash of a block where the hash is computed from the tokens_ids, extra_token_ids /// and the optional lora_id of a block, PLUS the hash of the parent block. /// /// In this case, the hashing function is external and unknown. -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Ord, PartialOrd)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub struct ExternalSequenceBlockHash(pub u64); // Implement From trait for convenient conversion @@ -137,6 +170,38 @@ pub struct KvCacheRemoveData { pub block_hashes: Vec, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct KVHitRateEvent { + pub worker: T, + pub isl_blocks: usize, + pub overlap_blocks: usize, +} + +/// A [`KvCacheEvent`] on a specific LLM worker denoted by [`WorkerId`]. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RouterEvent { + /// The ID of the worker emitting the event. + pub worker: T, + /// The cache event associated with the worker. + pub event: KvCacheEvent, +} + +impl RouterEvent { + /// Create a new `RouterEvent`. + /// + /// ### Arguments + /// + /// * `worker_id` - The ID of the worker emitting the event. + /// * `event` - The cache event. + /// + /// ### Returns + /// + /// A new `RouterEvent`. + pub fn new(worker: T, event: KvCacheEvent) -> Self { + Self { worker, event } + } +} + impl Serialize for LocalBlockHash { fn serialize(&self, serializer: S) -> Result where diff --git a/lib/llm/src/kv_router/publisher.rs b/lib/llm/src/kv_router/publisher.rs index d4bf56e0d8..da8be84b7b 100644 --- a/lib/llm/src/kv_router/publisher.rs +++ b/lib/llm/src/kv_router/publisher.rs @@ -14,9 +14,7 @@ // limitations under the License. use crate::kv_router::{ - indexer::{compute_block_hash_for_seq, RouterEvent}, - protocols::*, - KV_EVENT_SUBJECT, KV_METRICS_ENDPOINT, + indexer::compute_block_hash_for_seq, protocols::*, KV_EVENT_SUBJECT, KV_METRICS_ENDPOINT, }; use async_trait::async_trait; use dynamo_runtime::traits::{events::EventPublisher, DistributedRuntimeProvider}; @@ -45,6 +43,14 @@ use zeromq::{Socket, SocketRecv, SubSocket}; // KV Event Publishers ----------------------------------------------------- // ------------------------------------------------------------------------- +/// Represents a single cache event with an ID and associated data. +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct KvCacheEventWithDp { + pub kv_cache_event: KvCacheEvent, + #[serde(skip_serializing_if = "Option::is_none")] + pub dp_rank: Option, +} + /// Configure the source of KV events. /// Currently, only ZMQ is supported. pub enum KvEventSourceConfig { @@ -65,7 +71,7 @@ impl KvEventSource { kv_block_size: usize, source_config: KvEventSourceConfig, cancellation_token: CancellationToken, - tx: mpsc::UnboundedSender, + tx: mpsc::UnboundedSender, ) -> Result { match source_config { KvEventSourceConfig::Zmq { endpoint, topic } => { @@ -97,7 +103,6 @@ impl KvEventSource { /// A publisher of KV events. pub struct KvEventPublisher { - /// The size of the KV block. kv_block_size: usize, /// The source of KV events. /// Can be `None` if all events provided through [`KvEventPublisher::publish`]. @@ -105,19 +110,19 @@ pub struct KvEventPublisher { /// The cancellation token. cancellation_token: CancellationToken, /// The channel to send events to. - tx: mpsc::UnboundedSender, + tx: mpsc::UnboundedSender, } impl KvEventPublisher { pub fn new( component: Component, - worker_id: i64, + worker_id: WorkerId, kv_block_size: usize, source_config: Option, ) -> Result { let cancellation_token = CancellationToken::new(); - let (tx, rx) = mpsc::unbounded_channel::(); + let (tx, rx) = mpsc::unbounded_channel::(); // Create our event source (if any) let mut source = None; @@ -150,7 +155,10 @@ impl KvEventPublisher { }) } - pub fn publish(&self, event: KvCacheEvent) -> Result<(), mpsc::error::SendError> { + pub fn publish( + &self, + event: KvCacheEventWithDp, + ) -> Result<(), mpsc::error::SendError> { tracing::trace!("Publish event: {:?}", event); self.tx.send(event) } @@ -178,30 +186,36 @@ impl Drop for KvEventPublisher { async fn start_event_processor( publisher: P, - worker_id: i64, + worker_id: WorkerId, cancellation_token: CancellationToken, - mut rx: mpsc::UnboundedReceiver, + mut rx: mpsc::UnboundedReceiver, ) { + tracing::debug!("KV Event processor starting for worker_id: {}", worker_id); + loop { tokio::select! { _ = cancellation_token.cancelled() => { - tracing::info!("KV Event source received cancellation signal"); + tracing::debug!("KV Event processor received cancellation signal for worker_id: {}", worker_id); break; } - event = rx.recv() => { - let Some(event) = event else { - tracing::debug!("Event processor channel closed."); + maybe_data = rx.recv() => { + let Some(data) = maybe_data else { + tracing::debug!("KV Event processor channel closed for worker_id: {}", worker_id); break; }; // Encapsulate in a router event and publish. - let router_event = RouterEvent::new(worker_id, event); + let event = data.kv_cache_event; + let dp_rank = data.dp_rank; + + let router_event = RouterEvent::new((worker_id, dp_rank), event); if let Err(e) = publisher.publish(KV_EVENT_SUBJECT, &router_event).await { - tracing::error!("Failed to publish event: {}", e); + tracing::error!("Failed to publish event for worker_id: {}, dp_rank: {:?}, error: {}", worker_id, dp_rank, e); } } } } + tracing::debug!("KV Event processor exiting for worker_id: {}", worker_id); } // Error handling configuration for ZMQ operations @@ -221,12 +235,12 @@ fn calculate_backoff_ms(consecutive_errors: u32) -> u64 { async fn start_zmq_listener( zmq_endpoint: String, zmq_topic: String, - tx: mpsc::UnboundedSender, + tx: mpsc::UnboundedSender, cancellation_token: CancellationToken, kv_block_size: usize, ) { tracing::debug!( - "KVEventPublisher connecting to ZMQ endpoint {} (topic '{}')", + "ZMQ listener starting - connecting to endpoint: {}, topic: '{}'", zmq_endpoint, zmq_topic ); @@ -237,15 +251,25 @@ async fn start_zmq_listener( // Subscribe to the requested topic (empty string == all topics) if let Err(e) = socket.subscribe(&zmq_topic).await { - tracing::error!("Failed to subscribe on ZMQ socket: {}", e); + tracing::error!( + "Failed to subscribe on ZMQ socket for {}: {}", + zmq_endpoint, + e + ); return; } if let Err(e) = socket.connect(&zmq_endpoint).await { - tracing::error!("Failed to connect ZMQ SUB socket: {}", e); + tracing::error!( + "Failed to connect ZMQ SUB socket to {}: {}", + zmq_endpoint, + e + ); return; } + tracing::debug!("ZMQ listener successfully connected to {}", zmq_endpoint); + let mut consecutive_errors = 0u32; loop { @@ -254,7 +278,7 @@ async fn start_zmq_listener( // Check for cancellation _ = cancellation_token.cancelled() => { - tracing::info!("ZMQ listener received cancellation signal"); + tracing::debug!("ZMQ listener received cancellation signal for {}", zmq_endpoint); break; } @@ -268,6 +292,7 @@ async fn start_zmq_listener( tracing::error!( error=%e, consecutive_errors=%consecutive_errors, + endpoint=%zmq_endpoint, "Too many consecutive ZMQ errors, terminating listener" ); break; @@ -280,6 +305,7 @@ async fn start_zmq_listener( error=%e, consecutive_errors=%consecutive_errors, backoff_ms=%backoff_ms, + endpoint=%zmq_endpoint, "Error reading from ZMQ socket, applying exponential backoff" ); @@ -293,7 +319,7 @@ async fn start_zmq_listener( let mut frames: Vec> = msg.into_vec().into_iter().map(|frame| frame.to_vec()).collect(); if frames.len() != 3 { - tracing::warn!(expected=3, actual=%frames.len(), "Received unexpected ZMQ frame count"); + tracing::warn!(expected=3, actual=%frames.len(), endpoint=%zmq_endpoint, "Received unexpected ZMQ frame count"); continue; } @@ -302,7 +328,7 @@ async fn start_zmq_listener( let seq_bytes = frames.pop().unwrap(); if seq_bytes.len() != 8 { - tracing::warn!(expected=8, actual=%seq_bytes.len(), "Invalid sequence number byte length"); + tracing::warn!(expected=8, actual=%seq_bytes.len(), endpoint=%zmq_endpoint, "Invalid sequence number byte length"); continue; } @@ -312,22 +338,25 @@ async fn start_zmq_listener( let batch_result = rmps::from_slice::(&payload); let Ok(batch) = batch_result else { let e = batch_result.unwrap_err(); - tracing::warn!(error=%e, "Failed to decode KVEventBatch msgpack"); + tracing::warn!(error=%e, endpoint=%zmq_endpoint, "Failed to decode KVEventBatch msgpack"); continue; }; + tracing::trace!("ZMQ listener decoded batch with {} events, dp_rank: {:?} from {}", batch.events.len(), batch.data_parallel_rank, zmq_endpoint); + // For each of our events, convert them to [`KvCacheEvent`] and send to the event_processor. + let dp_rank = batch.data_parallel_rank; for raw_event in batch.events.into_iter() { - let event = convert_event(raw_event, seq, kv_block_size, &warning_count); - if tx.send(event).is_err() { - tracing::warn!("Failed to send message to channel - receiver dropped"); + let kv_cache_event = convert_event(raw_event, seq, kv_block_size, &warning_count); + if tx.send(KvCacheEventWithDp { kv_cache_event, dp_rank }).is_err() { + tracing::warn!("Failed to send message to channel - receiver dropped for {}", zmq_endpoint); return; } } } } - tracing::debug!("ZMQ listener exiting"); } + tracing::debug!("ZMQ listener exiting for {}", zmq_endpoint); } /// Convert a raw event coming from the ZMQ channel into the internal @@ -438,6 +467,8 @@ pub fn create_stored_blocks( struct KvEventBatch { ts: f64, events: Vec, + #[serde(alias = "dp_rank")] + data_parallel_rank: Option, } #[derive(Debug, Deserialize, Serialize)] @@ -479,13 +510,18 @@ impl WorkerMetricsPublisher { self.tx.send(metrics) } - pub async fn create_endpoint(&self, component: Component) -> Result<()> { + pub async fn create_endpoint(&self, component: Component, suffix: Option<&str>) -> Result<()> { let mut metrics_rx = self.rx.clone(); let handler = Arc::new(KvLoadEndpoingHander::new(metrics_rx.clone())); let handler = Ingress::for_engine(handler)?; + let endpoint_name = match suffix { + Some(s) => format!("{}-{}", KV_METRICS_ENDPOINT, s), + None => KV_METRICS_ENDPOINT.to_string(), + }; + component - .endpoint(KV_METRICS_ENDPOINT) + .endpoint(&endpoint_name) .endpoint_builder() .stats_handler(move |_| { let metrics = metrics_rx.borrow_and_update().clone(); @@ -705,15 +741,20 @@ mod tests_startup_helpers { async fn test_start_event_processor() { let (component, published) = MockComponent::new(); - let event = KvCacheEvent { + let kv_cache_event = KvCacheEvent { event_id: 1, data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes: vec![ExternalSequenceBlockHash(1), ExternalSequenceBlockHash(2)], }), }; + let event = KvCacheEventWithDp { + kv_cache_event, + dp_rank: None, + }; + let token = CancellationToken::new(); - let (tx, rx) = mpsc::unbounded_channel::(); + let (tx, rx) = mpsc::unbounded_channel::(); tx.send(event).unwrap(); drop(tx); @@ -737,7 +778,7 @@ mod tests_startup_helpers { #[tokio::test] async fn test_start_zmq_listener_pushes_to_channel() { // Prepare channel that listener should fill - let (tx, mut rx) = mpsc::unbounded_channel::(); + let (tx, mut rx) = mpsc::unbounded_channel::(); // ZMQ TCP endpoint using localhost with fixed port let endpoint = "tcp://127.0.0.1:15555"; @@ -770,7 +811,11 @@ mod tests_startup_helpers { lora_id: None, }]; - let batch = KvEventBatch { ts: 0.0, events }; + let batch = KvEventBatch { + ts: 0.0, + events, + data_parallel_rank: None, + }; let payload = Bytes::from(rmps::to_vec(&batch).unwrap()); @@ -795,7 +840,7 @@ mod tests_startup_helpers { let KvCacheEventData::Stored(KvCacheStoreData { parent_hash, blocks, - }) = event.data + }) = event.kv_cache_event.data else { panic!("expected KvCacheStoreData"); }; diff --git a/lib/llm/src/kv_router/recorder.rs b/lib/llm/src/kv_router/recorder.rs index 17c66c7925..40cdbffcbe 100644 --- a/lib/llm/src/kv_router/recorder.rs +++ b/lib/llm/src/kv_router/recorder.rs @@ -13,23 +13,24 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::kv_router::indexer::RouterEvent; +use crate::kv_router::protocols::*; use crate::recorder::Recorder; -// Type alias for backward compatibility -pub type KvRecorder = Recorder; +// Type alias for backward compatibility, now generic +pub type KvRecorder = Recorder>; #[cfg(test)] mod tests { use super::*; use crate::kv_router::indexer::KvIndexer; - use crate::kv_router::indexer::WorkerId; - use crate::kv_router::protocols::*; use std::time::Duration; use tempfile::tempdir; use tokio::fs; use tokio_util::sync::CancellationToken; + // Use i64 for tests + type TestWorkerId = i64; + fn make_blocks(hashes: Vec) -> Vec { hashes .iter() @@ -51,11 +52,11 @@ mod tests { } fn create_store_event( - worker_id: WorkerId, + worker_id: TestWorkerId, event_id: u64, hashes: Vec, parent: Option, - ) -> RouterEvent { + ) -> RouterEvent { RouterEvent::new( worker_id, KvCacheEvent { @@ -65,7 +66,11 @@ mod tests { ) } - fn create_remove_event(worker_id: WorkerId, event_id: u64, hashes: Vec) -> RouterEvent { + fn create_remove_event( + worker_id: TestWorkerId, + event_id: u64, + hashes: Vec, + ) -> RouterEvent { RouterEvent::new( worker_id, KvCacheEvent { @@ -88,7 +93,7 @@ mod tests { // Part 1: Record events to a file let token = CancellationToken::new(); - let recorder = KvRecorder::new(token.clone(), &file_path, None, None, None) + let recorder = KvRecorder::::new(token.clone(), &file_path, None, None, None) .await .unwrap(); let event_tx = recorder.event_sender(); @@ -128,13 +133,19 @@ mod tests { // Part 2: Now create a KvIndexer and load the events from the file let indexer_token = CancellationToken::new(); let kv_block_size = 32; // Default block size for testing - let indexer = KvIndexer::new(indexer_token.clone(), kv_block_size); + let indexer = KvIndexer::::new(indexer_token.clone(), kv_block_size); let indexer_event_tx = indexer.event_sender(); // Use the send_events method to load events from file to indexer - let count = KvRecorder::send_events(&file_path, &indexer_event_tx, false, None, None) - .await - .unwrap(); + let count = KvRecorder::::send_events( + &file_path, + &indexer_event_tx, + false, + None, + None, + ) + .await + .unwrap(); assert_eq!(count, 2, "Expected to send 2 events from file to indexer"); } } diff --git a/lib/llm/src/kv_router/scheduler.rs b/lib/llm/src/kv_router/scheduler.rs index edf85d3198..f9afafc555 100644 --- a/lib/llm/src/kv_router/scheduler.rs +++ b/lib/llm/src/kv_router/scheduler.rs @@ -16,11 +16,9 @@ use dynamo_runtime::component::Namespace; use dynamo_runtime::traits::events::EventPublisher; use rand::Rng; -use serde::{Deserialize, Serialize}; use std::borrow::BorrowMut; use std::collections::HashMap; -use super::protocols::WorkerSelectionResult; use super::WorkerSelector; use crate::kv_router::indexer::OverlapScores; pub use crate::kv_router::protocols::ForwardPassMetrics; @@ -28,12 +26,7 @@ use crate::kv_router::scoring::ProcessedEndpoints; use crate::kv_router::KvRouterConfig; use crate::kv_router::KV_HIT_RATE_SUBJECT; -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct KVHitRateEvent { - pub worker_id: i64, - pub isl_blocks: usize, - pub overlap_blocks: usize, -} +use super::protocols::{KVHitRateEvent, WorkerDp, WorkerSelectionResult}; #[derive(Debug, thiserror::Error)] pub enum KvSchedulerError { @@ -47,39 +40,15 @@ pub enum KvSchedulerError { SubscriberShutdown, } -/// [gluo FIXME] exactly the same as EndpointInfo except that 'data' -/// is cleaned (not optional) -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Endpoint { - pub name: String, - pub subject: String, - pub data: ForwardPassMetrics, -} - -impl Endpoint { - pub fn worker_id(&self) -> i64 { - i64::from_str_radix( - self.subject - .split("-") - .last() - .expect("invalid subject") - .to_string() - .as_str(), - 16, - ) - .expect("invalid worker id") - } -} - pub struct SchedulingRequest { pub isl_tokens: usize, - pub overlap: OverlapScores, - resp_tx: tokio::sync::oneshot::Sender, + pub overlap: OverlapScores, + resp_tx: tokio::sync::oneshot::Sender, } impl SchedulingRequest { - pub fn respond(self, worker_id: i64) { - if self.resp_tx.send(worker_id).is_err() { + pub fn respond(self, identifier: WorkerDp) { + if self.resp_tx.send(identifier).is_err() { tracing::trace!("failed to send response to requestor"); } } @@ -100,7 +69,8 @@ impl KvScheduler { let mut endpoints_rx = endpoints_rx; let mut endpoints: ProcessedEndpoints = endpoints_rx.borrow_and_update().clone(); - let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel::(); + let (event_tx, event_rx) = + tokio::sync::mpsc::unbounded_channel::>(); tokio::spawn(async move { let mut event_rx = event_rx; while let Some(event) = event_rx.recv().await { @@ -178,9 +148,9 @@ impl KvScheduler { pub async fn schedule( &self, - overlap: OverlapScores, + overlap: OverlapScores, isl_tokens: usize, - ) -> Result { + ) -> Result { let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); let request = SchedulingRequest { isl_tokens, @@ -201,12 +171,12 @@ impl KvScheduler { // This becomes the driver function that handles the selection result pub fn process_worker_selection( workers: &mut ProcessedEndpoints, - selection: WorkerSelectionResult, - event_tx: &tokio::sync::mpsc::UnboundedSender, -) -> i64 { + selection: WorkerSelectionResult, + event_tx: &tokio::sync::mpsc::UnboundedSender>, +) -> WorkerDp { let worker = workers .endpoints - .get_mut(&selection.worker_id) + .get_mut(&selection.worker) .expect("worker not found"); // Update worker state predictively @@ -220,14 +190,14 @@ pub fn process_worker_selection( // Emit event if let Err(e) = event_tx.send(KVHitRateEvent { - worker_id: selection.worker_id, + worker: selection.worker, isl_blocks: selection.required_blocks as usize, overlap_blocks: selection.overlap_blocks, }) { tracing::warn!("Failed to send KV hit rate event: {:?}", e); } - selection.worker_id + selection.worker } // Default implementation matching the Python _cost_function @@ -250,7 +220,7 @@ impl WorkerSelector for DefaultWorkerSelector { workers: &ProcessedEndpoints, request: &SchedulingRequest, block_size: usize, - ) -> Result { + ) -> Result, KvSchedulerError> { assert!(request.isl_tokens > 0); if workers.endpoints.is_empty() { @@ -261,13 +231,11 @@ impl WorkerSelector for DefaultWorkerSelector { let mut max_waiting = 0.0; // Calculate worker scores and find max waiting requests - for (worker_id, ep) in workers.endpoints.iter() { - // Calculate score similar to Python version - if let Some(score) = request.overlap.scores.get(worker_id) { + for (worker_dp, ep) in workers.endpoints.iter() { + if let Some(score) = request.overlap.scores.get(worker_dp) { let score = *score as f64 * block_size as f64 / request.isl_tokens as f64; - worker_scores.insert(worker_id, score); + worker_scores.insert(worker_dp, score); } - // Track max waiting requests max_waiting = f64::max(max_waiting, ep.data.num_requests_waiting as f64); } @@ -278,13 +246,11 @@ impl WorkerSelector for DefaultWorkerSelector { // Calculate logits for each worker let mut best_logit = f64::NEG_INFINITY; - let mut best_workers = Vec::new(); - - for (worker_id, ep) in workers.endpoints.iter() { - let worker_id = *worker_id; + let mut best_worker_dps = Vec::new(); + for (worker_dp, ep) in workers.endpoints.iter() { // Get score or default to 0.0 - let score = worker_scores.get(&worker_id).copied().unwrap_or(0.0); + let score = worker_scores.get(worker_dp).copied().unwrap_or(0.0); // Calculate normalized metrics let gpu_cache_usage = ep.data.gpu_cache_usage_perc as f64; @@ -300,7 +266,7 @@ impl WorkerSelector for DefaultWorkerSelector { - self.kv_router_config.waiting_requests_weight * normalized_waiting; tracing::trace!( - "Formula for {worker_id}: {logit:.3} = {:.1} * {score:.3} - {:.1} * {gpu_cache_usage:.3} - {:.1} * {normalized_waiting:.3}", + "Formula for {worker_dp:?}: {logit:.3} = {:.3} * {score:.3} - {:.3} * {gpu_cache_usage:.3} - {:.3} * {normalized_waiting:.3}", self.kv_router_config.overlap_score_weight, self.kv_router_config.gpu_cache_usage_weight, self.kv_router_config.waiting_requests_weight, @@ -310,40 +276,45 @@ impl WorkerSelector for DefaultWorkerSelector { match logit.partial_cmp(&best_logit) { Some(std::cmp::Ordering::Greater) => { best_logit = logit; - best_workers.clear(); - best_workers.push(worker_id); + best_worker_dps.clear(); + best_worker_dps.push(worker_dp); } Some(std::cmp::Ordering::Equal) => { - best_workers.push(worker_id); + best_worker_dps.push(worker_dp); } _ => {} } } // Return early if no valid workers found - if best_workers.is_empty() { + if best_worker_dps.is_empty() { return Err(KvSchedulerError::NoEndpoints); } else if best_logit == 0.0 { tracing::debug!("best worker logit is 0"); } - let worker_id = if best_workers.len() == 1 { - best_workers[0] + let best_worker_dp = if best_worker_dps.len() == 1 { + best_worker_dps[0] } else { // Randomly select from best workers let mut rng = rand::rng(); - best_workers[rng.random_range(0..best_workers.len())] + best_worker_dps[rng.random_range(0..best_worker_dps.len())] }; // Lower to trace level eventually. Nice to see KV routing working for now. - tracing::debug!("Selected worker: {worker_id}, logit: {best_logit:.3}"); + tracing::debug!("Selected worker: {best_worker_dp:?}, logit: {best_logit:.3}"); // Log selection metrics let total_blocks = std::cmp::max(request.isl_tokens / block_size, 1) as u64; - let overlap_blocks = request.overlap.scores.get(&worker_id).copied().unwrap_or(0) as usize; + let overlap_blocks = request + .overlap + .scores + .get(best_worker_dp) + .copied() + .unwrap_or(0) as usize; Ok(WorkerSelectionResult { - worker_id, + worker: *best_worker_dp, required_blocks: total_blocks, overlap_blocks, }) diff --git a/lib/llm/src/kv_router/scoring.rs b/lib/llm/src/kv_router/scoring.rs index c663c22b5a..b301e73109 100644 --- a/lib/llm/src/kv_router/scoring.rs +++ b/lib/llm/src/kv_router/scoring.rs @@ -18,11 +18,46 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; -use crate::kv_router::scheduler::Endpoint; +use crate::kv_router::protocols::{DpRank, ForwardPassMetrics, WorkerDp}; + +/// [gluo FIXME] exactly the same as EndpointInfo except that 'data' +/// is cleaned (not optional) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Endpoint { + pub name: String, + // contains dp + pub subject: String, + // one set of metrics for each dp worker + pub data: ForwardPassMetrics, +} + +impl Endpoint { + pub fn worker_id(&self) -> i64 { + i64::from_str_radix( + self.subject + .split("-") + .last() + .expect("invalid subject") + .to_string() + .as_str(), + 16, + ) + .expect("invalid worker id") + } + + pub fn dp_rank(&self) -> Option { + let parts: Vec<&str> = self.subject.split("-").collect(); + if parts.len() < 3 { + return None; + } + let second_to_last = parts[parts.len() - 2]; + second_to_last.parse::().ok() + } +} #[derive(Debug, Default, Serialize, Deserialize, Clone)] pub struct ProcessedEndpoints { - pub endpoints: HashMap, + pub endpoints: HashMap, pub load_avg: f64, pub load_std: f64, } @@ -32,8 +67,11 @@ impl ProcessedEndpoints { // compute some basic statistics let load_values: Vec = endpoints .iter() - .map(|x| x.data.kv_active_blocks as f64) + .map(|endpoint| endpoint.data.kv_active_blocks as f64) .collect(); + if load_values.is_empty() { + panic!("No endpoints to process!") + }; let load_avg = load_values.iter().copied().sum::() / load_values.len() as f64; let variance = load_values .iter() @@ -42,7 +80,19 @@ impl ProcessedEndpoints { / load_values.len() as f64; let load_std = variance.sqrt(); - let endpoints = endpoints.into_iter().map(|e| (e.worker_id(), e)).collect(); + // pass in (worker_id, dp_rank) tuple + let endpoints = endpoints + .into_iter() + .map(|e| { + ( + WorkerDp { + worker_id: e.worker_id(), + dp_rank: e.dp_rank(), + }, + e, + ) + }) + .collect(); ProcessedEndpoints { endpoints, diff --git a/lib/llm/src/kv_router/worker.rs b/lib/llm/src/kv_router/worker.rs deleted file mode 100644 index fc44624f85..0000000000 --- a/lib/llm/src/kv_router/worker.rs +++ /dev/null @@ -1,105 +0,0 @@ -// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::sync::Arc; - -pub use crate::kv_router::protocols::ForwardPassMetrics; - -use anyhow::Result; -use derive_builder::Builder; -use dynamo_runtime::pipeline::network::{ - ingress::push_endpoint::PushEndpoint, - PushWorkHandler, -}; - -use dynamo_runtime::transports::nats::{self, ServiceExt}; - -use tokio::sync::watch; -use tokio_util::sync::CancellationToken; -use tracing as log; - -#[derive(Builder)] -pub struct KvRoutedIngress { - #[builder(setter(into))] - pub service_name: String, - - #[builder(setter(into))] - pub worker_id: String, - - pub nats: nats::Client, - pub service_handler: Arc, - pub metrics_rx: watch::Receiver>, - pub cancellation_token: CancellationToken, -} - -/// version of crate -pub const VERSION: &str = env!("CARGO_PKG_VERSION"); - -impl KvRoutedIngress { - pub fn builder() -> KvRoutedIngressBuilder { - KvRoutedIngressBuilder::default() - } - - pub async fn start(self) -> Result<()> { - let worker_id = self.worker_id; - - log::trace!( - worker_id, - "Starting nats service: {}:{}", - self.service_name, - VERSION - ); - - let mut metrics_rx = self.metrics_rx; - let worker_id_clone = worker_id.clone(); - - let service = self - .nats - .client() - .service_builder() - .description("A handy min max service") - .stats_handler(move |name, stats| { - log::debug!( - worker_id = worker_id_clone.as_str(), - "[IN worker?] Stats for service {}: {:?}", - name, - stats - ); - let metrics = metrics_rx.borrow_and_update().clone(); - serde_json::to_value(&*metrics).unwrap() - }) - .start(self.service_name.as_str(), VERSION) - .await - .map_err(|e| anyhow::anyhow!("Failed to start service: {e}"))?; - - let group = service.group(self.service_name.as_str()); - - log::trace!(worker_id, "Starting endpoint: {}", worker_id); - - // creates an endpoint for the service - let service_endpoint = group - .endpoint(worker_id.clone()) - .await - .map_err(|e| anyhow::anyhow!("Failed to start endpoint: {e}"))?; - - let push_endpoint = PushEndpoint::builder() - .service_handler(self.service_handler) - .cancellation_token(self.cancellation_token) - .build() - .map_err(|e| anyhow::anyhow!("Failed to build push endpoint: {e}"))?; - - push_endpoint.start(service_endpoint).await - } -} diff --git a/lib/llm/src/protocols/common/preprocessor.rs b/lib/llm/src/protocols/common/preprocessor.rs index 6b3be76069..151b9dbfcc 100644 --- a/lib/llm/src/protocols/common/preprocessor.rs +++ b/lib/llm/src/protocols/common/preprocessor.rs @@ -51,6 +51,10 @@ pub struct PreprocessedRequest { /// Estimated number of prefix hit tokens (only used in kv aware routing) #[builder(default)] pub estimated_prefix_hit_num_blocks: Option, + + // The dp_rank to route to + #[builder(default)] + pub dp_rank: Option, } impl PreprocessedRequest {