Skip to content

Commit 14d6c5b

Browse files
committed
Add migration layer into OpenAI frontend
1 parent dfb096a commit 14d6c5b

File tree

3 files changed

+213
-0
lines changed

3 files changed

+213
-0
lines changed

lib/llm/src/discovery/watcher.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use dynamo_runtime::{
1919
use crate::{
2020
backend::Backend,
2121
kv_router::{KvPushRouter, KvRouterConfig},
22+
migration::Migration,
2223
model_type::ModelType,
2324
preprocessor::{OpenAIPreprocessor, PreprocessedEmbeddingRequest, PreprocessedRequest},
2425
protocols::common::llm_backend::{EmbeddingsEngineOutput, LLMEngineOutput},
@@ -197,12 +198,14 @@ impl ModelWatcher {
197198
// function. Needs checking carefully, possibly we need to store it in state.
198199
let _cache_dir = Some(card.move_from_nats(self.drt.nats_client()).await?);
199200

201+
// Chat Completions
200202
let frontend = SegmentSource::<
201203
SingleIn<NvCreateChatCompletionRequest>,
202204
ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
203205
>::new();
204206
let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator();
205207
let backend = Backend::from_mdc(card.clone()).await?.into_operator();
208+
let migration = Migration::from_mdc(card.clone()).await?.into_operator();
206209
let router =
207210
PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client(
208211
client.clone(),
@@ -231,13 +234,16 @@ impl ModelWatcher {
231234
let chat_engine = frontend
232235
.link(preprocessor.forward_edge())?
233236
.link(backend.forward_edge())?
237+
.link(migration.forward_edge())?
234238
.link(service_backend)?
239+
.link(migration.backward_edge())?
235240
.link(backend.backward_edge())?
236241
.link(preprocessor.backward_edge())?
237242
.link(frontend)?;
238243
self.manager
239244
.add_chat_completions_model(&model_entry.name, chat_engine)?;
240245

246+
// Completions
241247
let frontend = SegmentSource::<
242248
SingleIn<NvCreateCompletionRequest>,
243249
ManyOut<Annotated<NvCreateCompletionResponse>>,

lib/llm/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ pub mod hub;
2222
// pub mod key_value_store;
2323
pub mod kv_router;
2424
pub mod local_model;
25+
pub mod migration;
2526
pub mod mocker;
2627
pub mod model_card;
2728
pub mod model_type;

lib/llm/src/migration.rs

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
use std::sync::Arc;
17+
18+
use anyhow::{Error, Result};
19+
use futures::{stream, stream::StreamExt};
20+
21+
use async_nats::client::{
22+
RequestError as NatsRequestError, RequestErrorKind::NoResponders as NatsNoResponders,
23+
};
24+
use tokenizers::Tokenizer as HfTokenizer;
25+
26+
use crate::{
27+
model_card::model::{ModelDeploymentCard, TokenizerKind},
28+
protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest},
29+
tokenizers::{HuggingFaceTokenizer, Tokenizer},
30+
};
31+
32+
use dynamo_runtime::{
33+
pipeline::{
34+
async_trait, AsyncEngineContext, AsyncEngineContextProvider, ManyOut, Operator,
35+
ResponseStream, ServerStreamingEngine, SingleIn,
36+
},
37+
protocols::{annotated::Annotated, maybe_error::MaybeError},
38+
};
39+
40+
#[allow(dead_code)]
41+
pub struct Migration {
42+
pub tokenizer: Option<Tokenizer>,
43+
}
44+
45+
impl Migration {
46+
pub async fn from_tokenizer(tokenizer: HfTokenizer) -> Result<Arc<Self>> {
47+
let tokenizer = HuggingFaceTokenizer::from_tokenizer(tokenizer);
48+
let tokenizer = Tokenizer::from(Arc::new(tokenizer));
49+
50+
Ok(Arc::new(Self {
51+
tokenizer: Some(tokenizer),
52+
}))
53+
}
54+
55+
pub async fn from_mdc(mdc: ModelDeploymentCard) -> Result<Arc<Self>> {
56+
let tokenizer = match &mdc.tokenizer {
57+
Some(TokenizerKind::HfTokenizerJson(file)) => {
58+
HfTokenizer::from_file(file).map_err(Error::msg)?
59+
}
60+
Some(TokenizerKind::GGUF(t)) => *t.clone(),
61+
None => {
62+
return Ok(Arc::new(Self { tokenizer: None }));
63+
}
64+
};
65+
Self::from_tokenizer(tokenizer).await
66+
}
67+
}
68+
69+
#[async_trait]
70+
impl
71+
Operator<
72+
SingleIn<PreprocessedRequest>,
73+
ManyOut<Annotated<LLMEngineOutput>>,
74+
SingleIn<PreprocessedRequest>,
75+
ManyOut<Annotated<LLMEngineOutput>>,
76+
> for Migration
77+
{
78+
async fn generate(
79+
&self,
80+
request: SingleIn<PreprocessedRequest>,
81+
next: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
82+
) -> Result<ManyOut<Annotated<LLMEngineOutput>>> {
83+
let (preprocessed_request, context) = request.transfer(());
84+
let engine_ctx = context.context();
85+
const MAX_RETRIES: u16 = 3;
86+
let retry_manager =
87+
RetryManager::build(preprocessed_request, engine_ctx.clone(), next, MAX_RETRIES)
88+
.await?;
89+
let response_stream = stream::unfold(retry_manager, |mut retry_manager| async move {
90+
if let Some(response) = retry_manager.next().await {
91+
Some((response, retry_manager))
92+
} else {
93+
None
94+
}
95+
});
96+
Ok(ResponseStream::new(Box::pin(response_stream), engine_ctx))
97+
}
98+
}
99+
100+
#[allow(dead_code)]
101+
struct RetryManager {
102+
request: PreprocessedRequest,
103+
engine_ctx: Arc<dyn AsyncEngineContext>,
104+
next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
105+
next_stream: Option<ManyOut<Annotated<LLMEngineOutput>>>,
106+
retries_left: u16,
107+
}
108+
109+
impl RetryManager {
110+
pub async fn build(
111+
preprocessed_request: PreprocessedRequest,
112+
engine_ctx: Arc<dyn AsyncEngineContext>,
113+
next: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
114+
retries_left: u16,
115+
) -> Result<Self> {
116+
let mut slf = Self {
117+
request: preprocessed_request,
118+
engine_ctx: engine_ctx,
119+
next_generate: next,
120+
next_stream: None,
121+
retries_left: retries_left + 1, // +1 to account for the initial attempt
122+
};
123+
slf.new_stream().await?;
124+
Ok(slf)
125+
}
126+
127+
pub async fn next(&mut self) -> Option<Annotated<LLMEngineOutput>> {
128+
loop {
129+
let response_stream = match self.next_stream.as_mut() {
130+
Some(stream) => stream,
131+
None => {
132+
tracing::error!("next() called with next_stream is None - should not happen");
133+
return Some(Annotated::from_err(
134+
Error::msg("next_stream is None").into(),
135+
));
136+
}
137+
};
138+
if let Some(response) = response_stream.next().await {
139+
if let Some(err) = response.err() {
140+
const STREAM_ERR_MSG: &str = "Stream ended before generation completed";
141+
if format!("{:?}", err) == STREAM_ERR_MSG {
142+
tracing::info!("Stream disconnected... recreating stream...");
143+
if let Err(err) = self.new_stream().await {
144+
tracing::info!("Cannot recreate stream: {:?}", err);
145+
} else {
146+
continue;
147+
}
148+
}
149+
}
150+
self.track_response(&response);
151+
return Some(response);
152+
}
153+
return None;
154+
}
155+
}
156+
157+
async fn new_stream(&mut self) -> Result<()> {
158+
let mut response_stream: Option<Result<ManyOut<Annotated<LLMEngineOutput>>>> = None;
159+
while self.retries_left > 0 {
160+
self.retries_left -= 1;
161+
// TODO: Is there anything needed to pass between context?
162+
let request = SingleIn::new(self.request.clone());
163+
164+
// TODO: Why generate() does not implement Sync?
165+
let next = self.next_generate.clone();
166+
let handle = tokio::spawn(async move { next.generate(request).await });
167+
response_stream = Some(match handle.await {
168+
Ok(response_stream) => response_stream,
169+
Err(err) => {
170+
tracing::error!("Failed to spawn generate stream: {:?}", err);
171+
return Err(Error::msg("Failed to spawn generate stream"));
172+
}
173+
});
174+
175+
if let Some(err) = response_stream.as_ref().unwrap().as_ref().err() {
176+
if let Some(req_err) = err.downcast_ref::<NatsRequestError>() {
177+
if matches!(req_err.kind(), NatsNoResponders) {
178+
tracing::info!("Creating new stream... retrying...");
179+
continue;
180+
}
181+
}
182+
}
183+
break;
184+
}
185+
match response_stream {
186+
Some(Ok(next_stream)) => {
187+
self.next_stream = Some(next_stream);
188+
Ok(())
189+
}
190+
Some(Err(err)) => Err(err), // should propagate streaming error if stream started
191+
None => Err(Error::msg(
192+
"Retries exhausted - should propagate streaming error",
193+
)),
194+
}
195+
}
196+
197+
fn track_response(&mut self, response: &Annotated<LLMEngineOutput>) {
198+
let llm_engine_output = match response.data.as_ref() {
199+
Some(output) => output,
200+
None => return,
201+
};
202+
for token_id in llm_engine_output.token_ids.iter() {
203+
self.request.token_ids.push(*token_id);
204+
}
205+
}
206+
}

0 commit comments

Comments
 (0)