@@ -19,7 +19,9 @@ use std::collections::HashMap;
1919use super :: { NvCreateChatCompletionResponse , NvCreateChatCompletionStreamResponse } ;
2020use crate :: protocols:: {
2121 codec:: { Message , SseCodecError } ,
22- convert_sse_stream, Annotated ,
22+ convert_sse_stream,
23+ openai:: ParsingOptions ,
24+ Annotated ,
2325} ;
2426
2527use dynamo_parsers:: tool_calling:: try_tool_call_parse_aggregate;
@@ -99,6 +101,7 @@ impl DeltaAggregator {
99101 /// * `Err(String)` if an error occurs during processing.
100102 pub async fn apply (
101103 stream : impl Stream < Item = Annotated < NvCreateChatCompletionStreamResponse > > ,
104+ parsing_options : ParsingOptions ,
102105 ) -> Result < NvCreateChatCompletionResponse , String > {
103106 let aggregator = stream
104107 . fold ( DeltaAggregator :: new ( ) , |mut aggregator, delta| async move {
@@ -175,7 +178,10 @@ impl DeltaAggregator {
175178 // After aggregation, inspect each choice's text for tool call syntax
176179 for choice in aggregator. choices . values_mut ( ) {
177180 if choice. tool_calls . is_none ( ) {
178- if let Ok ( tool_calls) = try_tool_call_parse_aggregate ( & choice. text , None ) {
181+ if let Ok ( tool_calls) = try_tool_call_parse_aggregate (
182+ & choice. text ,
183+ parsing_options. tool_call_parser . as_deref ( ) ,
184+ ) {
179185 if tool_calls. is_empty ( ) {
180186 continue ;
181187 }
@@ -262,6 +268,7 @@ pub trait ChatCompletionAggregator {
262268 /// * `Err(String)` if an error occurs.
263269 async fn from_annotated_stream (
264270 stream : impl Stream < Item = Annotated < NvCreateChatCompletionStreamResponse > > ,
271+ parsing_options : ParsingOptions ,
265272 ) -> Result < NvCreateChatCompletionResponse , String > ;
266273
267274 /// Converts an SSE stream into a [`NvCreateChatCompletionResponse`].
@@ -274,21 +281,24 @@ pub trait ChatCompletionAggregator {
274281 /// * `Err(String)` if an error occurs.
275282 async fn from_sse_stream (
276283 stream : DataStream < Result < Message , SseCodecError > > ,
284+ parsing_options : ParsingOptions ,
277285 ) -> Result < NvCreateChatCompletionResponse , String > ;
278286}
279287
280288impl ChatCompletionAggregator for dynamo_async_openai:: types:: CreateChatCompletionResponse {
281289 async fn from_annotated_stream (
282290 stream : impl Stream < Item = Annotated < NvCreateChatCompletionStreamResponse > > ,
291+ parsing_options : ParsingOptions ,
283292 ) -> Result < NvCreateChatCompletionResponse , String > {
284- DeltaAggregator :: apply ( stream) . await
293+ DeltaAggregator :: apply ( stream, parsing_options ) . await
285294 }
286295
287296 async fn from_sse_stream (
288297 stream : DataStream < Result < Message , SseCodecError > > ,
298+ parsing_options : ParsingOptions ,
289299 ) -> Result < NvCreateChatCompletionResponse , String > {
290300 let stream = convert_sse_stream :: < NvCreateChatCompletionStreamResponse > ( stream) ;
291- NvCreateChatCompletionResponse :: from_annotated_stream ( stream) . await
301+ NvCreateChatCompletionResponse :: from_annotated_stream ( stream, parsing_options ) . await
292302 }
293303}
294304
@@ -347,7 +357,7 @@ mod tests {
347357 Box :: pin ( stream:: empty ( ) ) ;
348358
349359 // Call DeltaAggregator::apply
350- let result = DeltaAggregator :: apply ( stream) . await ;
360+ let result = DeltaAggregator :: apply ( stream, ParsingOptions :: default ( ) ) . await ;
351361
352362 // Check the result
353363 assert ! ( result. is_ok( ) ) ;
@@ -377,7 +387,7 @@ mod tests {
377387 let stream = Box :: pin ( stream:: iter ( vec ! [ annotated_delta] ) ) ;
378388
379389 // Call DeltaAggregator::apply
380- let result = DeltaAggregator :: apply ( stream) . await ;
390+ let result = DeltaAggregator :: apply ( stream, ParsingOptions :: default ( ) ) . await ;
381391
382392 // Check the result
383393 assert ! ( result. is_ok( ) ) ;
@@ -421,7 +431,7 @@ mod tests {
421431 let stream = Box :: pin ( stream:: iter ( annotated_deltas) ) ;
422432
423433 // Call DeltaAggregator::apply
424- let result = DeltaAggregator :: apply ( stream) . await ;
434+ let result = DeltaAggregator :: apply ( stream, ParsingOptions :: default ( ) ) . await ;
425435
426436 // Check the result
427437 assert ! ( result. is_ok( ) ) ;
@@ -492,7 +502,7 @@ mod tests {
492502 let stream = Box :: pin ( stream:: iter ( vec ! [ annotated_delta] ) ) ;
493503
494504 // Call DeltaAggregator::apply
495- let result = DeltaAggregator :: apply ( stream) . await ;
505+ let result = DeltaAggregator :: apply ( stream, ParsingOptions :: default ( ) ) . await ;
496506
497507 // Check the result
498508 assert ! ( result. is_ok( ) ) ;
@@ -550,7 +560,7 @@ mod tests {
550560 let stream = Box :: pin ( stream:: iter ( vec ! [ annotated_delta] ) ) ;
551561
552562 // Call DeltaAggregator::apply
553- let result = DeltaAggregator :: apply ( stream) . await ;
563+ let result = DeltaAggregator :: apply ( stream, ParsingOptions :: default ( ) ) . await ;
554564
555565 // Check the result
556566 assert ! ( result. is_ok( ) ) ;
0 commit comments