Skip to content

Commit 8392e7a

Browse files
authored
feat: Unnormalize waiting requests + predictive load updates for Python router (mirroring Rust) + softmax sampling to reduce thrashing (#1638)
1 parent e53a759 commit 8392e7a

File tree

8 files changed

+414
-228
lines changed

8 files changed

+414
-228
lines changed

examples/llm/components/kv_router.py

Lines changed: 127 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@
2020
from argparse import Namespace
2121
from typing import AsyncIterator, Tuple
2222

23+
import numpy as np # Add numpy import
2324
from components.worker import VllmWorker
2425
from utils.check_worker import check_required_workers
25-
from utils.protocol import Tokens
26+
from utils.protocol import LocalBlockHashes
2627
from utils.vllm import RouterType
2728

2829
from dynamo.llm import AggregatedMetrics, KvIndexer, KvMetricsAggregator, OverlapScores
@@ -35,20 +36,49 @@
3536
logger = logging.getLogger(__name__)
3637

3738

39+
def softmax_sample_from_logits(
40+
logits: dict[str, float], temperature: float = 1.0, lower_is_better: bool = True
41+
) -> str:
42+
if not logits:
43+
raise ValueError("Empty logits dictionary")
44+
45+
keys = list(logits.keys())
46+
values = np.array(list(logits.values()))
47+
48+
min_val = np.min(values)
49+
max_val = np.max(values)
50+
51+
if min_val == max_val:
52+
# All values are the same, uniform probability
53+
probabilities = np.ones(len(keys)) / len(keys)
54+
else:
55+
normalized = values / (max_val - min_val)
56+
if lower_is_better:
57+
normalized = -1 * normalized
58+
59+
scaled = normalized / temperature
60+
61+
exp_values = np.exp(scaled - np.max(scaled))
62+
probabilities = exp_values / np.sum(exp_values)
63+
64+
# Sample from the probability distribution
65+
return np.random.choice(keys, p=probabilities)
66+
67+
3868
def parse_args(service_name, prefix) -> Namespace:
3969
parser = argparse.ArgumentParser()
40-
parser.add_argument(
41-
"--min-workers",
42-
type=int,
43-
default=1,
44-
help="Minimum number of workers required before proceeding",
45-
)
4670
parser.add_argument(
4771
"--model",
4872
type=str,
4973
default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
5074
help="Model that is being served",
5175
)
76+
parser.add_argument(
77+
"--min-workers",
78+
type=int,
79+
default=1,
80+
help="Minimum number of workers required before proceeding",
81+
)
5282
# TODO: Read block size
5383
parser.add_argument(
5484
"--block-size",
@@ -68,6 +98,12 @@ def parse_args(service_name, prefix) -> Namespace:
6898
default="kv",
6999
help="The router type",
70100
)
101+
parser.add_argument(
102+
"--softmax-sample",
103+
type=bool,
104+
default=False,
105+
help="Whether to do softmax sampling based on worker logits (default is to pick smallest)",
106+
)
71107
config = ServiceConfig.get_instance()
72108
config_args = config.as_args(service_name, prefix=prefix)
73109
args = parser.parse_args(config_args)
@@ -93,8 +129,10 @@ def __init__(self):
93129
self.args = parse_args(self.__class__.__name__, "")
94130

95131
self.default_metrics = {
96-
"gpu_cache_usage_perc": 0.0,
132+
"kv_active_blocks": 0,
133+
"kv_total_blocks": 1,
97134
"num_requests_waiting": 0.0,
135+
"gpu_cache_usage_perc": 0.0,
98136
"gpu_prefix_cache_hit_rate": 0.0,
99137
}
100138

@@ -117,8 +155,36 @@ async def async_init(self):
117155
if self.router_type == RouterType.KV:
118156
self.indexer = KvIndexer(kv_listener, self.args.block_size)
119157
self.metrics_aggregator = KvMetricsAggregator(kv_listener)
158+
159+
self.active_blocks_dict = {}
160+
worker_ids = self.workers_client.instance_ids()
161+
for worker_id in worker_ids:
162+
# [old_value, predictive_value]
163+
self.active_blocks_dict[worker_id] = [0, 0]
164+
120165
logger.info("KV Router initialized")
121166

167+
def _update_and_get_active_blocks(self, worker_id: str, polled_value: int) -> int:
168+
"""Helper routine to update waiting dict and return the desired waiting value.
169+
170+
This method implements a predictive mechanism for tracking waiting requests:
171+
- If a new polled value is detected (different from the stored old value),
172+
it updates both the old and predictive values to this new measurement and returns it
173+
- If no change is detected (polled value equals old value), it returns the
174+
predictive value which has been incremented based on previous routing decisions
175+
176+
This allows the router to account for requests that have been dispatched but
177+
not yet reflected in the polled metrics.
178+
"""
179+
old_value, predictive_value = self.active_blocks_dict[worker_id]
180+
181+
# Check if polled value is different from old value
182+
if polled_value != old_value:
183+
self.active_blocks_dict[worker_id] = [polled_value, polled_value]
184+
return polled_value
185+
else:
186+
return predictive_value
187+
122188
def _cost_function(
123189
self,
124190
scores: OverlapScores | None,
@@ -142,66 +208,79 @@ def _cost_function(
142208
(str, float): The best worker id and the corresponding score.
143209
"""
144210

145-
worker_scores = {}
211+
# Get all worker IDs from the client. This is needed because scores / metrics may not have values for all workers
212+
# and we want all workers to be considered in the logit calculation
213+
worker_ids = self.workers_client.instance_ids()
214+
request_blocks = (
215+
token_length + self.args.block_size - 1
216+
) // self.args.block_size
217+
218+
overlap_blocks_dict = {worker_id: 0 for worker_id in worker_ids}
219+
new_blocks_dict = {worker_id: request_blocks for worker_id in worker_ids}
220+
146221
if scores:
147222
for worker_id, score in scores.scores.items():
148223
# score is number of matching blocks we multiply by block_size to get tokens
149224
# and compare to token_length. The larger the cache hit the better
150-
worker_scores[worker_id] = (
151-
score * self.indexer.block_size() / token_length
152-
)
225+
overlap_blocks_dict[worker_id] = score
226+
new_blocks_dict[worker_id] = request_blocks - score
153227
else:
154228
logger.warning("Cannot get KV scores")
155229

156230
worker_metrics = {}
157-
max_waiting = 0.0
158231
if metrics:
159232
for endpoint in metrics.endpoints:
160233
worker_id = endpoint.worker_id
161234
worker_metrics[worker_id] = {
162235
key: getattr(endpoint, key, self.default_metrics[key])
163236
for key in self.default_metrics.keys()
164237
}
165-
max_waiting = max(
166-
max_waiting, worker_metrics[worker_id]["num_requests_waiting"]
238+
239+
# Update waiting value using helper routine
240+
polled_active_blocks = int(
241+
worker_metrics[worker_id]["kv_active_blocks"]
167242
)
243+
worker_metrics[worker_id][
244+
"kv_active_blocks"
245+
] = self._update_and_get_active_blocks(worker_id, polled_active_blocks)
168246
else:
169247
logger.warning("Cannot get metrics")
170248

171-
# Get all worker IDs from the client. This is needed because scores / metrics may not have values for all workers
172-
# and we want all workers to be considered in the logit calculation
173-
worker_ids = self.workers_client.instance_ids()
174-
175249
worker_logits = {}
176250
for worker_id in worker_ids:
177251
# Use default values if worker not in scores or metrics
178-
score = worker_scores.get(worker_id, 0.0)
179252
metrics_dict = worker_metrics.get(worker_id, self.default_metrics)
180-
gpu_cache_usage = metrics_dict["gpu_cache_usage_perc"]
253+
kv_total_blocks = metrics_dict["kv_total_blocks"]
181254

182-
normalized_waiting = (
183-
metrics_dict["num_requests_waiting"] / max_waiting
184-
if max_waiting > 0
185-
else 0.0
186-
)
255+
new_blocks = new_blocks_dict[worker_id]
256+
normalized_new_blocks = new_blocks / kv_total_blocks
257+
gpu_cache_usage = metrics_dict["kv_active_blocks"] / kv_total_blocks
258+
259+
# Use raw waiting value without normalization
260+
num_requests_waiting = metrics_dict["num_requests_waiting"]
187261

188262
# Have 1 metric that weights towards cache hit
189263
# 2 metrics that penalize overloaded worker and queuing
190-
worker_logits[worker_id] = 2 * score - gpu_cache_usage - normalized_waiting
264+
worker_logits[worker_id] = (
265+
normalized_new_blocks + gpu_cache_usage + num_requests_waiting
266+
)
191267
logger.info(
192-
f"Formula for {worker_id}: {worker_logits[worker_id]:.3f} = 2.0 * {score:.3f} - {gpu_cache_usage:.3f} - {normalized_waiting:.3f}"
268+
f"Formula for {worker_id}: {worker_logits[worker_id]:.3f} = {normalized_new_blocks:.3f} + {gpu_cache_usage:.3f} + {num_requests_waiting:.3f}"
193269
)
194270

195271
if not worker_logits or not any(worker_logits.values()):
196272
logger.warning(f"All worker logits are zero. {fallback_msg}.")
197273
return "", 0.0
198274

199275
# Select the worker with the highest logit
200-
max_logit = max(worker_logits.values())
201-
best_workers = [
202-
wid for wid, logit in worker_logits.items() if logit == max_logit
203-
]
204-
best_worker_id = random.choice(best_workers)
276+
if self.args.softmax_sample:
277+
best_worker_id = int(softmax_sample_from_logits(worker_logits))
278+
else:
279+
min_logit = min(worker_logits.values())
280+
best_workers = [
281+
wid for wid, logit in worker_logits.items() if logit == min_logit
282+
]
283+
best_worker_id = random.choice(best_workers)
205284

206285
# Log the metrics for the selected worker
207286
if best_worker_id:
@@ -212,15 +291,23 @@ def _cost_function(
212291
f"Selected worker: {best_worker_id}, logit: {worker_logits[best_worker_id]:.3f}",
213292
f"Score: {scores.scores.get(best_worker_id, 0.0) if scores else 0.0:.3f}",
214293
f"GPU Cache Hit Rate: {metrics_dict['gpu_prefix_cache_hit_rate']:.3f}",
215-
f"GPU Cache Usage: {metrics_dict['gpu_cache_usage_perc']:.3f}",
294+
f"GPU Cache Usage: {metrics_dict['kv_active_blocks'] / metrics_dict['kv_total_blocks']:.3f}",
216295
f"Requests Waiting: {metrics_dict['num_requests_waiting']}",
217296
]
218297

219298
# Log to vllm_logger
220299
for message in log_messages:
221300
logger.info(message)
222301

223-
return best_worker_id, worker_scores.get(best_worker_id, 0.0)
302+
# Increment predictive waiting for the selected worker before returning
303+
self.active_blocks_dict[best_worker_id][1] += new_blocks_dict[
304+
best_worker_id
305+
]
306+
307+
return (
308+
best_worker_id,
309+
overlap_blocks_dict[best_worker_id] * self.args.block_size / token_length,
310+
)
224311

225312
def _get_underloaded_worker(self, metrics: AggregatedMetrics | None):
226313
if not metrics:
@@ -248,7 +335,9 @@ def _get_underloaded_worker(self, metrics: AggregatedMetrics | None):
248335
return best_worker_id, kv_load[best_worker_id]
249336

250337
@endpoint()
251-
async def generate(self, request: Tokens) -> AsyncIterator[Tuple[WorkerId, float]]:
338+
async def generate(
339+
self, request: LocalBlockHashes
340+
) -> AsyncIterator[Tuple[WorkerId, float]]:
252341
metrics = await self.metrics_aggregator.get_metrics()
253342

254343
# Quick return for KV_LOAD mode
@@ -263,19 +352,16 @@ async def generate(self, request: Tokens) -> AsyncIterator[Tuple[WorkerId, float
263352
return
264353

265354
# Existing KV routing logic
266-
lora_id = 0
267355
try:
268-
scores = await self.indexer.find_matches_for_request(
269-
request.tokens, lora_id
270-
)
356+
scores = await self.indexer.find_matches(request.hashes)
271357
except Exception as e:
272358
scores = {}
273359
logger.exception(f"Error finding matches: {e}. {fallback_msg}")
274360
yield "", 0.0
275361
return
276362

277363
worker_id, prefix_hit_rate = self._cost_function(
278-
scores, metrics, len(request.tokens)
364+
scores, metrics, request.num_tokens
279365
)
280366

281367
if worker_id:

examples/llm/components/processor.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,14 @@
2424
from transformers import AutoTokenizer
2525
from utils.chat_processor import ChatProcessor, CompletionsProcessor, ProcessMixIn
2626
from utils.check_worker import check_required_workers
27-
from utils.protocol import MyRequestOutput, Tokens, vLLMGenerateRequest
27+
from utils.protocol import LocalBlockHashes, MyRequestOutput, vLLMGenerateRequest
2828
from utils.vllm import RouterType, parse_vllm_args
2929
from vllm.engine.arg_utils import AsyncEngineArgs
3030
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest
3131
from vllm.outputs import RequestOutput
3232
from vllm.transformers_utils.tokenizer import AnyTokenizer
3333

34-
from dynamo.llm import KvMetricsAggregator
34+
from dynamo.llm import KvMetricsAggregator, compute_block_hash_for_seq_py
3535
from dynamo.runtime import EtcdKvCache
3636
from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service
3737

@@ -242,9 +242,13 @@ async def process_and_stream():
242242

243243
prefix_hit_rate = 0.0 # Default value
244244
if self.use_router:
245+
token_ids = engine_prompt["prompt_token_ids"]
245246
router_generator = await self.router_client.generate(
246-
Tokens(
247-
tokens=engine_prompt["prompt_token_ids"]
247+
LocalBlockHashes(
248+
hashes=compute_block_hash_for_seq_py(
249+
token_ids, self.engine_args.block_size
250+
),
251+
num_tokens=len(token_ids),
248252
).model_dump_json()
249253
)
250254
decision = await router_generator.__anext__()

examples/llm/configs/agg_router.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ Processor:
2929

3030
Router:
3131
min-workers: 1
32+
softmax_sample: true
3233
common-configs: [model, block-size, router]
3334

3435
VllmWorker:

examples/llm/utils/protocol.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ class Tokens(BaseModel):
3636
tokens: list[int]
3737

3838

39+
class LocalBlockHashes(BaseModel):
40+
hashes: list[int]
41+
num_tokens: int
42+
43+
3944
class PrefillRequest(Request):
4045
request_id: str
4146

lib/bindings/python/rust/llm/kv.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,24 @@ impl KvIndexer {
481481
self.inner.block_size()
482482
}
483483

484+
fn find_matches<'p>(&self, py: Python<'p>, sequence: Vec<u64>) -> PyResult<Bound<'p, PyAny>> {
485+
let indexer = self.inner.clone();
486+
pyo3_async_runtimes::tokio::future_into_py(py, async move {
487+
let local_block_hashes: Vec<llm_rs::kv_router::protocols::LocalBlockHash> = sequence
488+
.into_iter()
489+
.map(llm_rs::kv_router::protocols::LocalBlockHash)
490+
.collect();
491+
492+
let rs_overlap_scores = indexer
493+
.find_matches(local_block_hashes)
494+
.await
495+
.map_err(to_pyerr)?;
496+
Ok(OverlapScores {
497+
inner: rs_overlap_scores,
498+
})
499+
})
500+
}
501+
484502
fn find_matches_for_request<'p>(
485503
&self,
486504
py: Python<'p>,

lib/bindings/python/src/dynamo/_core.pyi

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,18 @@ class KvIndexer:
527527
Create a `KvIndexer` object
528528
"""
529529

530+
def find_matches(self, sequence: List[int]) -> OverlapScores:
531+
"""
532+
Find prefix matches for the given sequence of block hashes.
533+
534+
Args:
535+
sequence: List of block hashes to find matches for
536+
537+
Returns:
538+
OverlapScores containing worker matching scores and frequencies
539+
"""
540+
...
541+
530542
def find_matches_for_request(
531543
self, token_ids: List[int], lora_id: int
532544
) -> OverlapScores:

lib/llm/src/kv_router.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ pub struct KvRouterConfig {
7373
impl Default for KvRouterConfig {
7474
fn default() -> Self {
7575
Self {
76-
overlap_score_weight: 2.0,
76+
overlap_score_weight: 1.0,
7777
gpu_cache_usage_weight: 1.0,
7878
waiting_requests_weight: 1.0,
7979
}

0 commit comments

Comments
 (0)