2020from argparse import Namespace
2121from typing import AsyncIterator , Tuple
2222
23+ import numpy as np # Add numpy import
2324from components .worker import VllmWorker
2425from utils .check_worker import check_required_workers
25- from utils .protocol import Tokens
26+ from utils .protocol import LocalBlockHashes
2627from utils .vllm import RouterType
2728
2829from dynamo .llm import AggregatedMetrics , KvIndexer , KvMetricsAggregator , OverlapScores
3536logger = logging .getLogger (__name__ )
3637
3738
39+ def softmax_sample_from_logits (
40+ logits : dict [str , float ], temperature : float = 1.0 , lower_is_better : bool = True
41+ ) -> str :
42+ if not logits :
43+ raise ValueError ("Empty logits dictionary" )
44+
45+ keys = list (logits .keys ())
46+ values = np .array (list (logits .values ()))
47+
48+ min_val = np .min (values )
49+ max_val = np .max (values )
50+
51+ if min_val == max_val :
52+ # All values are the same, uniform probability
53+ probabilities = np .ones (len (keys )) / len (keys )
54+ else :
55+ normalized = values / (max_val - min_val )
56+ if lower_is_better :
57+ normalized = - 1 * normalized
58+
59+ scaled = normalized / temperature
60+
61+ exp_values = np .exp (scaled - np .max (scaled ))
62+ probabilities = exp_values / np .sum (exp_values )
63+
64+ # Sample from the probability distribution
65+ return np .random .choice (keys , p = probabilities )
66+
67+
3868def parse_args (service_name , prefix ) -> Namespace :
3969 parser = argparse .ArgumentParser ()
40- parser .add_argument (
41- "--min-workers" ,
42- type = int ,
43- default = 1 ,
44- help = "Minimum number of workers required before proceeding" ,
45- )
4670 parser .add_argument (
4771 "--model" ,
4872 type = str ,
4973 default = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" ,
5074 help = "Model that is being served" ,
5175 )
76+ parser .add_argument (
77+ "--min-workers" ,
78+ type = int ,
79+ default = 1 ,
80+ help = "Minimum number of workers required before proceeding" ,
81+ )
5282 # TODO: Read block size
5383 parser .add_argument (
5484 "--block-size" ,
@@ -68,6 +98,12 @@ def parse_args(service_name, prefix) -> Namespace:
6898 default = "kv" ,
6999 help = "The router type" ,
70100 )
101+ parser .add_argument (
102+ "--softmax-sample" ,
103+ type = bool ,
104+ default = False ,
105+ help = "Whether to do softmax sampling based on worker logits (default is to pick smallest)" ,
106+ )
71107 config = ServiceConfig .get_instance ()
72108 config_args = config .as_args (service_name , prefix = prefix )
73109 args = parser .parse_args (config_args )
@@ -93,8 +129,10 @@ def __init__(self):
93129 self .args = parse_args (self .__class__ .__name__ , "" )
94130
95131 self .default_metrics = {
96- "gpu_cache_usage_perc" : 0.0 ,
132+ "kv_active_blocks" : 0 ,
133+ "kv_total_blocks" : 1 ,
97134 "num_requests_waiting" : 0.0 ,
135+ "gpu_cache_usage_perc" : 0.0 ,
98136 "gpu_prefix_cache_hit_rate" : 0.0 ,
99137 }
100138
@@ -117,8 +155,36 @@ async def async_init(self):
117155 if self .router_type == RouterType .KV :
118156 self .indexer = KvIndexer (kv_listener , self .args .block_size )
119157 self .metrics_aggregator = KvMetricsAggregator (kv_listener )
158+
159+ self .active_blocks_dict = {}
160+ worker_ids = self .workers_client .instance_ids ()
161+ for worker_id in worker_ids :
162+ # [old_value, predictive_value]
163+ self .active_blocks_dict [worker_id ] = [0 , 0 ]
164+
120165 logger .info ("KV Router initialized" )
121166
167+ def _update_and_get_active_blocks (self , worker_id : str , polled_value : int ) -> int :
168+ """Helper routine to update waiting dict and return the desired waiting value.
169+
170+ This method implements a predictive mechanism for tracking waiting requests:
171+ - If a new polled value is detected (different from the stored old value),
172+ it updates both the old and predictive values to this new measurement and returns it
173+ - If no change is detected (polled value equals old value), it returns the
174+ predictive value which has been incremented based on previous routing decisions
175+
176+ This allows the router to account for requests that have been dispatched but
177+ not yet reflected in the polled metrics.
178+ """
179+ old_value , predictive_value = self .active_blocks_dict [worker_id ]
180+
181+ # Check if polled value is different from old value
182+ if polled_value != old_value :
183+ self .active_blocks_dict [worker_id ] = [polled_value , polled_value ]
184+ return polled_value
185+ else :
186+ return predictive_value
187+
122188 def _cost_function (
123189 self ,
124190 scores : OverlapScores | None ,
@@ -142,66 +208,79 @@ def _cost_function(
142208 (str, float): The best worker id and the corresponding score.
143209 """
144210
145- worker_scores = {}
211+ # Get all worker IDs from the client. This is needed because scores / metrics may not have values for all workers
212+ # and we want all workers to be considered in the logit calculation
213+ worker_ids = self .workers_client .instance_ids ()
214+ request_blocks = (
215+ token_length + self .args .block_size - 1
216+ ) // self .args .block_size
217+
218+ overlap_blocks_dict = {worker_id : 0 for worker_id in worker_ids }
219+ new_blocks_dict = {worker_id : request_blocks for worker_id in worker_ids }
220+
146221 if scores :
147222 for worker_id , score in scores .scores .items ():
148223 # score is number of matching blocks we multiply by block_size to get tokens
149224 # and compare to token_length. The larger the cache hit the better
150- worker_scores [worker_id ] = (
151- score * self .indexer .block_size () / token_length
152- )
225+ overlap_blocks_dict [worker_id ] = score
226+ new_blocks_dict [worker_id ] = request_blocks - score
153227 else :
154228 logger .warning ("Cannot get KV scores" )
155229
156230 worker_metrics = {}
157- max_waiting = 0.0
158231 if metrics :
159232 for endpoint in metrics .endpoints :
160233 worker_id = endpoint .worker_id
161234 worker_metrics [worker_id ] = {
162235 key : getattr (endpoint , key , self .default_metrics [key ])
163236 for key in self .default_metrics .keys ()
164237 }
165- max_waiting = max (
166- max_waiting , worker_metrics [worker_id ]["num_requests_waiting" ]
238+
239+ # Update waiting value using helper routine
240+ polled_active_blocks = int (
241+ worker_metrics [worker_id ]["kv_active_blocks" ]
167242 )
243+ worker_metrics [worker_id ][
244+ "kv_active_blocks"
245+ ] = self ._update_and_get_active_blocks (worker_id , polled_active_blocks )
168246 else :
169247 logger .warning ("Cannot get metrics" )
170248
171- # Get all worker IDs from the client. This is needed because scores / metrics may not have values for all workers
172- # and we want all workers to be considered in the logit calculation
173- worker_ids = self .workers_client .instance_ids ()
174-
175249 worker_logits = {}
176250 for worker_id in worker_ids :
177251 # Use default values if worker not in scores or metrics
178- score = worker_scores .get (worker_id , 0.0 )
179252 metrics_dict = worker_metrics .get (worker_id , self .default_metrics )
180- gpu_cache_usage = metrics_dict ["gpu_cache_usage_perc " ]
253+ kv_total_blocks = metrics_dict ["kv_total_blocks " ]
181254
182- normalized_waiting = (
183- metrics_dict ["num_requests_waiting" ] / max_waiting
184- if max_waiting > 0
185- else 0.0
186- )
255+ new_blocks = new_blocks_dict [worker_id ]
256+ normalized_new_blocks = new_blocks / kv_total_blocks
257+ gpu_cache_usage = metrics_dict ["kv_active_blocks" ] / kv_total_blocks
258+
259+ # Use raw waiting value without normalization
260+ num_requests_waiting = metrics_dict ["num_requests_waiting" ]
187261
188262 # Have 1 metric that weights towards cache hit
189263 # 2 metrics that penalize overloaded worker and queuing
190- worker_logits [worker_id ] = 2 * score - gpu_cache_usage - normalized_waiting
264+ worker_logits [worker_id ] = (
265+ normalized_new_blocks + gpu_cache_usage + num_requests_waiting
266+ )
191267 logger .info (
192- f"Formula for { worker_id } : { worker_logits [worker_id ]:.3f} = 2.0 * { score :.3f} - { gpu_cache_usage :.3f} - { normalized_waiting :.3f} "
268+ f"Formula for { worker_id } : { worker_logits [worker_id ]:.3f} = { normalized_new_blocks :.3f} + { gpu_cache_usage :.3f} + { num_requests_waiting :.3f} "
193269 )
194270
195271 if not worker_logits or not any (worker_logits .values ()):
196272 logger .warning (f"All worker logits are zero. { fallback_msg } ." )
197273 return "" , 0.0
198274
199275 # Select the worker with the highest logit
200- max_logit = max (worker_logits .values ())
201- best_workers = [
202- wid for wid , logit in worker_logits .items () if logit == max_logit
203- ]
204- best_worker_id = random .choice (best_workers )
276+ if self .args .softmax_sample :
277+ best_worker_id = int (softmax_sample_from_logits (worker_logits ))
278+ else :
279+ min_logit = min (worker_logits .values ())
280+ best_workers = [
281+ wid for wid , logit in worker_logits .items () if logit == min_logit
282+ ]
283+ best_worker_id = random .choice (best_workers )
205284
206285 # Log the metrics for the selected worker
207286 if best_worker_id :
@@ -212,15 +291,23 @@ def _cost_function(
212291 f"Selected worker: { best_worker_id } , logit: { worker_logits [best_worker_id ]:.3f} " ,
213292 f"Score: { scores .scores .get (best_worker_id , 0.0 ) if scores else 0.0 :.3f} " ,
214293 f"GPU Cache Hit Rate: { metrics_dict ['gpu_prefix_cache_hit_rate' ]:.3f} " ,
215- f"GPU Cache Usage: { metrics_dict ['gpu_cache_usage_perc ' ]:.3f} " ,
294+ f"GPU Cache Usage: { metrics_dict ['kv_active_blocks' ] / metrics_dict [ 'kv_total_blocks ' ]:.3f} " ,
216295 f"Requests Waiting: { metrics_dict ['num_requests_waiting' ]} " ,
217296 ]
218297
219298 # Log to vllm_logger
220299 for message in log_messages :
221300 logger .info (message )
222301
223- return best_worker_id , worker_scores .get (best_worker_id , 0.0 )
302+ # Increment predictive waiting for the selected worker before returning
303+ self .active_blocks_dict [best_worker_id ][1 ] += new_blocks_dict [
304+ best_worker_id
305+ ]
306+
307+ return (
308+ best_worker_id ,
309+ overlap_blocks_dict [best_worker_id ] * self .args .block_size / token_length ,
310+ )
224311
225312 def _get_underloaded_worker (self , metrics : AggregatedMetrics | None ):
226313 if not metrics :
@@ -248,7 +335,9 @@ def _get_underloaded_worker(self, metrics: AggregatedMetrics | None):
248335 return best_worker_id , kv_load [best_worker_id ]
249336
250337 @endpoint ()
251- async def generate (self , request : Tokens ) -> AsyncIterator [Tuple [WorkerId , float ]]:
338+ async def generate (
339+ self , request : LocalBlockHashes
340+ ) -> AsyncIterator [Tuple [WorkerId , float ]]:
252341 metrics = await self .metrics_aggregator .get_metrics ()
253342
254343 # Quick return for KV_LOAD mode
@@ -263,19 +352,16 @@ async def generate(self, request: Tokens) -> AsyncIterator[Tuple[WorkerId, float
263352 return
264353
265354 # Existing KV routing logic
266- lora_id = 0
267355 try :
268- scores = await self .indexer .find_matches_for_request (
269- request .tokens , lora_id
270- )
356+ scores = await self .indexer .find_matches (request .hashes )
271357 except Exception as e :
272358 scores = {}
273359 logger .exception (f"Error finding matches: { e } . { fallback_msg } " )
274360 yield "" , 0.0
275361 return
276362
277363 worker_id , prefix_hit_rate = self ._cost_function (
278- scores , metrics , len ( request .tokens )
364+ scores , metrics , request .num_tokens
279365 )
280366
281367 if worker_id :
0 commit comments