Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions components/frontend/src/dynamo/frontend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,12 @@ def parse_args():
default=None,
help="Prefix for Dynamo frontend metrics. If unset, uses DYN_METRICS_PREFIX env var or 'dynamo_frontend'.",
)
parser.add_argument(
"--tool-call-parser",
Copy link
Contributor

@rmccorm4 rmccorm4 Aug 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the frontend can serve many models - should this be a frontend argument? Or backend/worker argument passed through ModelDeploymentCard/RuntimeConfig etc. that the frontend can load on demand when worker is discovered via register_llm?

ex: python -m dynamo.vllm instead of python -m dynamo.frontend

python -m dynamo.vllm --tool-call-parser ...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, sorry I see you said the same thing on slack - ignore me.

type=str,
default=None,
help="Tool parser name for the model. Available options: 'hermes', 'nemotron_deci', 'llama3_json', 'mistral', 'phi4'.",
)

flags = parser.parse_args()

Expand Down Expand Up @@ -233,6 +239,8 @@ async def async_main():
kwargs["tls_cert_path"] = flags.tls_cert_path
if flags.tls_key_path:
kwargs["tls_key_path"] = flags.tls_key_path
if flags.tool_call_parser:
kwargs["tool_call_parser"] = flags.tool_call_parser

if is_static:
# out=dyn://<static_endpoint>
Expand Down
8 changes: 6 additions & 2 deletions lib/bindings/python/rust/llm/entrypoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,14 @@ pub(crate) struct EntrypointArgs {
tls_cert_path: Option<PathBuf>,
tls_key_path: Option<PathBuf>,
extra_engine_args: Option<PathBuf>,
tool_call_parser: Option<String>,
}

#[pymethods]
impl EntrypointArgs {
#[allow(clippy::too_many_arguments)]
#[new]
#[pyo3(signature = (engine_type, model_path=None, model_name=None, model_config=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None))]
#[pyo3(signature = (engine_type, model_path=None, model_name=None, model_config=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, tool_call_parser=None))]
pub fn new(
engine_type: EngineType,
model_path: Option<PathBuf>,
Expand All @@ -129,6 +130,7 @@ impl EntrypointArgs {
tls_cert_path: Option<PathBuf>,
tls_key_path: Option<PathBuf>,
extra_engine_args: Option<PathBuf>,
tool_call_parser: Option<String>,
) -> PyResult<Self> {
let endpoint_id_obj: Option<EndpointId> = match endpoint_id {
Some(eid) => Some(eid.parse().map_err(|_| {
Expand Down Expand Up @@ -160,6 +162,7 @@ impl EntrypointArgs {
tls_cert_path,
tls_key_path,
extra_engine_args,
tool_call_parser,
})
}
}
Expand Down Expand Up @@ -192,7 +195,8 @@ pub fn make_engine<'p>(
.tls_cert_path(args.tls_cert_path.clone())
.tls_key_path(args.tls_key_path.clone())
.is_mocker(matches!(args.engine_type, EngineType::Mocker))
.extra_engine_args(args.extra_engine_args.clone());
.extra_engine_args(args.extra_engine_args.clone())
.tool_call_parser(args.tool_call_parser.clone());
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let local_model = builder.build().await.map_err(to_pyerr)?;
let inner = select_engine(distributed_runtime, args, local_model)
Expand Down
4 changes: 4 additions & 0 deletions lib/llm/src/entrypoint/input/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
if let Some(http_host) = local_model.http_host() {
http_service_builder = http_service_builder.host(http_host);
}

if let Some(tool_call_parser) = local_model.tool_call_parser() {
http_service_builder = http_service_builder.with_tool_call_parser(Some(tool_call_parser));
}
http_service_builder =
http_service_builder.with_request_template(engine_config.local_model().request_template());

Expand Down
36 changes: 21 additions & 15 deletions lib/llm/src/http/service/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,8 @@ async fn completions(
// apply any annotations to the front of the stream
let stream = stream::iter(annotations).chain(stream);

let tool_call_parser = state.tool_call_parser();

if streaming {
let stream = stream.map(move |response| {
process_event_converter(EventConverter::from(response), &mut response_collector)
Expand All @@ -314,7 +316,7 @@ async fn completions(
process_metrics_only(response, &mut response_collector);
});

let response = NvCreateCompletionResponse::from_annotated_stream(stream)
let response = NvCreateCompletionResponse::from_annotated_stream(stream, tool_call_parser)
.await
.map_err(|e| {
tracing::error!(
Expand Down Expand Up @@ -478,6 +480,8 @@ async fn chat_completions(
// todo - determine the proper error code for when a request model is not present
tracing::trace!("Getting chat completions engine for model: {}", model);

let tool_call_parser = state.tool_call_parser();

let engine = state
.manager()
.get_chat_completions_engine(model)
Expand Down Expand Up @@ -542,19 +546,20 @@ async fn chat_completions(
process_metrics_only(response, &mut response_collector);
});

let response = NvCreateChatCompletionResponse::from_annotated_stream(stream)
.await
.map_err(|e| {
tracing::error!(
request_id,
"Failed to fold chat completions stream for: {:?}",
e
);
ErrorMessage::internal_server_error(&format!(
"Failed to fold chat completions stream: {}",
e
))
})?;
let response =
NvCreateChatCompletionResponse::from_annotated_stream(stream, tool_call_parser)
.await
.map_err(|e| {
tracing::error!(
request_id,
"Failed to fold chat completions stream for: {:?}",
e
);
ErrorMessage::internal_server_error(&format!(
"Failed to fold chat completions stream: {}",
e
))
})?;

inflight_guard.mark_ok();
Ok(Json(response).into_response())
Expand Down Expand Up @@ -731,7 +736,8 @@ async fn responses(
.map_err(|e| ErrorMessage::from_anyhow(e, "Failed to generate completions"))?;

// TODO: handle streaming, currently just unary
let response = NvCreateChatCompletionResponse::from_annotated_stream(stream)
let tool_call_parser = state.tool_call_parser();
let response = NvCreateChatCompletionResponse::from_annotated_stream(stream, tool_call_parser)
.await
.map_err(|e| {
tracing::error!(
Expand Down
29 changes: 26 additions & 3 deletions lib/llm/src/http/service/service_v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ pub struct State {
manager: Arc<ModelManager>,
etcd_client: Option<etcd::Client>,
flags: StateFlags,
tool_call_parser: Option<String>,
}

#[derive(Default, Debug)]
Expand Down Expand Up @@ -71,7 +72,7 @@ impl StateFlags {
}

impl State {
pub fn new(manager: Arc<ModelManager>) -> Self {
pub fn new(manager: Arc<ModelManager>, tool_call_parser: Option<String>) -> Self {
Self {
manager,
metrics: Arc::new(Metrics::default()),
Expand All @@ -82,10 +83,15 @@ impl State {
embeddings_endpoints_enabled: AtomicBool::new(false),
responses_endpoints_enabled: AtomicBool::new(false),
},
tool_call_parser: Some(tool_call_parser.unwrap_or_else(|| String::from(""))),
}
}

pub fn new_with_etcd(manager: Arc<ModelManager>, etcd_client: Option<etcd::Client>) -> Self {
pub fn new_with_etcd(
manager: Arc<ModelManager>,
etcd_client: Option<etcd::Client>,
tool_call_parser: Option<String>,
) -> Self {
Self {
manager,
metrics: Arc::new(Metrics::default()),
Expand All @@ -96,6 +102,7 @@ impl State {
embeddings_endpoints_enabled: AtomicBool::new(false),
responses_endpoints_enabled: AtomicBool::new(false),
},
tool_call_parser: Some(tool_call_parser.unwrap_or_else(|| String::from(""))),
}
}
/// Get the Prometheus [`Metrics`] object which tracks request counts and inflight requests
Expand All @@ -119,6 +126,10 @@ impl State {
pub fn sse_keep_alive(&self) -> Option<Duration> {
None
}

pub fn tool_call_parser(&self) -> Option<String> {
self.tool_call_parser.clone()
}
}

#[derive(Clone)]
Expand Down Expand Up @@ -172,6 +183,9 @@ pub struct HttpServiceConfig {

#[builder(default = "None")]
etcd_client: Option<etcd::Client>,

#[builder(default = "None")]
tool_call_parser: Option<String>,
}

impl HttpService {
Expand Down Expand Up @@ -294,7 +308,11 @@ impl HttpServiceConfigBuilder {
let config: HttpServiceConfig = self.build_internal()?;

let model_manager = Arc::new(ModelManager::new());
let state = Arc::new(State::new_with_etcd(model_manager, config.etcd_client));
let state = Arc::new(State::new_with_etcd(
model_manager,
config.etcd_client,
config.tool_call_parser,
));

state
.flags
Expand Down Expand Up @@ -357,6 +375,11 @@ impl HttpServiceConfigBuilder {
self
}

pub fn with_tool_call_parser(mut self, tool_call_parser: Option<String>) -> Self {
self.tool_call_parser = Some(tool_call_parser);
self
}

fn get_endpoints_router(
state: Arc<State>,
request_template: &Option<RequestTemplate>,
Expand Down
14 changes: 14 additions & 0 deletions lib/llm/src/local_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ pub struct LocalModelBuilder {
extra_engine_args: Option<PathBuf>,
runtime_config: ModelRuntimeConfig,
user_data: Option<serde_json::Value>,
tool_call_parser: Option<String>,
}

impl Default for LocalModelBuilder {
Expand All @@ -81,6 +82,7 @@ impl Default for LocalModelBuilder {
extra_engine_args: Default::default(),
runtime_config: Default::default(),
user_data: Default::default(),
tool_call_parser: Default::default(),
}
}
}
Expand Down Expand Up @@ -172,6 +174,11 @@ impl LocalModelBuilder {
self
}

pub fn tool_call_parser(&mut self, tool_call_parser: Option<String>) -> &mut Self {
self.tool_call_parser = tool_call_parser;
self
}

/// Make an LLM ready for use:
/// - Download it from Hugging Face (and NGC in future) if necessary
/// - Resolve the path
Expand Down Expand Up @@ -213,6 +220,7 @@ impl LocalModelBuilder {
tls_key_path: self.tls_key_path.take(),
router_config: self.router_config.take().unwrap_or_default(),
runtime_config: self.runtime_config.clone(),
tool_call_parser: self.tool_call_parser.take(),
});
}

Expand Down Expand Up @@ -289,6 +297,7 @@ impl LocalModelBuilder {
tls_key_path: self.tls_key_path.take(),
router_config: self.router_config.take().unwrap_or_default(),
runtime_config: self.runtime_config.clone(),
tool_call_parser: self.tool_call_parser.take(),
})
}
}
Expand All @@ -305,6 +314,7 @@ pub struct LocalModel {
tls_key_path: Option<PathBuf>,
router_config: RouterConfig,
runtime_config: ModelRuntimeConfig,
tool_call_parser: Option<String>,
}

impl LocalModel {
Expand Down Expand Up @@ -372,6 +382,10 @@ impl LocalModel {
self.card
}

pub fn tool_call_parser(&self) -> Option<String> {
self.tool_call_parser.clone()
}

/// Attach this model the endpoint. This registers it on the network
/// allowing ingress to discover it.
pub async fn attach(
Expand Down
23 changes: 15 additions & 8 deletions lib/llm/src/protocols/openai/chat_completions/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ impl DeltaAggregator {
/// * `Err(String)` if an error occurs during processing.
pub async fn apply(
stream: impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>>,
tool_call_parser: Option<String>,
) -> Result<NvCreateChatCompletionResponse, String> {
let aggregator = stream
.fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
Expand Down Expand Up @@ -164,7 +165,9 @@ impl DeltaAggregator {
// After aggregation, inspect each choice's text for tool call syntax
for choice in aggregator.choices.values_mut() {
if choice.tool_calls.is_none() {
if let Ok(tool_calls) = try_tool_call_parse_aggregate(&choice.text, None) {
if let Ok(tool_calls) =
try_tool_call_parse_aggregate(&choice.text, tool_call_parser.as_deref())
{
if tool_calls.is_empty() {
continue;
}
Expand Down Expand Up @@ -251,6 +254,7 @@ pub trait ChatCompletionAggregator {
/// * `Err(String)` if an error occurs.
async fn from_annotated_stream(
stream: impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>>,
tool_call_parser: Option<String>,
) -> Result<NvCreateChatCompletionResponse, String>;

/// Converts an SSE stream into a [`NvCreateChatCompletionResponse`].
Expand All @@ -263,21 +267,24 @@ pub trait ChatCompletionAggregator {
/// * `Err(String)` if an error occurs.
async fn from_sse_stream(
stream: DataStream<Result<Message, SseCodecError>>,
tool_call_parser: Option<String>,
) -> Result<NvCreateChatCompletionResponse, String>;
}

impl ChatCompletionAggregator for dynamo_async_openai::types::CreateChatCompletionResponse {
async fn from_annotated_stream(
stream: impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>>,
tool_call_parser: Option<String>,
) -> Result<NvCreateChatCompletionResponse, String> {
DeltaAggregator::apply(stream).await
DeltaAggregator::apply(stream, tool_call_parser).await
}

async fn from_sse_stream(
stream: DataStream<Result<Message, SseCodecError>>,
tool_call_parser: Option<String>,
) -> Result<NvCreateChatCompletionResponse, String> {
let stream = convert_sse_stream::<NvCreateChatCompletionStreamResponse>(stream);
NvCreateChatCompletionResponse::from_annotated_stream(stream).await
NvCreateChatCompletionResponse::from_annotated_stream(stream, tool_call_parser).await
}
}

Expand Down Expand Up @@ -336,7 +343,7 @@ mod tests {
Box::pin(stream::empty());

// Call DeltaAggregator::apply
let result = DeltaAggregator::apply(stream).await;
let result = DeltaAggregator::apply(stream, None).await;

// Check the result
assert!(result.is_ok());
Expand Down Expand Up @@ -366,7 +373,7 @@ mod tests {
let stream = Box::pin(stream::iter(vec![annotated_delta]));

// Call DeltaAggregator::apply
let result = DeltaAggregator::apply(stream).await;
let result = DeltaAggregator::apply(stream, None).await;

// Check the result
assert!(result.is_ok());
Expand Down Expand Up @@ -410,7 +417,7 @@ mod tests {
let stream = Box::pin(stream::iter(annotated_deltas));

// Call DeltaAggregator::apply
let result = DeltaAggregator::apply(stream).await;
let result = DeltaAggregator::apply(stream, None).await;

// Check the result
assert!(result.is_ok());
Expand Down Expand Up @@ -481,7 +488,7 @@ mod tests {
let stream = Box::pin(stream::iter(vec![annotated_delta]));

// Call DeltaAggregator::apply
let result = DeltaAggregator::apply(stream).await;
let result = DeltaAggregator::apply(stream, None).await;

// Check the result
assert!(result.is_ok());
Expand Down Expand Up @@ -539,7 +546,7 @@ mod tests {
let stream = Box::pin(stream::iter(vec![annotated_delta]));

// Call DeltaAggregator::apply
let result = DeltaAggregator::apply(stream).await;
let result = DeltaAggregator::apply(stream, Some("default".to_string())).await;

// Check the result
assert!(result.is_ok());
Expand Down
Loading
Loading