Skip to content

Commit cfc1fb3

Browse files
committed
publish kv events over js
Signed-off-by: PeaBrane <yanrpei@gmail.com>
1 parent 35a9c34 commit cfc1fb3

File tree

5 files changed

+156
-54
lines changed

5 files changed

+156
-54
lines changed

lib/bindings/python/rust/llm/kv.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -870,12 +870,14 @@ impl KvPushRouter {
870870
// Get component from endpoint
871871
let component = endpoint.inner.component();
872872

873-
// Create KvRouter
873+
// Create KvRouter with a unique consumer UUID
874+
let consumer_uuid = uuid::Uuid::new_v4().to_string();
874875
let kv_router = llm_rs::kv_router::KvRouter::new(
875876
component.clone(),
876877
block_size as u32,
877878
None, // default selector
878879
Some(kv_router_config.inner()),
880+
consumer_uuid,
879881
)
880882
.await
881883
.map_err(to_pyerr)?;

lib/llm/src/discovery/model_manager.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,10 +218,11 @@ impl ModelManager {
218218
.drt()
219219
.etcd_client()
220220
.ok_or_else(|| anyhow::anyhow!("KV routing requires etcd (dynamic mode)"))?;
221+
let router_uuid = uuid::Uuid::new_v4();
221222
let router_key = format!(
222223
"kv_routers/{}/{}",
223224
Slug::from_string(model_name),
224-
uuid::Uuid::new_v4()
225+
router_uuid
225226
);
226227
etcd_client
227228
.kv_create(
@@ -237,6 +238,7 @@ impl ModelManager {
237238
kv_cache_block_size,
238239
Some(selector),
239240
kv_router_config,
241+
router_uuid.to_string(),
240242
)
241243
.await?;
242244
let new_kv_chooser = Arc::new(chooser);

lib/llm/src/kv_router.rs

Lines changed: 82 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ use dynamo_runtime::{
1515
},
1616
prelude::*,
1717
protocols::annotated::Annotated,
18+
transports::nats::NatsQueue,
1819
};
1920
use futures::stream::{self, StreamExt};
2021
use serde::{Deserialize, Serialize};
@@ -47,7 +48,9 @@ use crate::{
4748
protocols::common::llm_backend::LLMEngineOutput,
4849
};
4950

50-
use dynamo_runtime::traits::events::EventSubscriber;
51+
use dynamo_runtime::traits::events::EventPublisher;
52+
use tokio::sync::mpsc;
53+
use tokio_util::sync::CancellationToken;
5154

5255
// [gluo TODO] shouldn't need to be public
5356
// this should be discovered from the component
@@ -56,7 +59,7 @@ use dynamo_runtime::traits::events::EventSubscriber;
5659
pub const KV_METRICS_ENDPOINT: &str = "load_metrics";
5760

5861
// for metric publishing (push-based)
59-
pub const KV_EVENT_SUBJECT: &str = "kv_events";
62+
pub const KV_EVENT_SUBJECT: &str = "kv-events";
6063
pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate";
6164
pub const KV_METRICS_SUBJECT: &str = "kv_metrics";
6265

@@ -165,11 +168,80 @@ pub struct KvRouter {
165168
}
166169

167170
impl KvRouter {
171+
/// Start a background task to consume events from NatsQueue and forward them to the indexer
172+
async fn start_event_consumer(
173+
component: Component,
174+
consumer_uuid: String,
175+
kv_events_tx: mpsc::Sender<RouterEvent>,
176+
cancellation_token: CancellationToken,
177+
) -> Result<()> {
178+
let stream_name = format!("{}_{}", component.subject(), KV_EVENT_SUBJECT)
179+
.replace(['/', '\\', '.', '_'], "-");
180+
let nats_server =
181+
std::env::var("NATS_SERVER").unwrap_or_else(|_| "nats://localhost:4222".to_string());
182+
let mut nats_queue = NatsQueue::new_with_consumer(
183+
stream_name,
184+
nats_server,
185+
std::time::Duration::from_secs(300), // Very long timeout (5 minutes)
186+
consumer_uuid,
187+
);
188+
189+
nats_queue.connect().await?;
190+
191+
tokio::spawn(async move {
192+
loop {
193+
tokio::select! {
194+
_ = cancellation_token.cancelled() => {
195+
tracing::info!("KvRouter event consumer received cancellation signal");
196+
break;
197+
}
198+
result = nats_queue.dequeue_task(Some(std::time::Duration::from_secs(300))) => {
199+
match result {
200+
Ok(Some(bytes)) => {
201+
let event: RouterEvent = match serde_json::from_slice(&bytes) {
202+
Ok(event) => event,
203+
Err(e) => {
204+
tracing::warn!("Failed to deserialize RouterEvent: {:?}", e);
205+
continue;
206+
}
207+
};
208+
209+
// Forward the RouterEvent to the indexer
210+
if let Err(e) = kv_events_tx.send(event).await {
211+
tracing::warn!(
212+
"failed to send kv event to indexer; shutting down: {:?}",
213+
e
214+
);
215+
break;
216+
}
217+
},
218+
Ok(None) => {
219+
tracing::trace!("KvRouter dequeue timeout, continuing");
220+
},
221+
Err(e) => {
222+
tracing::error!("Failed to dequeue task: {:?}", e);
223+
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
224+
}
225+
}
226+
}
227+
}
228+
}
229+
230+
// Clean up the queue and remove the durable consumer
231+
if let Err(e) = nats_queue.shutdown().await {
232+
tracing::warn!("Failed to shutdown NatsQueue: {}", e);
233+
}
234+
});
235+
236+
Ok(())
237+
}
238+
168239
pub async fn new(
169240
component: Component,
170241
block_size: u32,
171242
selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
172243
kv_router_config: Option<KvRouterConfig>,
244+
consumer_uuid: String,
173245
) -> Result<Self> {
174246
let kv_router_config = kv_router_config.unwrap_or_default();
175247

@@ -230,31 +302,15 @@ impl KvRouter {
230302
)
231303
.await?;
232304

233-
// [gluo TODO] try subscribe_with_type::<RouterEvent>,
234-
// error checking below will be different.
305+
// Start event consumer if using KvIndexer
235306
if let Indexer::KvIndexer(ref kv_indexer) = indexer {
236-
let mut kv_events_rx = component.subscribe(KV_EVENT_SUBJECT).await?;
237-
let kv_events_tx = kv_indexer.event_sender();
238-
239-
tokio::spawn(async move {
240-
while let Some(event) = kv_events_rx.next().await {
241-
let event: RouterEvent = match serde_json::from_slice(&event.payload) {
242-
Ok(event) => event,
243-
Err(e) => {
244-
tracing::warn!("Failed to deserialize RouterEvent: {:?}", e);
245-
// Choosing warn and continue to process other events from other workers
246-
// A bad event likely signals a problem with a worker, but potentially other workers are still healthy
247-
continue;
248-
}
249-
};
250-
if let Err(e) = kv_events_tx.send(event).await {
251-
tracing::warn!(
252-
"failed to send kv event to indexer; shutting down: {:?}",
253-
e
254-
);
255-
}
256-
}
257-
});
307+
Self::start_event_consumer(
308+
component.clone(),
309+
consumer_uuid,
310+
kv_indexer.event_sender(),
311+
cancellation_token.clone(),
312+
)
313+
.await?;
258314
}
259315

260316
tracing::info!("KV Routing initialized");

lib/llm/src/kv_router/publisher.rs

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ use dynamo_runtime::{
2929
network::Ingress,
3030
},
3131
protocols::annotated::Annotated,
32+
transports::nats::NatsQueue,
3233
};
3334
use futures::stream;
3435
use std::sync::Arc;
@@ -132,16 +133,26 @@ impl KvEventPublisher {
132133
)?);
133134
}
134135

135-
component
136-
.drt()
137-
.runtime()
138-
.secondary()
139-
.spawn(start_event_processor(
140-
component,
141-
worker_id,
142-
cancellation_token.clone(),
143-
rx,
144-
));
136+
let stream_name = format!("{}.{}", component.subject(), KV_EVENT_SUBJECT)
137+
.replace(['/', '\\', '.', '_'], "-");
138+
let nats_server =
139+
std::env::var("NATS_SERVER").unwrap_or_else(|_| "nats://localhost:4222".to_string());
140+
// Create NatsQueue without consumer since we're only publishing
141+
let mut nats_queue = NatsQueue::new_without_consumer(
142+
stream_name,
143+
nats_server,
144+
std::time::Duration::from_secs(60), // Default timeout
145+
);
146+
147+
// Connect the NatsQueue before passing it to the event processor
148+
let cancellation_token_clone = cancellation_token.clone();
149+
component.drt().runtime().secondary().spawn(async move {
150+
if let Err(e) = nats_queue.connect().await {
151+
tracing::error!("Failed to connect NatsQueue: {}", e);
152+
return;
153+
}
154+
start_event_processor(nats_queue, worker_id, cancellation_token_clone, rx).await
155+
});
145156

146157
Ok(Self {
147158
kv_block_size,
@@ -197,7 +208,7 @@ async fn start_event_processor<P: EventPublisher + Send + Sync + 'static>(
197208
// Encapsulate in a router event and publish.
198209
tracing::trace!("Event processor for worker_id {} processing event: {:?}", worker_id, event.data);
199210
let router_event = RouterEvent::new(worker_id, event);
200-
if let Err(e) = publisher.publish(KV_EVENT_SUBJECT, &router_event).await {
211+
if let Err(e) = publisher.publish("queue", &router_event).await {
201212
tracing::error!("Failed to publish event: {}", e);
202213
}
203214
}
@@ -845,7 +856,7 @@ mod tests_startup_helpers {
845856
let published = published.lock().unwrap();
846857
assert_eq!(published.len(), 1);
847858
let (subject, _) = &published[0];
848-
assert_eq!(subject, &KV_EVENT_SUBJECT.to_string());
859+
assert_eq!(subject, "queue");
849860
}
850861

851862
//--------------------------------------------------------------------

lib/runtime/src/transports/nats.rs

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -450,11 +450,31 @@ pub struct NatsQueue {
450450
}
451451

452452
impl NatsQueue {
453-
/// Create a new NatsQueue with the given configuration
453+
/// Create a new NatsQueue with the default "worker-group" consumer
454454
pub fn new(stream_name: String, nats_server: String, dequeue_timeout: time::Duration) -> Self {
455455
// Sanitize stream name to remove path separators (like in Python version)
456+
// rupei: are we sure NATs stream name accepts '_'?
456457
let sanitized_stream_name = stream_name.replace(['/', '\\'], "_");
458+
let subject = format!("{}.*", sanitized_stream_name);
459+
460+
Self {
461+
stream_name: sanitized_stream_name,
462+
nats_server,
463+
dequeue_timeout,
464+
client: None,
465+
subject,
466+
subscriber: None,
467+
consumer_name: Some("worker-group".to_string()),
468+
}
469+
}
457470

471+
/// Create a new NatsQueue without a consumer (publisher-only mode)
472+
pub fn new_without_consumer(
473+
stream_name: String,
474+
nats_server: String,
475+
dequeue_timeout: time::Duration,
476+
) -> Self {
477+
let sanitized_stream_name = stream_name.replace(['/', '\\'], "_");
458478
let subject = format!("{}.*", sanitized_stream_name);
459479

460480
Self {
@@ -511,20 +531,18 @@ impl NatsQueue {
511531
client.jetstream().create_stream(stream_config).await?;
512532
}
513533

514-
// Create persistent subscriber
515-
let consumer_config = jetstream::consumer::pull::Config {
516-
durable_name: Some(
517-
self.consumer_name
518-
.clone()
519-
.unwrap_or_else(|| "worker-group".to_string()),
520-
),
521-
..Default::default()
522-
};
534+
// Create persistent subscriber only if consumer_name is set
535+
if let Some(ref consumer_name) = self.consumer_name {
536+
let consumer_config = jetstream::consumer::pull::Config {
537+
durable_name: Some(consumer_name.clone()),
538+
..Default::default()
539+
};
523540

524-
let stream = client.jetstream().get_stream(&self.stream_name).await?;
525-
let subscriber = stream.create_consumer(consumer_config).await?;
541+
let stream = client.jetstream().get_stream(&self.stream_name).await?;
542+
let subscriber = stream.create_consumer(consumer_config).await?;
543+
self.subscriber = Some(subscriber);
544+
}
526545

527-
self.subscriber = Some(subscriber);
528546
self.client = Some(client);
529547
}
530548

@@ -767,6 +785,14 @@ impl EventPublisher for NatsQueue {
767785
event_name: impl AsRef<str> + Send + Sync,
768786
bytes: Vec<u8>,
769787
) -> Result<()> {
788+
// Check if event_name is "queue", otherwise warn
789+
if event_name.as_ref() != "queue" {
790+
tracing::warn!(
791+
"Expected event_name to be 'queue', but got '{}'",
792+
event_name.as_ref()
793+
);
794+
}
795+
770796
let subject = format!("{}.{}", self.subject(), event_name.as_ref());
771797

772798
// Note: enqueue_task requires &mut self, but EventPublisher requires &self
@@ -1019,9 +1045,14 @@ mod tests {
10191045
consumer2_name,
10201046
);
10211047

1022-
// Connect both queues (first one creates the stream, second one reuses it)
1048+
// Create a third queue without consumer (publisher-only)
1049+
let mut queue3 =
1050+
NatsQueue::new_without_consumer(stream_name.clone(), nats_server.clone(), timeout);
1051+
1052+
// Connect all queues (first one creates the stream, others reuse it)
10231053
queue1.connect().await.expect("Failed to connect queue1");
10241054
queue2.connect().await.expect("Failed to connect queue2");
1055+
queue3.connect().await.expect("Failed to connect queue3");
10251056

10261057
// Send 4 messages using the EventPublisher trait
10271058
let message_strings = vec![

0 commit comments

Comments
 (0)