Skip to content

Commit cbe854f

Browse files
authored
feat: [vLLM] implement cli args for tool and reasoning parsers (#2619)
1 parent b658ba6 commit cbe854f

File tree

13 files changed

+183
-58
lines changed

13 files changed

+183
-58
lines changed

components/backends/vllm/src/dynamo/vllm/args.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ class Config:
5858
# Connector list from CLI
5959
connector_list: Optional[list] = None
6060

61+
# tool and reasoning parser info
62+
tool_call_parser: Optional[str] = None
63+
reasoning_parser: Optional[str] = None
64+
6165

6266
def parse_args() -> Config:
6367
parser = FlexibleArgumentParser(
@@ -102,6 +106,19 @@ def parse_args() -> Config:
102106
help="List of connectors to use in order (e.g., --connector nixl lmcache). "
103107
"Options: nixl, lmcache, kvbm, null, none. Default: nixl. Order will be preserved in MultiConnector.",
104108
)
109+
# To avoid name conflicts with different backends, adoped prefix "dyn-" for dynamo specific args
110+
parser.add_argument(
111+
"--dyn-tool-call-parser",
112+
type=str,
113+
default=None,
114+
help="Tool call parser name for the model. Available options: 'hermes', 'nemotron_deci', 'llama3_json', 'mistral', 'phi4'.",
115+
)
116+
parser.add_argument(
117+
"--dyn-reasoning-parser",
118+
type=str,
119+
default=None,
120+
help="Reasoning parser name for the model.",
121+
)
105122

106123
parser = AsyncEngineArgs.add_cli_args(parser)
107124
args = parser.parse_args()
@@ -151,7 +168,8 @@ def parse_args() -> Config:
151168
config.port_range = DynamoPortRange(
152169
min=args.dynamo_port_min, max=args.dynamo_port_max
153170
)
154-
171+
config.tool_call_parser = args.dyn_tool_call_parser
172+
config.reasoning_parser = args.dyn_reasoning_parser
155173
# Check for conflicting flags
156174
has_kv_transfer_config = (
157175
hasattr(engine_args, "kv_transfer_config")

components/backends/vllm/src/dynamo/vllm/main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,8 @@ async def init(runtime: DistributedRuntime, config: Config):
234234
runtime_config.total_kv_blocks = runtime_values["num_gpu_blocks"]
235235
runtime_config.max_num_seqs = runtime_values["max_num_seqs"]
236236
runtime_config.max_num_batched_tokens = runtime_values["max_num_batched_tokens"]
237+
runtime_config.tool_call_parser = config.tool_call_parser
238+
runtime_config.reasoning_parser = config.reasoning_parser
237239

238240
await register_llm(
239241
ModelType.Backend,

lib/bindings/python/rust/llm/local_model.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,16 @@ impl ModelRuntimeConfig {
3434
self.inner.max_num_batched_tokens = Some(max_num_batched_tokens);
3535
}
3636

37+
#[setter]
38+
fn set_tool_call_parser(&mut self, tool_call_parser: Option<String>) {
39+
self.inner.tool_call_parser = tool_call_parser;
40+
}
41+
42+
#[setter]
43+
fn set_reasoning_parser(&mut self, reasoning_parser: Option<String>) {
44+
self.inner.reasoning_parser = reasoning_parser;
45+
}
46+
3747
fn set_engine_specific(&mut self, key: &str, value: String) -> PyResult<()> {
3848
let value: serde_json::Value = serde_json::from_str(&value).map_err(to_pyerr)?;
3949
self.inner
@@ -57,6 +67,16 @@ impl ModelRuntimeConfig {
5767
self.inner.max_num_batched_tokens
5868
}
5969

70+
#[getter]
71+
fn tool_call_parser(&self) -> Option<String> {
72+
self.inner.tool_call_parser.clone()
73+
}
74+
75+
#[getter]
76+
fn reasoning_parser(&self) -> Option<String> {
77+
self.inner.reasoning_parser.clone()
78+
}
79+
6080
#[getter]
6181
fn runtime_data(&self, py: Python<'_>) -> PyResult<PyObject> {
6282
let dict = PyDict::new(py);

lib/llm/src/discovery/model_manager.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,18 @@ impl ModelManager {
246246
.insert(model_name.to_string(), new_kv_chooser.clone());
247247
Ok(new_kv_chooser)
248248
}
249+
250+
pub fn get_model_tool_call_parser(&self, model: &str) -> Option<String> {
251+
match self.entries.lock() {
252+
Ok(entries) => entries
253+
.values()
254+
.find(|entry| entry.name == model)
255+
.and_then(|entry| entry.runtime_config.as_ref())
256+
.and_then(|config| config.tool_call_parser.clone())
257+
.map(|parser| parser.to_string()),
258+
Err(_) => None,
259+
}
260+
}
249261
}
250262

251263
pub struct ModelEngines<E> {

lib/llm/src/http/service/openai.rs

Lines changed: 43 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ use crate::protocols::openai::{
3737
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
3838
embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
3939
responses::{NvCreateResponse, NvResponse},
40+
ParsingOptions,
4041
};
4142
use crate::request_template::RequestTemplate;
4243
use crate::types::Annotated;
@@ -194,6 +195,13 @@ fn get_or_create_request_id(primary: Option<&str>, headers: &HeaderMap) -> Strin
194195
uuid.to_string()
195196
}
196197

198+
fn get_parsing_options(state: &Arc<service_v2::State>, model: &str) -> ParsingOptions {
199+
let tool_call_parser = state.manager().get_model_tool_call_parser(model);
200+
let reasoning_parser = None; // TODO: Implement reasoning parser
201+
202+
ParsingOptions::new(tool_call_parser, reasoning_parser)
203+
}
204+
197205
/// OpenAI Completions Request Handler
198206
///
199207
/// This method will handle the incoming request for the `/v1/completions endpoint`. The endpoint is a "source"
@@ -267,6 +275,8 @@ async fn completions(
267275
.get_completions_engine(model)
268276
.map_err(|_| ErrorMessage::model_not_found())?;
269277

278+
let parsing_options = get_parsing_options(&state, model);
279+
270280
let mut inflight_guard =
271281
state
272282
.metrics_clone()
@@ -325,7 +335,7 @@ async fn completions(
325335
process_metrics_only(response, &mut response_collector);
326336
});
327337

328-
let response = NvCreateCompletionResponse::from_annotated_stream(stream)
338+
let response = NvCreateCompletionResponse::from_annotated_stream(stream, parsing_options)
329339
.await
330340
.map_err(|e| {
331341
tracing::error!(
@@ -494,6 +504,8 @@ async fn chat_completions(
494504
.get_chat_completions_engine(model)
495505
.map_err(|_| ErrorMessage::model_not_found())?;
496506

507+
let parsing_options = get_parsing_options(&state, model);
508+
497509
let mut inflight_guard =
498510
state
499511
.metrics_clone()
@@ -553,19 +565,20 @@ async fn chat_completions(
553565
process_metrics_only(response, &mut response_collector);
554566
});
555567

556-
let response = NvCreateChatCompletionResponse::from_annotated_stream(stream)
557-
.await
558-
.map_err(|e| {
559-
tracing::error!(
560-
request_id,
561-
"Failed to fold chat completions stream for: {:?}",
562-
e
563-
);
564-
ErrorMessage::internal_server_error(&format!(
565-
"Failed to fold chat completions stream: {}",
566-
e
567-
))
568-
})?;
568+
let response =
569+
NvCreateChatCompletionResponse::from_annotated_stream(stream, parsing_options.clone())
570+
.await
571+
.map_err(|e| {
572+
tracing::error!(
573+
request_id,
574+
"Failed to fold chat completions stream for: {:?}",
575+
e
576+
);
577+
ErrorMessage::internal_server_error(&format!(
578+
"Failed to fold chat completions stream: {}",
579+
e
580+
))
581+
})?;
569582

570583
inflight_guard.mark_ok();
571584
Ok(Json(response).into_response())
@@ -726,6 +739,8 @@ async fn responses(
726739
.get_chat_completions_engine(model)
727740
.map_err(|_| ErrorMessage::model_not_found())?;
728741

742+
let parsing_options = get_parsing_options(&state, model);
743+
729744
let mut inflight_guard =
730745
state
731746
.metrics_clone()
@@ -742,19 +757,20 @@ async fn responses(
742757
.map_err(|e| ErrorMessage::from_anyhow(e, "Failed to generate completions"))?;
743758

744759
// TODO: handle streaming, currently just unary
745-
let response = NvCreateChatCompletionResponse::from_annotated_stream(stream)
746-
.await
747-
.map_err(|e| {
748-
tracing::error!(
749-
request_id,
750-
"Failed to fold chat completions stream for: {:?}",
751-
e
752-
);
753-
ErrorMessage::internal_server_error(&format!(
754-
"Failed to fold chat completions stream: {}",
755-
e
756-
))
757-
})?;
760+
let response =
761+
NvCreateChatCompletionResponse::from_annotated_stream(stream, parsing_options.clone())
762+
.await
763+
.map_err(|e| {
764+
tracing::error!(
765+
request_id,
766+
"Failed to fold chat completions stream for: {:?}",
767+
e
768+
);
769+
ErrorMessage::internal_server_error(&format!(
770+
"Failed to fold chat completions stream: {}",
771+
e
772+
))
773+
})?;
758774

759775
// Convert NvCreateChatCompletionResponse --> NvResponse
760776
let response: NvResponse = response.try_into().map_err(|e| {

lib/llm/src/local_model.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ impl LocalModelBuilder {
202202
);
203203
card.migration_limit = self.migration_limit;
204204
card.user_data = self.user_data.take();
205+
205206
return Ok(LocalModel {
206207
card,
207208
full_path: PathBuf::new(),
@@ -392,6 +393,7 @@ impl LocalModel {
392393
let kvstore: Box<dyn KeyValueStore> = Box::new(EtcdStorage::new(etcd_client.clone()));
393394
let card_store = Arc::new(KeyValueStoreManager::new(kvstore));
394395
let key = self.card.slug().to_string();
396+
395397
card_store
396398
.publish(model_card::ROOT_PATH, None, &key, &mut self.card)
397399
.await?;

lib/llm/src/local_model/runtime_config.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ pub struct ModelRuntimeConfig {
1313

1414
pub max_num_batched_tokens: Option<u64>,
1515

16+
pub tool_call_parser: Option<String>,
17+
18+
pub reasoning_parser: Option<String>,
19+
1620
/// Mapping of engine-specific runtime configs
1721
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
1822
pub runtime_data: HashMap<String, serde_json::Value>,

lib/llm/src/preprocessor.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,6 @@ impl OpenAIPreprocessor {
101101
let mdcsum = mdc.mdcsum();
102102
let formatter = PromptFormatter::from_mdc(mdc.clone()).await?;
103103
let PromptFormatter::OAI(formatter) = formatter;
104-
105104
let tokenizer = match &mdc.tokenizer {
106105
Some(TokenizerKind::HfTokenizerJson(file)) => HuggingFaceTokenizer::from_file(file)?,
107106
Some(TokenizerKind::GGUF(tokenizer)) => {

lib/llm/src/protocols/openai.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,19 @@ pub trait DeltaGeneratorExt<ResponseType: Send + 'static + std::fmt::Debug>:
193193
/// Gets the current prompt token count (Input Sequence Length).
194194
fn get_isl(&self) -> Option<u32>;
195195
}
196+
197+
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
198+
pub struct ParsingOptions {
199+
pub tool_call_parser: Option<String>,
200+
201+
pub reasoning_parser: Option<String>,
202+
}
203+
204+
impl ParsingOptions {
205+
pub fn new(tool_call_parser: Option<String>, reasoning_parser: Option<String>) -> Self {
206+
Self {
207+
tool_call_parser,
208+
reasoning_parser,
209+
}
210+
}
211+
}

lib/llm/src/protocols/openai/chat_completions/aggregator.rs

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ use std::collections::HashMap;
1919
use super::{NvCreateChatCompletionResponse, NvCreateChatCompletionStreamResponse};
2020
use crate::protocols::{
2121
codec::{Message, SseCodecError},
22-
convert_sse_stream, Annotated,
22+
convert_sse_stream,
23+
openai::ParsingOptions,
24+
Annotated,
2325
};
2426

2527
use dynamo_parsers::tool_calling::try_tool_call_parse_aggregate;
@@ -99,6 +101,7 @@ impl DeltaAggregator {
99101
/// * `Err(String)` if an error occurs during processing.
100102
pub async fn apply(
101103
stream: impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>>,
104+
parsing_options: ParsingOptions,
102105
) -> Result<NvCreateChatCompletionResponse, String> {
103106
let aggregator = stream
104107
.fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
@@ -175,7 +178,10 @@ impl DeltaAggregator {
175178
// After aggregation, inspect each choice's text for tool call syntax
176179
for choice in aggregator.choices.values_mut() {
177180
if choice.tool_calls.is_none() {
178-
if let Ok(tool_calls) = try_tool_call_parse_aggregate(&choice.text, None) {
181+
if let Ok(tool_calls) = try_tool_call_parse_aggregate(
182+
&choice.text,
183+
parsing_options.tool_call_parser.as_deref(),
184+
) {
179185
if tool_calls.is_empty() {
180186
continue;
181187
}
@@ -262,6 +268,7 @@ pub trait ChatCompletionAggregator {
262268
/// * `Err(String)` if an error occurs.
263269
async fn from_annotated_stream(
264270
stream: impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>>,
271+
parsing_options: ParsingOptions,
265272
) -> Result<NvCreateChatCompletionResponse, String>;
266273

267274
/// Converts an SSE stream into a [`NvCreateChatCompletionResponse`].
@@ -274,21 +281,24 @@ pub trait ChatCompletionAggregator {
274281
/// * `Err(String)` if an error occurs.
275282
async fn from_sse_stream(
276283
stream: DataStream<Result<Message, SseCodecError>>,
284+
parsing_options: ParsingOptions,
277285
) -> Result<NvCreateChatCompletionResponse, String>;
278286
}
279287

280288
impl ChatCompletionAggregator for dynamo_async_openai::types::CreateChatCompletionResponse {
281289
async fn from_annotated_stream(
282290
stream: impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>>,
291+
parsing_options: ParsingOptions,
283292
) -> Result<NvCreateChatCompletionResponse, String> {
284-
DeltaAggregator::apply(stream).await
293+
DeltaAggregator::apply(stream, parsing_options).await
285294
}
286295

287296
async fn from_sse_stream(
288297
stream: DataStream<Result<Message, SseCodecError>>,
298+
parsing_options: ParsingOptions,
289299
) -> Result<NvCreateChatCompletionResponse, String> {
290300
let stream = convert_sse_stream::<NvCreateChatCompletionStreamResponse>(stream);
291-
NvCreateChatCompletionResponse::from_annotated_stream(stream).await
301+
NvCreateChatCompletionResponse::from_annotated_stream(stream, parsing_options).await
292302
}
293303
}
294304

@@ -347,7 +357,7 @@ mod tests {
347357
Box::pin(stream::empty());
348358

349359
// Call DeltaAggregator::apply
350-
let result = DeltaAggregator::apply(stream).await;
360+
let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
351361

352362
// Check the result
353363
assert!(result.is_ok());
@@ -377,7 +387,7 @@ mod tests {
377387
let stream = Box::pin(stream::iter(vec![annotated_delta]));
378388

379389
// Call DeltaAggregator::apply
380-
let result = DeltaAggregator::apply(stream).await;
390+
let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
381391

382392
// Check the result
383393
assert!(result.is_ok());
@@ -421,7 +431,7 @@ mod tests {
421431
let stream = Box::pin(stream::iter(annotated_deltas));
422432

423433
// Call DeltaAggregator::apply
424-
let result = DeltaAggregator::apply(stream).await;
434+
let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
425435

426436
// Check the result
427437
assert!(result.is_ok());
@@ -492,7 +502,7 @@ mod tests {
492502
let stream = Box::pin(stream::iter(vec![annotated_delta]));
493503

494504
// Call DeltaAggregator::apply
495-
let result = DeltaAggregator::apply(stream).await;
505+
let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
496506

497507
// Check the result
498508
assert!(result.is_ok());
@@ -550,7 +560,7 @@ mod tests {
550560
let stream = Box::pin(stream::iter(vec![annotated_delta]));
551561

552562
// Call DeltaAggregator::apply
553-
let result = DeltaAggregator::apply(stream).await;
563+
let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
554564

555565
// Check the result
556566
assert!(result.is_ok());

0 commit comments

Comments
 (0)