Skip to content

Commit 8cda57a

Browse files
ayushag-nvKrishnanPrash
authored andcommitted
feat: [vLLM] implement cli args for tool and reasoning parsers (#2619)
1 parent 36ab6ca commit 8cda57a

File tree

13 files changed

+190
-57
lines changed

13 files changed

+190
-57
lines changed

components/backends/vllm/src/dynamo/vllm/args.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ class Config:
5959
# Connector list from CLI
6060
connector_list: Optional[list] = None
6161

62+
# tool and reasoning parser info
63+
tool_call_parser: Optional[str] = None
64+
reasoning_parser: Optional[str] = None
65+
6266

6367
def parse_args() -> Config:
6468
parser = FlexibleArgumentParser(
@@ -109,6 +113,19 @@ def parse_args() -> Config:
109113
help="List of connectors to use in order (e.g., --connector nixl lmcache). "
110114
"Options: nixl, lmcache, kvbm, null, none. Default: nixl. Order will be preserved in MultiConnector.",
111115
)
116+
# To avoid name conflicts with different backends, adoped prefix "dyn-" for dynamo specific args
117+
parser.add_argument(
118+
"--dyn-tool-call-parser",
119+
type=str,
120+
default=None,
121+
help="Tool call parser name for the model. Available options: 'hermes', 'nemotron_deci', 'llama3_json', 'mistral', 'phi4'.",
122+
)
123+
parser.add_argument(
124+
"--dyn-reasoning-parser",
125+
type=str,
126+
default=None,
127+
help="Reasoning parser name for the model.",
128+
)
112129

113130
parser = AsyncEngineArgs.add_cli_args(parser)
114131
args = parser.parse_args()
@@ -160,6 +177,8 @@ def parse_args() -> Config:
160177
)
161178
config.custom_jinja_template = args.custom_jinja_template
162179

180+
config.tool_call_parser = args.dyn_tool_call_parser
181+
config.reasoning_parser = args.dyn_reasoning_parser
163182
# Check for conflicting flags
164183
has_kv_transfer_config = (
165184
hasattr(engine_args, "kv_transfer_config")

components/backends/vllm/src/dynamo/vllm/main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,8 @@ async def init(runtime: DistributedRuntime, config: Config):
234234
runtime_config.total_kv_blocks = runtime_values["num_gpu_blocks"]
235235
runtime_config.max_num_seqs = runtime_values["max_num_seqs"]
236236
runtime_config.max_num_batched_tokens = runtime_values["max_num_batched_tokens"]
237+
runtime_config.tool_call_parser = config.tool_call_parser
238+
runtime_config.reasoning_parser = config.reasoning_parser
237239

238240
await register_llm(
239241
ModelType.Backend,

lib/bindings/python/rust/llm/local_model.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,16 @@ impl ModelRuntimeConfig {
3434
self.inner.max_num_batched_tokens = Some(max_num_batched_tokens);
3535
}
3636

37+
#[setter]
38+
fn set_tool_call_parser(&mut self, tool_call_parser: Option<String>) {
39+
self.inner.tool_call_parser = tool_call_parser;
40+
}
41+
42+
#[setter]
43+
fn set_reasoning_parser(&mut self, reasoning_parser: Option<String>) {
44+
self.inner.reasoning_parser = reasoning_parser;
45+
}
46+
3747
fn set_engine_specific(&mut self, key: &str, value: String) -> PyResult<()> {
3848
let value: serde_json::Value = serde_json::from_str(&value).map_err(to_pyerr)?;
3949
self.inner
@@ -57,6 +67,16 @@ impl ModelRuntimeConfig {
5767
self.inner.max_num_batched_tokens
5868
}
5969

70+
#[getter]
71+
fn tool_call_parser(&self) -> Option<String> {
72+
self.inner.tool_call_parser.clone()
73+
}
74+
75+
#[getter]
76+
fn reasoning_parser(&self) -> Option<String> {
77+
self.inner.reasoning_parser.clone()
78+
}
79+
6080
#[getter]
6181
fn runtime_data(&self, py: Python<'_>) -> PyResult<PyObject> {
6282
let dict = PyDict::new(py);

lib/llm/src/discovery/model_manager.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,18 @@ impl ModelManager {
246246
.insert(model_name.to_string(), new_kv_chooser.clone());
247247
Ok(new_kv_chooser)
248248
}
249+
250+
pub fn get_model_tool_call_parser(&self, model: &str) -> Option<String> {
251+
match self.entries.lock() {
252+
Ok(entries) => entries
253+
.values()
254+
.find(|entry| entry.name == model)
255+
.and_then(|entry| entry.runtime_config.as_ref())
256+
.and_then(|config| config.tool_call_parser.clone())
257+
.map(|parser| parser.to_string()),
258+
Err(_) => None,
259+
}
260+
}
249261
}
250262

251263
pub struct ModelEngines<E> {

lib/llm/src/http/service/openai.rs

Lines changed: 43 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ use crate::protocols::openai::{
3737
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
3838
embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
3939
responses::{NvCreateResponse, NvResponse},
40+
ParsingOptions,
4041
};
4142
use crate::request_template::RequestTemplate;
4243
use crate::types::Annotated;
@@ -194,6 +195,13 @@ fn get_or_create_request_id(primary: Option<&str>, headers: &HeaderMap) -> Strin
194195
uuid.to_string()
195196
}
196197

198+
fn get_parsing_options(state: &Arc<service_v2::State>, model: &str) -> ParsingOptions {
199+
let tool_call_parser = state.manager().get_model_tool_call_parser(model);
200+
let reasoning_parser = None; // TODO: Implement reasoning parser
201+
202+
ParsingOptions::new(tool_call_parser, reasoning_parser)
203+
}
204+
197205
/// OpenAI Completions Request Handler
198206
///
199207
/// This method will handle the incoming request for the `/v1/completions endpoint`. The endpoint is a "source"
@@ -267,6 +275,8 @@ async fn completions(
267275
.get_completions_engine(model)
268276
.map_err(|_| ErrorMessage::model_not_found())?;
269277

278+
let parsing_options = get_parsing_options(&state, model);
279+
270280
let mut inflight_guard =
271281
state
272282
.metrics_clone()
@@ -325,7 +335,7 @@ async fn completions(
325335
process_metrics_only(response, &mut response_collector);
326336
});
327337

328-
let response = NvCreateCompletionResponse::from_annotated_stream(stream)
338+
let response = NvCreateCompletionResponse::from_annotated_stream(stream, parsing_options)
329339
.await
330340
.map_err(|e| {
331341
tracing::error!(
@@ -494,6 +504,8 @@ async fn chat_completions(
494504
.get_chat_completions_engine(model)
495505
.map_err(|_| ErrorMessage::model_not_found())?;
496506

507+
let parsing_options = get_parsing_options(&state, model);
508+
497509
let mut inflight_guard =
498510
state
499511
.metrics_clone()
@@ -553,19 +565,20 @@ async fn chat_completions(
553565
process_metrics_only(response, &mut response_collector);
554566
});
555567

556-
let response = NvCreateChatCompletionResponse::from_annotated_stream(stream)
557-
.await
558-
.map_err(|e| {
559-
tracing::error!(
560-
request_id,
561-
"Failed to fold chat completions stream for: {:?}",
562-
e
563-
);
564-
ErrorMessage::internal_server_error(&format!(
565-
"Failed to fold chat completions stream: {}",
566-
e
567-
))
568-
})?;
568+
let response =
569+
NvCreateChatCompletionResponse::from_annotated_stream(stream, parsing_options.clone())
570+
.await
571+
.map_err(|e| {
572+
tracing::error!(
573+
request_id,
574+
"Failed to fold chat completions stream for: {:?}",
575+
e
576+
);
577+
ErrorMessage::internal_server_error(&format!(
578+
"Failed to fold chat completions stream: {}",
579+
e
580+
))
581+
})?;
569582

570583
inflight_guard.mark_ok();
571584
Ok(Json(response).into_response())
@@ -726,6 +739,8 @@ async fn responses(
726739
.get_chat_completions_engine(model)
727740
.map_err(|_| ErrorMessage::model_not_found())?;
728741

742+
let parsing_options = get_parsing_options(&state, model);
743+
729744
let mut inflight_guard =
730745
state
731746
.metrics_clone()
@@ -742,19 +757,20 @@ async fn responses(
742757
.map_err(|e| ErrorMessage::from_anyhow(e, "Failed to generate completions"))?;
743758

744759
// TODO: handle streaming, currently just unary
745-
let response = NvCreateChatCompletionResponse::from_annotated_stream(stream)
746-
.await
747-
.map_err(|e| {
748-
tracing::error!(
749-
request_id,
750-
"Failed to fold chat completions stream for: {:?}",
751-
e
752-
);
753-
ErrorMessage::internal_server_error(&format!(
754-
"Failed to fold chat completions stream: {}",
755-
e
756-
))
757-
})?;
760+
let response =
761+
NvCreateChatCompletionResponse::from_annotated_stream(stream, parsing_options.clone())
762+
.await
763+
.map_err(|e| {
764+
tracing::error!(
765+
request_id,
766+
"Failed to fold chat completions stream for: {:?}",
767+
e
768+
);
769+
ErrorMessage::internal_server_error(&format!(
770+
"Failed to fold chat completions stream: {}",
771+
e
772+
))
773+
})?;
758774

759775
// Convert NvCreateChatCompletionResponse --> NvResponse
760776
let response: NvResponse = response.try_into().map_err(|e| {

lib/llm/src/local_model.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ impl LocalModelBuilder {
209209
);
210210
card.migration_limit = self.migration_limit;
211211
card.user_data = self.user_data.take();
212+
212213
return Ok(LocalModel {
213214
card,
214215
full_path: PathBuf::new(),
@@ -402,6 +403,7 @@ impl LocalModel {
402403
// Publish the Model Deployment Card to etcd
403404
let kvstore: Box<dyn KeyValueStore> = Box::new(EtcdStorage::new(etcd_client.clone()));
404405
let card_store = Arc::new(KeyValueStoreManager::new(kvstore));
406+
<<<<<<< HEAD
405407
let slug_key = self.card.slug();
406408
let key = slug_key.to_string();
407409

@@ -436,6 +438,13 @@ impl LocalModel {
436438
.await?;
437439
}
438440
}
441+
=======
442+
let key = self.card.slug().to_string();
443+
444+
card_store
445+
.publish(model_card::ROOT_PATH, None, &key, &mut self.card)
446+
.await?;
447+
>>>>>>> cbe854fc (feat: [vLLM] implement cli args for tool and reasoning parsers (#2619))
439448

440449
// Publish our ModelEntry to etcd. This allows ingress to find the model card.
441450
// (Why don't we put the model card directly under this key?)

lib/llm/src/local_model/runtime_config.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ pub struct ModelRuntimeConfig {
1313

1414
pub max_num_batched_tokens: Option<u64>,
1515

16+
pub tool_call_parser: Option<String>,
17+
18+
pub reasoning_parser: Option<String>,
19+
1620
/// Mapping of engine-specific runtime configs
1721
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
1822
pub runtime_data: HashMap<String, serde_json::Value>,

lib/llm/src/preprocessor.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,6 @@ impl OpenAIPreprocessor {
101101
let mdcsum = mdc.mdcsum();
102102
let formatter = PromptFormatter::from_mdc(mdc.clone()).await?;
103103
let PromptFormatter::OAI(formatter) = formatter;
104-
105104
let tokenizer = match &mdc.tokenizer {
106105
Some(TokenizerKind::HfTokenizerJson(file)) => HuggingFaceTokenizer::from_file(file)?,
107106
Some(TokenizerKind::GGUF(tokenizer)) => {

lib/llm/src/protocols/openai.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,19 @@ pub trait DeltaGeneratorExt<ResponseType: Send + 'static + std::fmt::Debug>:
193193
/// Gets the current prompt token count (Input Sequence Length).
194194
fn get_isl(&self) -> Option<u32>;
195195
}
196+
197+
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
198+
pub struct ParsingOptions {
199+
pub tool_call_parser: Option<String>,
200+
201+
pub reasoning_parser: Option<String>,
202+
}
203+
204+
impl ParsingOptions {
205+
pub fn new(tool_call_parser: Option<String>, reasoning_parser: Option<String>) -> Self {
206+
Self {
207+
tool_call_parser,
208+
reasoning_parser,
209+
}
210+
}
211+
}

0 commit comments

Comments
 (0)