From c6e965c0d00ec63c315ac67d2a6adb701e550ae4 Mon Sep 17 00:00:00 2001 From: sigoden Date: Fri, 29 Nov 2024 08:46:16 +0800 Subject: [PATCH] fix: openai-compatible stream function calling --- src/client/bedrock.rs | 4 ++-- src/client/claude.rs | 4 ++-- src/client/ernie.rs | 4 ++-- src/client/openai.rs | 25 +++++++++++++++---------- 4 files changed, 21 insertions(+), 16 deletions(-) diff --git a/src/client/bedrock.rs b/src/client/bedrock.rs index 78f9d2be..081a65b6 100644 --- a/src/client/bedrock.rs +++ b/src/client/bedrock.rs @@ -233,7 +233,7 @@ async fn chat_completions_streaming( if !function_name.is_empty() { let arguments: Value = function_arguments.parse().with_context(|| { - format!("Tool call '{function_name}' is invalid: arguments must be in valid JSON format") + format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'") })?; handler.tool_call(ToolCall::new( function_name.clone(), @@ -257,7 +257,7 @@ async fn chat_completions_streaming( "contentBlockStop" => { if !function_name.is_empty() { let arguments: Value = function_arguments.parse().with_context(|| { - format!("Tool call '{function_name}' is invalid: arguments must be in valid JSON format") + format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'") })?; handler.tool_call(ToolCall::new( function_name.clone(), diff --git a/src/client/claude.rs b/src/client/claude.rs index c7ef88a8..f982e14a 100644 --- a/src/client/claude.rs +++ b/src/client/claude.rs @@ -95,7 +95,7 @@ pub async fn claude_chat_completions_streaming( if !function_name.is_empty() { let arguments: Value = function_arguments.parse().with_context(|| { - format!("Tool call '{function_name}' is invalid: arguments must be in valid JSON format") + format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'") })?; handler.tool_call(ToolCall::new( function_name.clone(), @@ -124,7 +124,7 @@ pub async fn claude_chat_completions_streaming( json!({}) } else { function_arguments.parse().with_context(|| { - format!("Tool call '{function_name}' is invalid: arguments must be in valid JSON format") + format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'") })? }; handler.tool_call(ToolCall::new( diff --git a/src/client/ernie.rs b/src/client/ernie.rs index 7e1f8e70..fe98ff5a 100644 --- a/src/client/ernie.rs +++ b/src/client/ernie.rs @@ -175,7 +175,7 @@ async fn chat_completions_streaming( function.get("arguments").and_then(|v| v.as_str()), ) { let arguments: Value = arguments.parse().with_context(|| { - format!("Tool call '{name}' is invalid: arguments must be in valid JSON format") + format!("Tool call '{name}' have non-JSON arguments '{arguments}'") })?; handler.tool_call(ToolCall::new(name.to_string(), arguments, None))?; } @@ -292,7 +292,7 @@ fn extract_chat_completions_text(data: &Value) -> Result call.get("arguments").and_then(|v| v.as_str()), ) { let arguments: Value = arguments.parse().with_context(|| { - format!("Tool call '{name}' is invalid: arguments must be in valid JSON format") + format!("Tool call '{name}' have non-JSON arguments '{arguments}'") })?; tool_calls.push(ToolCall::new(name.to_string(), arguments, None)); } diff --git a/src/client/openai.rs b/src/client/openai.rs index 7f449fb8..253bf21a 100644 --- a/src/client/openai.rs +++ b/src/client/openai.rs @@ -101,7 +101,7 @@ pub async fn openai_chat_completions_streaming( handler: &mut SseHandler, _model: &Model, ) -> Result<()> { - let mut function_index = 0; + let mut call_id = String::new(); let mut function_name = String::new(); let mut function_arguments = String::new(); let mut function_id = String::new(); @@ -109,7 +109,7 @@ pub async fn openai_chat_completions_streaming( if message.data == "[DONE]" { if !function_name.is_empty() { let arguments: Value = function_arguments.parse().with_context(|| { - format!("Tool call '{function_name}' is invalid: arguments must be in valid JSON format") + format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'") })?; handler.tool_call(ToolCall::new( function_name.clone(), @@ -121,18 +121,23 @@ pub async fn openai_chat_completions_streaming( } let data: Value = serde_json::from_str(&message.data)?; debug!("stream-data: {data}"); - if let Some(text) = data["choices"][0]["delta"]["content"].as_str() { + if let Some(text) = data["choices"][0]["delta"]["content"] + .as_str() + .filter(|v| !v.is_empty()) + { handler.text(text)?; } else if let (Some(function), index, id) = ( data["choices"][0]["delta"]["tool_calls"][0]["function"].as_object(), data["choices"][0]["delta"]["tool_calls"][0]["index"].as_u64(), - data["choices"][0]["delta"]["tool_calls"][0]["id"].as_str(), + data["choices"][0]["delta"]["tool_calls"][0]["id"] + .as_str() + .filter(|v| !v.is_empty()), ) { - let index = index.unwrap_or_default(); - if index != function_index { + let maybe_call_id = format!("{}/{}", id.unwrap_or_default(), index.unwrap_or_default()); + if maybe_call_id != call_id && maybe_call_id.len() >= call_id.len() { if !function_name.is_empty() { let arguments: Value = function_arguments.parse().with_context(|| { - format!("Tool call '{function_name}' is invalid: arguments must be in valid JSON format") + format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'") })?; handler.tool_call(ToolCall::new( function_name.clone(), @@ -143,7 +148,7 @@ pub async fn openai_chat_completions_streaming( function_name.clear(); function_arguments.clear(); function_id.clear(); - function_index = index; + call_id = maybe_call_id; } if let Some(name) = function.get("name").and_then(|v| v.as_str()) { if name.starts_with(&function_name) { @@ -240,7 +245,7 @@ pub fn openai_build_chat_completions_body(data: ChatCompletionsData, model: &Mod vec![ json!({ "role": MessageRole::Assistant, - "content": null, + "content": "", "tool_calls": [ { "id": tool_result.call.id, @@ -319,7 +324,7 @@ pub fn openai_extract_chat_completions(data: &Value) -> Result