Skip to content

Commit ae894d0

Browse files
PeaBranedillon-cullinan
authored andcommitted
feat: don't modify kv scheduler states on query + more python binding (#2798)
Signed-off-by: PeaBrane <yanrpei@gmail.com>
1 parent 7ba9cf2 commit ae894d0

File tree

5 files changed

+339
-25
lines changed

5 files changed

+339
-25
lines changed

docs/architecture/kv_cache_routing.md

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,4 +292,70 @@ if __name__ == "__main__":
292292
asyncio.run(main())
293293
```
294294

295+
### Additional Routing Features
296+
297+
The `KvPushRouter` provides additional methods for fine-grained control:
298+
299+
- **`best_worker_id()`**: Query which worker would be selected for given tokens without actually routing the request. Returns `(worker_id, overlap_blocks)`.
300+
- **`get_potential_loads()`**: Get detailed load information for all workers including potential prefill tokens and active decode blocks.
301+
- **`worker_id` parameter in `generate()`**: Force routing to a specific worker by passing `worker_id=<id>` to bypass the automatic KV-aware selection.
302+
295303
The `router_config_override` parameter allows you to adjust routing behavior per request without recreating the router. This is useful for implementing different routing strategies based on request characteristics.
304+
305+
### Custom Routing Example: Minimizing TTFT
306+
307+
Here's an example of using `get_potential_loads()` to implement custom routing that minimizes Time To First Token (TTFT) by selecting the worker with the least prefill work:
308+
309+
```python
310+
import asyncio
311+
from dynamo._core import DistributedRuntime, KvPushRouter, KvRouterConfig
312+
313+
async def minimize_ttft_routing():
314+
# Setup router
315+
runtime = DistributedRuntime.detached()
316+
namespace = runtime.namespace("inference")
317+
component = namespace.component("vllm")
318+
endpoint = component.endpoint("generate")
319+
320+
router = KvPushRouter(
321+
endpoint=endpoint,
322+
block_size=16,
323+
kv_router_config=KvRouterConfig()
324+
)
325+
326+
# Your input tokens
327+
token_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
328+
329+
# Get potential loads for all workers
330+
potential_loads = await router.get_potential_loads(token_ids)
331+
332+
# Find worker with minimum prefill tokens (best for TTFT)
333+
best_worker = min(potential_loads, key=lambda x: x['potential_prefill_tokens'])
334+
335+
print(f"Worker loads: {potential_loads}")
336+
print(f"Selected worker {best_worker['worker_id']} with {best_worker['potential_prefill_tokens']} prefill tokens")
337+
338+
# Route directly to the selected worker
339+
stream = await router.generate(
340+
token_ids=token_ids,
341+
model="meta-llama/Llama-2-7b-hf",
342+
worker_id=best_worker['worker_id'], # Force routing to optimal worker
343+
stop_conditions={"max_tokens": 20}
344+
)
345+
346+
# Process response
347+
async for response in stream:
348+
if isinstance(response, dict) and "token_ids" in response:
349+
print(f"Generated tokens: {response['token_ids']}")
350+
351+
if __name__ == "__main__":
352+
asyncio.run(minimize_ttft_routing())
353+
```
354+
355+
This approach gives you complete control over routing decisions, allowing you to optimize for different metrics based on your specific requirements. As some examples:
356+
357+
- **Minimize TTFT**: Select worker with lowest `potential_prefill_tokens`
358+
- **Maximize cache reuse**: Use `best_worker_id()` which considers both prefill and decode loads
359+
- **Balance load**: Consider both `potential_prefill_tokens` and `potential_decode_blocks` together
360+
361+
See [KV Router Architecture](../components/router/README.md) for performance tuning details.

lib/bindings/python/rust/llm/kv.rs

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -909,7 +909,7 @@ impl KvPushRouter {
909909
}
910910

911911
#[allow(clippy::too_many_arguments)]
912-
#[pyo3(signature = (token_ids, model, stop_conditions=None, sampling_options=None, output_options=None, router_config_override=None))]
912+
#[pyo3(signature = (token_ids, model, stop_conditions=None, sampling_options=None, output_options=None, router_config_override=None, worker_id=None))]
913913
fn generate<'p>(
914914
&self,
915915
py: Python<'p>,
@@ -919,6 +919,7 @@ impl KvPushRouter {
919919
sampling_options: Option<PyObject>,
920920
output_options: Option<PyObject>,
921921
router_config_override: Option<PyObject>,
922+
worker_id: Option<i64>,
922923
) -> PyResult<Bound<'p, PyAny>> {
923924
// Depythonize the options with defaults
924925
let (stop_conditions, sampling_options, output_options, router_config_override) =
@@ -957,15 +958,22 @@ impl KvPushRouter {
957958
})?;
958959

959960
// Build the PreprocessedRequest
960-
let request = llm_rs::protocols::common::preprocessor::PreprocessedRequest::builder()
961+
let mut request_builder =
962+
llm_rs::protocols::common::preprocessor::PreprocessedRequest::builder();
963+
request_builder
961964
.model(model)
962965
.token_ids(token_ids)
963966
.stop_conditions(stop_conditions)
964967
.sampling_options(sampling_options)
965968
.output_options(output_options)
966-
.router_config_override(router_config_override)
967-
.build()
968-
.map_err(to_pyerr)?;
969+
.router_config_override(router_config_override);
970+
971+
// Set backend_instance_id if worker_id is provided
972+
if let Some(worker_id) = worker_id {
973+
request_builder.backend_instance_id(Some(worker_id));
974+
}
975+
976+
let request = request_builder.build().map_err(to_pyerr)?;
969977

970978
let inner = self.inner.clone();
971979

@@ -1010,6 +1018,59 @@ impl KvPushRouter {
10101018
})
10111019
}
10121020

1021+
#[pyo3(signature = (context_id, token_ids, router_config_override=None))]
1022+
fn best_worker_id<'p>(
1023+
&self,
1024+
py: Python<'p>,
1025+
context_id: String,
1026+
token_ids: Vec<u32>,
1027+
router_config_override: Option<PyObject>,
1028+
) -> PyResult<Bound<'p, PyAny>> {
1029+
let router_config_override = if let Some(obj) = router_config_override {
1030+
Python::with_gil(|py| {
1031+
let override_config: llm_rs::kv_router::RouterConfigOverride =
1032+
depythonize(obj.bind(py)).map_err(to_pyerr)?;
1033+
Ok::<_, PyErr>(Some(override_config))
1034+
})?
1035+
} else {
1036+
None
1037+
};
1038+
1039+
let inner = self.inner.clone();
1040+
1041+
pyo3_async_runtimes::tokio::future_into_py(py, async move {
1042+
let (worker_id, overlap_blocks) = inner
1043+
.find_best_match(&context_id, &token_ids, router_config_override.as_ref())
1044+
.await
1045+
.map_err(to_pyerr)?;
1046+
1047+
// Return a tuple of (worker_id, overlap_blocks)
1048+
Ok((worker_id, overlap_blocks))
1049+
})
1050+
}
1051+
1052+
fn get_potential_loads<'p>(
1053+
&self,
1054+
py: Python<'p>,
1055+
token_ids: Vec<u32>,
1056+
) -> PyResult<Bound<'p, PyAny>> {
1057+
let inner = self.inner.clone();
1058+
1059+
pyo3_async_runtimes::tokio::future_into_py(py, async move {
1060+
let loads = inner
1061+
.get_potential_loads(&token_ids)
1062+
.await
1063+
.map_err(to_pyerr)?;
1064+
1065+
// Use pythonize to convert Vec<PotentialLoad> to Python list of dicts
1066+
Python::with_gil(|py| {
1067+
pythonize(py, &loads)
1068+
.map(|obj| obj.unbind())
1069+
.map_err(to_pyerr)
1070+
})
1071+
})
1072+
}
1073+
10131074
/// Dump all events from the KV router's indexer as a JSON string
10141075
fn dump_events<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
10151076
let inner = self.inner.clone();

lib/bindings/python/src/dynamo/_core.pyi

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,6 +1227,7 @@ class KvPushRouter:
12271227
sampling_options: Optional[JsonLike] = None,
12281228
output_options: Optional[JsonLike] = None,
12291229
router_config_override: Optional[JsonLike] = None,
1230+
worker_id: Optional[int] = None,
12301231
) -> AsyncIterator[JsonLike]:
12311232
"""
12321233
Generate text using the KV-aware router.
@@ -1238,9 +1239,56 @@ class KvPushRouter:
12381239
sampling_options: Optional sampling configuration
12391240
output_options: Optional output configuration
12401241
router_config_override: Optional router configuration override
1242+
worker_id: Optional worker ID to route to directly. If set, the request
1243+
will be sent to this specific worker and router states will be
1244+
updated accordingly.
12411245
12421246
Returns:
12431247
An async iterator yielding generation responses
1248+
1249+
Note:
1250+
- If worker_id is set, the request bypasses KV matching and routes directly
1251+
to the specified worker while still updating router states.
1252+
- This is different from query_instance_id which doesn't route the request.
1253+
"""
1254+
...
1255+
1256+
async def best_worker_id(
1257+
self,
1258+
context_id: str,
1259+
token_ids: List[int],
1260+
router_config_override: Optional[JsonLike] = None,
1261+
) -> Tuple[int, int]:
1262+
"""
1263+
Find the best matching worker for the given tokens without updating states.
1264+
1265+
Args:
1266+
context_id: String identifier for the request
1267+
token_ids: List of token IDs to find matches for
1268+
router_config_override: Optional router configuration override
1269+
1270+
Returns:
1271+
A tuple of (worker_id, overlap_blocks) where:
1272+
- worker_id: The ID of the best matching worker
1273+
- overlap_blocks: The number of overlapping blocks found
1274+
"""
1275+
...
1276+
1277+
async def get_potential_loads(
1278+
self,
1279+
token_ids: List[int],
1280+
) -> List[Dict[str, int]]:
1281+
"""
1282+
Get potential prefill and decode loads for all workers.
1283+
1284+
Args:
1285+
token_ids: List of token IDs to evaluate
1286+
1287+
Returns:
1288+
A list of dictionaries, each containing:
1289+
- worker_id: The worker ID
1290+
- potential_prefill_tokens: Number of tokens that would need prefill
1291+
- potential_decode_blocks: Number of blocks currently in decode phase
12441292
"""
12451293
...
12461294

0 commit comments

Comments
 (0)