Skip to content
2 changes: 2 additions & 0 deletions docs/architecture/kv_cache_routing.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ The main KV-aware routing arguments:

>[!Note]
> State persistence is only available when KV events are enabled (default). When using `--no-kv-events` with `ApproxKvIndexer`, state persistence is not currently supported.
>
> When `--kv-overlap-score-weight` is set to 0 or `--no-kv-events` is set, no KvIndexer will be launched to drain and process KV events. It's recommended to disable your backend workers from relaying events through `KvEventPublisher` to avoid event accumulation in JetStream. WIP to enable disabling publishing of KV events completely in these cases.

## Architecture

Expand Down
8 changes: 8 additions & 0 deletions launch/dynamo-run/src/flags.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,13 @@ pub struct Flags {
#[arg(long)]
pub router_replica_sync: Option<bool>,

/// KV Router: Whether to track active blocks in the router for memory management.
/// When false, the router will not maintain state about which blocks are active,
/// reducing memory overhead but potentially affecting scheduling decisions.
/// Default: true
#[arg(long)]
pub router_track_active_blocks: Option<bool>,

/// Max model context length. Reduce this if you don't have enough VRAM for the full model
/// context length (e.g. Llama 4).
/// Defaults to the model's max, which is usually model_max_length in tokenizer_config.json.
Expand Down Expand Up @@ -228,6 +235,7 @@ impl Flags {
self.router_temperature,
self.use_kv_events,
self.router_replica_sync,
self.router_track_active_blocks,
self.max_num_batched_tokens,
// defaulting below args (no longer maintaining new flags for dynamo-run)
None,
Expand Down
4 changes: 3 additions & 1 deletion lib/bindings/python/rust/llm/entrypoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,13 @@ impl KvRouterConfig {
#[pymethods]
impl KvRouterConfig {
#[new]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, router_replica_sync=false, router_snapshot_threshold=10000, router_reset_states=false))]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, router_replica_sync=false, router_track_active_blocks=true, router_snapshot_threshold=10000, router_reset_states=false))]
fn new(
overlap_score_weight: f64,
router_temperature: f64,
use_kv_events: bool,
router_replica_sync: bool,
router_track_active_blocks: bool,
router_snapshot_threshold: Option<u32>,
router_reset_states: bool,
) -> Self {
Expand All @@ -57,6 +58,7 @@ impl KvRouterConfig {
router_temperature,
use_kv_events,
router_replica_sync,
router_track_active_blocks,
router_snapshot_threshold,
router_reset_states,
..Default::default()
Expand Down
96 changes: 75 additions & 21 deletions lib/llm/src/kv_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ use serde::{Deserialize, Serialize};
pub mod approx;
pub mod indexer;
pub mod metrics_aggregator;
pub mod prefill_counter;
pub mod protocols;
pub mod publisher;
pub mod recorder;
Expand Down Expand Up @@ -102,6 +101,9 @@ pub struct KvRouterConfig {

pub router_replica_sync: bool,

/// Whether to track active blocks in the router (default: true)
pub router_track_active_blocks: bool,

// TODO: this is not actually used for now
// Would need this (along with total kv blocks) to trigger AllWorkersBusy error for e.g. rate-limiting
pub max_num_batched_tokens: u32,
Expand All @@ -120,6 +122,7 @@ impl Default for KvRouterConfig {
router_temperature: 0.0,
use_kv_events: true,
router_replica_sync: false,
router_track_active_blocks: true,
max_num_batched_tokens: 8192,
router_snapshot_threshold: Some(10000),
router_reset_states: false,
Expand All @@ -130,11 +133,13 @@ impl Default for KvRouterConfig {
impl KvRouterConfig {
/// Create a new KvRouterConfig with optional weight values.
/// If a weight is None, the default value will be used.
#[allow(clippy::too_many_arguments)]
pub fn new(
overlap_score_weight: Option<f64>,
temperature: Option<f64>,
use_kv_events: Option<bool>,
replica_sync: Option<bool>,
track_active_blocks: Option<bool>,
max_num_batched_tokens: Option<u32>,
router_snapshot_threshold: Option<Option<u32>>,
router_reset_states: Option<bool>,
Expand All @@ -145,6 +150,8 @@ impl KvRouterConfig {
router_temperature: temperature.unwrap_or(default.router_temperature),
use_kv_events: use_kv_events.unwrap_or(default.use_kv_events),
router_replica_sync: replica_sync.unwrap_or(default.router_replica_sync),
router_track_active_blocks: track_active_blocks
.unwrap_or(default.router_track_active_blocks),
max_num_batched_tokens: max_num_batched_tokens
.unwrap_or(default.max_num_batched_tokens),
router_snapshot_threshold: router_snapshot_threshold
Expand All @@ -157,8 +164,17 @@ impl KvRouterConfig {
// TODO: is there a way (macro) to auto-derive the KvIndexerInterface trait for this
// since both variants implement it
pub enum Indexer {
/// Updates itself based on KV events emitted by backend workers.
/// Has the ability to persist and snapshot states.
KvIndexer(KvIndexer),

/// Predicts the cached blocks based on requests on a TTL basis.
/// Currently does not persist or snapshot states (WIP to enable that).
ApproxKvIndexer(ApproxKvIndexer),

/// Used when we do not wish to use the indexer at all (e.g., when overlap_score_weight is 0).
/// Note: This will cause KV events to accumulate in JetStream as we do not regularly purge them.
None,
}

impl Indexer {
Expand All @@ -169,13 +185,22 @@ impl Indexer {
match self {
Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await,
Indexer::ApproxKvIndexer(indexer) => indexer.find_matches(sequence).await,
Indexer::None => Ok(OverlapScores {
scores: HashMap::new(),
frequencies: Vec::new(),
}),
}
}

async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
match self {
Indexer::KvIndexer(indexer) => indexer.dump_events().await,
Indexer::ApproxKvIndexer(indexer) => indexer.dump_events().await,
Indexer::None => {
panic!(
"Cannot dump events: indexer does not exist (is overlap_score_weight set to 0?)"
);
}
}
}
}
Expand All @@ -189,6 +214,8 @@ pub struct KvRouter {
scheduler: KvScheduler,

block_size: u32,

kv_router_config: KvRouterConfig,
}

impl KvRouter {
Expand Down Expand Up @@ -234,7 +261,10 @@ impl KvRouter {
.await?;
let runtime_configs_rx = runtime_configs_watcher.receiver();

let indexer = if kv_router_config.use_kv_events {
let indexer = if kv_router_config.overlap_score_weight == 0.0 {
// When overlap_score_weight is zero, we don't need to track prefixes
Indexer::None
} else if kv_router_config.use_kv_events {
let kv_indexer_metrics = indexer::KvIndexerMetrics::from_component(&component);
Indexer::KvIndexer(KvIndexer::new(
cancellation_token.clone(),
Expand All @@ -257,6 +287,7 @@ impl KvRouter {
runtime_configs_rx,
selector,
kv_router_config.router_replica_sync,
consumer_uuid.clone(),
)
.await?;

Expand All @@ -282,6 +313,7 @@ impl KvRouter {
indexer,
scheduler,
block_size,
kv_router_config,
})
}

Expand All @@ -302,12 +334,25 @@ impl KvRouter {

let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;

// Determine who needs seq_hashes
let approx_indexer_needs_it = matches!(self.indexer, Indexer::ApproxKvIndexer(_));
let scheduler_needs_it = self.kv_router_config.router_track_active_blocks;

// Optimize cloning: only clone if both need it, otherwise move
let (maybe_seq_hashes_1, maybe_seq_hashes_2) =
match (approx_indexer_needs_it, scheduler_needs_it) {
(true, true) => (Some(seq_hashes.clone()), Some(seq_hashes)),
(true, false) => (Some(seq_hashes), None),
(false, true) => (None, Some(seq_hashes)),
(false, false) => (None, None),
};

let best_worker_id = self
.scheduler
.schedule(
context_id.to_string(),
isl_tokens,
seq_hashes.clone(),
maybe_seq_hashes_2,
overlap_scores.clone(),
router_config_override,
update_states,
Expand All @@ -316,7 +361,7 @@ impl KvRouter {

if let Indexer::ApproxKvIndexer(ref indexer) = self.indexer {
indexer
.process_routing_decision(best_worker_id, block_hashes, seq_hashes)
.process_routing_decision(best_worker_id, block_hashes, maybe_seq_hashes_1.unwrap())
.await
.unwrap();
};
Expand All @@ -337,25 +382,28 @@ impl KvRouter {
worker_id: i64,
) {
let isl_tokens = tokens.len();
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
let seq_hashes = compute_seq_hash_for_block(&block_hashes);

let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| {
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
compute_seq_hash_for_block(&block_hashes)
});

self.scheduler
.add_request(
request_id,
seq_hashes,
maybe_seq_hashes,
isl_tokens,
overlap_blocks,
worker_id,
)
.await;
}

pub async fn mark_prefill_completed(&self, request_id: &str) {
pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<()> {
self.scheduler.mark_prefill_completed(request_id).await
}

pub async fn free(&self, request_id: &str) {
pub async fn free(&self, request_id: &str) -> Result<()> {
self.scheduler.free(request_id).await
}

Expand All @@ -367,12 +415,16 @@ impl KvRouter {
pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> {
let isl_tokens = tokens.len();
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
let seq_hashes = compute_seq_hash_for_block(&block_hashes);
let overlap_scores = self.indexer.find_matches(block_hashes).await?;

let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| {
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
compute_seq_hash_for_block(&block_hashes)
});

Ok(self
.scheduler
.get_potential_loads(seq_hashes, isl_tokens, overlap_scores)
.get_potential_loads(maybe_seq_hashes, isl_tokens, overlap_scores)
.await)
}

Expand Down Expand Up @@ -404,14 +456,12 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er
overlap_blocks,
}
}
RouterRequest::MarkPrefill => {
self.mark_prefill_completed(&context_id).await;
RouterResponse::PrefillMarked { success: true }
}
RouterRequest::MarkFree => {
self.free(&context_id).await;
RouterResponse::FreeMarked { success: true }
}
RouterRequest::MarkPrefill => RouterResponse::PrefillMarked {
success: self.mark_prefill_completed(&context_id).await.is_ok(),
},
RouterRequest::MarkFree => RouterResponse::FreeMarked {
success: self.free(&context_id).await.is_ok(),
},
};

let response = Annotated::from_data(response);
Expand Down Expand Up @@ -541,15 +591,19 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu

let wrapped_stream = Box::pin(async_stream::stream! {
if let Some(first_item) = response_stream.next().await {
chooser.mark_prefill_completed(&context_id).await;
if let Err(e) = chooser.mark_prefill_completed(&context_id).await {
tracing::warn!("Failed to mark prefill completed for request {context_id}: {e:?}");
}
yield first_item;
}

while let Some(item) = response_stream.next().await {
yield item;
}

chooser.free(&context_id).await;
if let Err(e) = chooser.free(&context_id).await {
tracing::warn!("Failed to free request {context_id}: {e:?}");
}
});
Ok(ResponseStream::new(wrapped_stream, stream_context))
}
Expand Down
Loading
Loading