Skip to content

Commit c1dd5dc

Browse files
committed
Add migration layer into OpenAI frontend
1 parent dfb096a commit c1dd5dc

File tree

3 files changed

+212
-0
lines changed

3 files changed

+212
-0
lines changed

lib/llm/src/discovery/watcher.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use dynamo_runtime::{
1919
use crate::{
2020
backend::Backend,
2121
kv_router::{KvPushRouter, KvRouterConfig},
22+
migration::Migration,
2223
model_type::ModelType,
2324
preprocessor::{OpenAIPreprocessor, PreprocessedEmbeddingRequest, PreprocessedRequest},
2425
protocols::common::llm_backend::{EmbeddingsEngineOutput, LLMEngineOutput},
@@ -197,12 +198,14 @@ impl ModelWatcher {
197198
// function. Needs checking carefully, possibly we need to store it in state.
198199
let _cache_dir = Some(card.move_from_nats(self.drt.nats_client()).await?);
199200

201+
// Chat Completions
200202
let frontend = SegmentSource::<
201203
SingleIn<NvCreateChatCompletionRequest>,
202204
ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
203205
>::new();
204206
let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator();
205207
let backend = Backend::from_mdc(card.clone()).await?.into_operator();
208+
let migration = Migration::from_mdc(card.clone()).await?.into_operator();
206209
let router =
207210
PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client(
208211
client.clone(),
@@ -231,13 +234,16 @@ impl ModelWatcher {
231234
let chat_engine = frontend
232235
.link(preprocessor.forward_edge())?
233236
.link(backend.forward_edge())?
237+
.link(migration.forward_edge())?
234238
.link(service_backend)?
239+
.link(migration.backward_edge())?
235240
.link(backend.backward_edge())?
236241
.link(preprocessor.backward_edge())?
237242
.link(frontend)?;
238243
self.manager
239244
.add_chat_completions_model(&model_entry.name, chat_engine)?;
240245

246+
// Completions
241247
let frontend = SegmentSource::<
242248
SingleIn<NvCreateCompletionRequest>,
243249
ManyOut<Annotated<NvCreateCompletionResponse>>,

lib/llm/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ pub mod hub;
2222
// pub mod key_value_store;
2323
pub mod kv_router;
2424
pub mod local_model;
25+
pub mod migration;
2526
pub mod mocker;
2627
pub mod model_card;
2728
pub mod model_type;

lib/llm/src/migration.rs

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
use std::sync::Arc;
17+
18+
use anyhow::{Error, Result};
19+
use futures::{stream, stream::StreamExt};
20+
21+
use async_nats::client::{
22+
RequestError as NatsRequestError, RequestErrorKind::NoResponders as NatsNoResponders,
23+
};
24+
use tokenizers::Tokenizer as HfTokenizer;
25+
26+
use crate::{
27+
model_card::model::{ModelDeploymentCard, TokenizerKind},
28+
protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest},
29+
tokenizers::{HuggingFaceTokenizer, Tokenizer},
30+
};
31+
32+
use dynamo_runtime::{
33+
pipeline::{
34+
async_trait, AsyncEngineContext, AsyncEngineContextProvider, ManyOut, Operator,
35+
ResponseStream, ServerStreamingEngine, SingleIn,
36+
},
37+
protocols::{annotated::Annotated, maybe_error::MaybeError},
38+
};
39+
40+
#[allow(dead_code)]
41+
pub struct Migration {
42+
pub tokenizer: Option<Tokenizer>,
43+
}
44+
45+
impl Migration {
46+
pub async fn from_tokenizer(tokenizer: HfTokenizer) -> Result<Arc<Self>> {
47+
let tokenizer = HuggingFaceTokenizer::from_tokenizer(tokenizer);
48+
let tokenizer = Tokenizer::from(Arc::new(tokenizer));
49+
50+
Ok(Arc::new(Self {
51+
tokenizer: Some(tokenizer),
52+
}))
53+
}
54+
55+
pub async fn from_mdc(mdc: ModelDeploymentCard) -> Result<Arc<Self>> {
56+
let tokenizer = match &mdc.tokenizer {
57+
Some(TokenizerKind::HfTokenizerJson(file)) => {
58+
HfTokenizer::from_file(file).map_err(Error::msg)?
59+
}
60+
Some(TokenizerKind::GGUF(t)) => *t.clone(),
61+
None => {
62+
return Ok(Arc::new(Self { tokenizer: None }));
63+
}
64+
};
65+
Self::from_tokenizer(tokenizer).await
66+
}
67+
}
68+
69+
#[async_trait]
70+
impl
71+
Operator<
72+
SingleIn<PreprocessedRequest>,
73+
ManyOut<Annotated<LLMEngineOutput>>,
74+
SingleIn<PreprocessedRequest>,
75+
ManyOut<Annotated<LLMEngineOutput>>,
76+
> for Migration
77+
{
78+
async fn generate(
79+
&self,
80+
request: SingleIn<PreprocessedRequest>,
81+
next: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
82+
) -> Result<ManyOut<Annotated<LLMEngineOutput>>> {
83+
let (preprocessed_request, context) = request.transfer(());
84+
let engine_ctx = context.context();
85+
const MAX_RETRIES: u16 = 3;
86+
let retry_manager =
87+
RetryManager::build(preprocessed_request, engine_ctx.clone(), next, MAX_RETRIES)
88+
.await?;
89+
let response_stream = stream::unfold(retry_manager, |mut retry_manager| async move {
90+
retry_manager
91+
.next()
92+
.await
93+
.map(|response| (response, retry_manager))
94+
});
95+
Ok(ResponseStream::new(Box::pin(response_stream), engine_ctx))
96+
}
97+
}
98+
99+
#[allow(dead_code)]
100+
struct RetryManager {
101+
request: PreprocessedRequest,
102+
engine_ctx: Arc<dyn AsyncEngineContext>,
103+
next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
104+
next_stream: Option<ManyOut<Annotated<LLMEngineOutput>>>,
105+
retries_left: u16,
106+
}
107+
108+
impl RetryManager {
109+
pub async fn build(
110+
preprocessed_request: PreprocessedRequest,
111+
engine_ctx: Arc<dyn AsyncEngineContext>,
112+
next: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
113+
retries_left: u16,
114+
) -> Result<Self> {
115+
let mut slf = Self {
116+
request: preprocessed_request,
117+
engine_ctx,
118+
next_generate: next,
119+
next_stream: None,
120+
retries_left: retries_left + 1, // +1 to account for the initial attempt
121+
};
122+
slf.new_stream().await?;
123+
Ok(slf)
124+
}
125+
126+
pub async fn next(&mut self) -> Option<Annotated<LLMEngineOutput>> {
127+
loop {
128+
let response_stream = match self.next_stream.as_mut() {
129+
Some(stream) => stream,
130+
None => {
131+
tracing::error!("next() called with next_stream is None - should not happen");
132+
return Some(Annotated::from_err(
133+
Error::msg("next_stream is None").into(),
134+
));
135+
}
136+
};
137+
if let Some(response) = response_stream.next().await {
138+
if let Some(err) = response.err() {
139+
const STREAM_ERR_MSG: &str = "Stream ended before generation completed";
140+
if format!("{:?}", err) == STREAM_ERR_MSG {
141+
tracing::info!("Stream disconnected... recreating stream...");
142+
if let Err(err) = self.new_stream().await {
143+
tracing::info!("Cannot recreate stream: {:?}", err);
144+
} else {
145+
continue;
146+
}
147+
}
148+
}
149+
self.track_response(&response);
150+
return Some(response);
151+
}
152+
return None;
153+
}
154+
}
155+
156+
async fn new_stream(&mut self) -> Result<()> {
157+
let mut response_stream: Option<Result<ManyOut<Annotated<LLMEngineOutput>>>> = None;
158+
while self.retries_left > 0 {
159+
self.retries_left -= 1;
160+
// TODO: Is there anything needed to pass between context?
161+
let request = SingleIn::new(self.request.clone());
162+
163+
// TODO: Why generate() does not implement Sync?
164+
let next = self.next_generate.clone();
165+
let handle = tokio::spawn(async move { next.generate(request).await });
166+
response_stream = Some(match handle.await {
167+
Ok(response_stream) => response_stream,
168+
Err(err) => {
169+
tracing::error!("Failed to spawn generate stream: {:?}", err);
170+
return Err(Error::msg("Failed to spawn generate stream"));
171+
}
172+
});
173+
174+
if let Some(err) = response_stream.as_ref().unwrap().as_ref().err() {
175+
if let Some(req_err) = err.downcast_ref::<NatsRequestError>() {
176+
if matches!(req_err.kind(), NatsNoResponders) {
177+
tracing::info!("Creating new stream... retrying...");
178+
continue;
179+
}
180+
}
181+
}
182+
break;
183+
}
184+
match response_stream {
185+
Some(Ok(next_stream)) => {
186+
self.next_stream = Some(next_stream);
187+
Ok(())
188+
}
189+
Some(Err(err)) => Err(err), // should propagate streaming error if stream started
190+
None => Err(Error::msg(
191+
"Retries exhausted - should propagate streaming error",
192+
)),
193+
}
194+
}
195+
196+
fn track_response(&mut self, response: &Annotated<LLMEngineOutput>) {
197+
let llm_engine_output = match response.data.as_ref() {
198+
Some(output) => output,
199+
None => return,
200+
};
201+
for token_id in llm_engine_output.token_ids.iter() {
202+
self.request.token_ids.push(*token_id);
203+
}
204+
}
205+
}

0 commit comments

Comments
 (0)