@@ -28,7 +28,8 @@ use crate::{
2828 indexer:: { KvIndexer , KvIndexerInterface } ,
2929 metrics_aggregator:: KvMetricsAggregator ,
3030 protocols:: {
31- LocalBlockHash , RouterEvent , RouterRequest , RouterResponse , WorkerSelectionResult ,
31+ DpRank , LocalBlockHash , RouterEvent , RouterRequest , RouterResponse , WorkerId ,
32+ WorkerSelectionResult ,
3233 } ,
3334 scheduler:: { KvScheduler , KvSchedulerError , SchedulingRequest } ,
3435 scoring:: ProcessedEndpoints ,
@@ -53,13 +54,13 @@ pub trait WorkerSelector {
5354 workers : & ProcessedEndpoints ,
5455 request : & SchedulingRequest ,
5556 block_size : usize ,
56- ) -> Result < WorkerSelectionResult , KvSchedulerError > ;
57+ ) -> Result < WorkerSelectionResult < ( WorkerId , DpRank ) > , KvSchedulerError > ;
5758}
5859
5960/// A KvRouter only decides which worker you should use. It doesn't send you there.
6061/// TODO: Rename this to indicate it only selects a worker, it does not route.
6162pub struct KvRouter {
62- indexer : KvIndexer ,
63+ indexer : KvIndexer < ( WorkerId , DpRank ) > ,
6364 scheduler : KvScheduler ,
6465 block_size : usize ,
6566}
@@ -94,15 +95,16 @@ impl KvRouter {
9495
9596 tokio:: spawn ( async move {
9697 while let Some ( event) = kv_events_rx. next ( ) . await {
97- let event: RouterEvent = match serde_json:: from_slice ( & event. payload ) {
98- Ok ( event) => event,
99- Err ( e) => {
100- tracing:: warn!( "Failed to deserialize RouterEvent: {:?}" , e) ;
101- // Choosing warn and continue to process other events from other workers
102- // A bad event likely signals a problem with a worker, but potentially other workers are still healthy
103- continue ;
104- }
105- } ;
98+ let event: RouterEvent < ( WorkerId , DpRank ) > =
99+ match serde_json:: from_slice ( & event. payload ) {
100+ Ok ( event) => event,
101+ Err ( e) => {
102+ tracing:: warn!( "Failed to deserialize RouterEvent: {:?}" , e) ;
103+ // Choosing warn and continue to process other events from other workers
104+ // A bad event likely signals a problem with a worker, but potentially other workers are still healthy
105+ continue ;
106+ }
107+ } ;
106108 if let Err ( e) = kv_events_tx. send ( event) . await {
107109 tracing:: debug!( "failed to send kv event to indexer; shutting down: {:?}" , e) ;
108110 }
@@ -117,7 +119,11 @@ impl KvRouter {
117119 }
118120
119121 // [TODO] indexer needs to take 'lora_id' as parameter
120- pub async fn schedule ( & self , token_ids : & Vec < u32 > , _lora_id : u64 ) -> Result < i64 > {
122+ pub async fn schedule (
123+ & self ,
124+ token_ids : & Vec < u32 > ,
125+ _lora_id : u64 ,
126+ ) -> Result < ( WorkerId , DpRank ) > {
121127 // Extracting part of the code in KvRouter::generate() for only
122128 // the decision making part, routing is done by the caller
123129 let isl_tokens = token_ids. len ( ) ;
@@ -132,7 +138,7 @@ impl KvRouter {
132138
133139 /// Give these tokens, find the worker with the best match in it's KV cache.
134140 /// Returned overlap amount is in number of blocks.
135- async fn find_best_match ( & self , tokens : & [ u32 ] ) -> anyhow:: Result < ( i64 , u32 ) > {
141+ async fn find_best_match ( & self , tokens : & [ u32 ] ) -> anyhow:: Result < ( ( WorkerId , DpRank ) , u32 ) > {
136142 let isl_tokens = tokens. len ( ) ;
137143 let block_size = self . block_size ;
138144
@@ -159,11 +165,17 @@ impl KvRouter {
159165}
160166
161167#[ async_trait]
162- impl AsyncEngine < SingleIn < RouterRequest > , ManyOut < Annotated < RouterResponse > > , Error > for KvRouter {
168+ impl
169+ AsyncEngine <
170+ SingleIn < RouterRequest > ,
171+ ManyOut < Annotated < RouterResponse < ( WorkerId , DpRank ) > > > ,
172+ Error ,
173+ > for KvRouter
174+ {
163175 async fn generate (
164176 & self ,
165177 request : SingleIn < RouterRequest > ,
166- ) -> Result < ManyOut < Annotated < RouterResponse > > > {
178+ ) -> Result < ManyOut < Annotated < RouterResponse < ( WorkerId , DpRank ) > > > > {
167179 let ( request, ctx) = request. into_parts ( ) ;
168180 let ( worker_id, _) = self . find_best_match ( & request. tokens ) . await ?;
169181
@@ -205,7 +217,8 @@ impl AsyncEngine<SingleIn<BackendInput>, ManyOut<Annotated<LLMEngineOutput>>, Er
205217 let ( mut backend_input, context) = request. into_parts ( ) ;
206218 backend_input. estimated_prefix_hit_num_blocks = Some ( overlap_amount) ;
207219 let updated_request = context. map ( |_| backend_input) ;
208- self . inner . direct ( updated_request, instance_id) . await
220+ // TODO: this does not do dp routing
221+ self . inner . direct ( updated_request, instance_id. 0 ) . await
209222 }
210223 }
211224 }
0 commit comments