Skip to content

Commit da355f6

Browse files
committed
refactor into background.rs
Signed-off-by: PeaBrane <yanrpei@gmail.com>
1 parent b7f5185 commit da355f6

File tree

4 files changed

+156
-80
lines changed

4 files changed

+156
-80
lines changed

lib/bindings/python/rust/llm/kv.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ impl KvIndexer {
417417
.into();
418418

419419
// Use the shared start_event_consumer function instead of duplicating the logic
420-
llm_rs::kv_router::start_event_consumer(
420+
llm_rs::kv_router::background::start_event_consumer(
421421
component.inner.clone(),
422422
consumer_uuid.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
423423
inner.event_sender(),

lib/llm/src/kv_router.rs

Lines changed: 32 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@ use dynamo_runtime::{
1515
},
1616
prelude::*,
1717
protocols::annotated::Annotated,
18-
transports::nats::NatsQueue,
1918
};
2019
use futures::stream::{self, StreamExt};
2120
use serde::{Deserialize, Serialize};
2221

2322
pub mod approx;
23+
pub mod background;
2424
pub mod indexer;
2525
pub mod metrics_aggregator;
2626
pub mod prefill_counter;
@@ -35,8 +35,9 @@ use crate::{
3535
discovery::{MODEL_ROOT_PATH, ModelEntry},
3636
kv_router::{
3737
approx::ApproxKvIndexer,
38+
background::{start_event_consumer, start_radix_uploader},
3839
indexer::{
39-
KvIndexer, KvIndexerInterface, KvRouterError, OverlapScores, RouterEvent,
40+
KvIndexer, KvIndexerInterface, KvRouterError, OverlapScores, RadixUploader,
4041
compute_block_hash_for_seq, compute_seq_hash_for_block,
4142
},
4243
protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult},
@@ -48,10 +49,6 @@ use crate::{
4849
protocols::common::llm_backend::LLMEngineOutput,
4950
};
5051

51-
use dynamo_runtime::traits::events::EventPublisher;
52-
use tokio::sync::mpsc;
53-
use tokio_util::sync::CancellationToken;
54-
5552
// [gluo TODO] shouldn't need to be public
5653
// this should be discovered from the component
5754

@@ -165,74 +162,9 @@ pub struct KvRouter {
165162
scheduler: KvScheduler,
166163

167164
block_size: u32,
168-
}
169165

170-
/// Start a background task to consume events from NatsQueue and forward them to the indexer
171-
pub async fn start_event_consumer(
172-
component: Component,
173-
consumer_uuid: String,
174-
kv_events_tx: mpsc::Sender<RouterEvent>,
175-
cancellation_token: CancellationToken,
176-
) -> Result<()> {
177-
let stream_name =
178-
format!("{}.{}", component.subject(), KV_EVENT_SUBJECT).replace(['/', '\\', '.', '_'], "-");
179-
let nats_server =
180-
std::env::var("NATS_SERVER").unwrap_or_else(|_| "nats://localhost:4222".to_string());
181-
let mut nats_queue = NatsQueue::new_with_consumer(
182-
stream_name,
183-
nats_server,
184-
std::time::Duration::from_secs(300), // Very long timeout (5 minutes)
185-
consumer_uuid,
186-
);
187-
188-
nats_queue.connect().await?;
189-
190-
tokio::spawn(async move {
191-
loop {
192-
tokio::select! {
193-
_ = cancellation_token.cancelled() => {
194-
tracing::info!("Event consumer received cancellation signal");
195-
break;
196-
}
197-
result = nats_queue.dequeue_task(Some(std::time::Duration::from_secs(300))) => {
198-
match result {
199-
Ok(Some(bytes)) => {
200-
let event: RouterEvent = match serde_json::from_slice(&bytes) {
201-
Ok(event) => event,
202-
Err(e) => {
203-
tracing::warn!("Failed to deserialize RouterEvent: {:?}", e);
204-
continue;
205-
}
206-
};
207-
208-
// Forward the RouterEvent to the indexer
209-
if let Err(e) = kv_events_tx.send(event).await {
210-
tracing::warn!(
211-
"failed to send kv event to indexer; shutting down: {:?}",
212-
e
213-
);
214-
break;
215-
}
216-
},
217-
Ok(None) => {
218-
tracing::trace!("Dequeue timeout, continuing");
219-
},
220-
Err(e) => {
221-
tracing::error!("Failed to dequeue task: {:?}", e);
222-
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
223-
}
224-
}
225-
}
226-
}
227-
}
228-
229-
// Clean up the queue and remove the durable consumer
230-
if let Err(e) = nats_queue.shutdown().await {
231-
tracing::warn!("Failed to shutdown NatsQueue: {}", e);
232-
}
233-
});
234-
235-
Ok(())
166+
/// Optional radix tree uploader for snapshot persistence
167+
radix_uploader: Option<RadixUploader>,
236168
}
237169

238170
impl KvRouter {
@@ -302,22 +234,35 @@ impl KvRouter {
302234
)
303235
.await?;
304236

305-
// Start event consumer if using KvIndexer
306-
if let Indexer::KvIndexer(ref kv_indexer) = indexer {
237+
// Start event consumer and radix uploader if using KvIndexer
238+
let radix_uploader = if let Indexer::KvIndexer(ref kv_indexer) = indexer {
307239
start_event_consumer(
308240
component.clone(),
309241
consumer_uuid,
310242
kv_indexer.event_sender(),
311243
cancellation_token.clone(),
312244
)
313245
.await?;
314-
}
246+
247+
// Start radix uploader for snapshot persistence
248+
let uploader = start_radix_uploader(
249+
component.clone(),
250+
kv_indexer.snapshot_event_sender(),
251+
Duration::from_secs(30),
252+
cancellation_token.clone(),
253+
)
254+
.await?;
255+
Some(uploader)
256+
} else {
257+
None
258+
};
315259

316260
tracing::info!("KV Routing initialized");
317261
Ok(Self {
318262
indexer,
319263
scheduler,
320264
block_size,
265+
radix_uploader,
321266
})
322267
}
323268

@@ -374,6 +319,17 @@ impl KvRouter {
374319
pub fn block_size(&self) -> u32 {
375320
self.block_size
376321
}
322+
323+
/// Upload a snapshot of the radix tree to NATS object store
324+
pub async fn upload_snapshot(&self) -> Result<(), KvRouterError> {
325+
match &self.radix_uploader {
326+
Some(uploader) => uploader.upload_snapshot().await,
327+
None => {
328+
tracing::warn!("Radix uploader not available (likely using ApproxKvIndexer)");
329+
Ok(())
330+
}
331+
}
332+
}
377333
}
378334

379335
// NOTE: this would not be usable for now, should deprecate

lib/llm/src/kv_router/approx.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,9 @@ use tokio_util::sync::CancellationToken;
2525

2626
use crate::tokens::{SequenceHash, TokenBlockSequence};
2727

28-
use crate::kv_router::RouterEvent;
2928
use crate::kv_router::indexer::{
30-
DumpRequest, KvIndexerInterface, KvRouterError, OverlapScores, RadixTree, WorkerId,
31-
compute_block_hash_for_seq,
29+
DumpRequest, KvIndexerInterface, KvRouterError, OverlapScores, RadixTree, RouterEvent,
30+
WorkerId, compute_block_hash_for_seq,
3231
};
3332
use crate::kv_router::protocols::{
3433
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData,
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
//! Background processes for the KV Router including event consumption and snapshot uploads.
5+
6+
use std::time::Duration;
7+
8+
use anyhow::Result;
9+
use dynamo_runtime::{
10+
component::Component, traits::events::EventPublisher, transports::nats::NatsQueue,
11+
};
12+
use tokio::sync::mpsc;
13+
use tokio_util::sync::CancellationToken;
14+
15+
use crate::kv_router::{
16+
KV_EVENT_SUBJECT,
17+
indexer::{DumpRequest, RadixUploader, RouterEvent},
18+
};
19+
20+
/// Start a background task to consume events from NatsQueue and forward them to the indexer
21+
pub async fn start_event_consumer(
22+
component: Component,
23+
consumer_uuid: String,
24+
kv_events_tx: mpsc::Sender<RouterEvent>,
25+
cancellation_token: CancellationToken,
26+
) -> Result<()> {
27+
let stream_name =
28+
format!("{}.{}", component.subject(), KV_EVENT_SUBJECT).replace(['/', '\\', '.', '_'], "-");
29+
let nats_server =
30+
std::env::var("NATS_SERVER").unwrap_or_else(|_| "nats://localhost:4222".to_string());
31+
let mut nats_queue = NatsQueue::new_with_consumer(
32+
stream_name,
33+
nats_server,
34+
std::time::Duration::from_secs(300), // Very long timeout (5 minutes)
35+
consumer_uuid,
36+
);
37+
38+
nats_queue.connect().await?;
39+
40+
tokio::spawn(async move {
41+
loop {
42+
tokio::select! {
43+
_ = cancellation_token.cancelled() => {
44+
tracing::info!("Event consumer received cancellation signal");
45+
break;
46+
}
47+
result = nats_queue.dequeue_task(Some(std::time::Duration::from_secs(300))) => {
48+
match result {
49+
Ok(Some(bytes)) => {
50+
let event: RouterEvent = match serde_json::from_slice(&bytes) {
51+
Ok(event) => event,
52+
Err(e) => {
53+
tracing::warn!("Failed to deserialize RouterEvent: {:?}", e);
54+
continue;
55+
}
56+
};
57+
58+
// Forward the RouterEvent to the indexer
59+
if let Err(e) = kv_events_tx.send(event).await {
60+
tracing::warn!(
61+
"failed to send kv event to indexer; shutting down: {:?}",
62+
e
63+
);
64+
break;
65+
}
66+
},
67+
Ok(None) => {
68+
tracing::trace!("Dequeue timeout, continuing");
69+
},
70+
Err(e) => {
71+
tracing::error!("Failed to dequeue task: {:?}", e);
72+
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
73+
}
74+
}
75+
}
76+
}
77+
}
78+
79+
// Clean up the queue and remove the durable consumer
80+
if let Err(e) = nats_queue.shutdown().await {
81+
tracing::warn!("Failed to shutdown NatsQueue: {}", e);
82+
}
83+
});
84+
85+
Ok(())
86+
}
87+
88+
/// Start a RadixUploader for periodic snapshot uploads to NATS object store
89+
pub async fn start_radix_uploader(
90+
component: Component,
91+
snapshot_tx: mpsc::Sender<DumpRequest>,
92+
upload_interval: Duration,
93+
cancellation_token: CancellationToken,
94+
) -> Result<RadixUploader> {
95+
// Create NATS client
96+
let nats_server =
97+
std::env::var("NATS_SERVER").unwrap_or_else(|_| "nats://localhost:4222".to_string());
98+
let client_options = dynamo_runtime::transports::nats::Client::builder()
99+
.server(&nats_server)
100+
.build()?;
101+
let nats_client = client_options.connect().await?;
102+
103+
// Create bucket name from component name
104+
let bucket_name =
105+
format!("{}-router-snapshot", component.name()).replace(['/', '\\', '.', '_'], "-");
106+
107+
let uploader = RadixUploader::new(
108+
nats_client,
109+
snapshot_tx,
110+
upload_interval,
111+
bucket_name.clone(),
112+
cancellation_token,
113+
);
114+
115+
tracing::info!(
116+
"RadixUploader initialized with bucket: {}, interval: {:?}",
117+
bucket_name,
118+
upload_interval
119+
);
120+
Ok(uploader)
121+
}

0 commit comments

Comments
 (0)