diff --git a/crates/goose/src/providers/formats/openai.rs b/crates/goose/src/providers/formats/openai.rs index fc977ca0b09a..e488272c1779 100644 --- a/crates/goose/src/providers/formats/openai.rs +++ b/crates/goose/src/providers/formats/openai.rs @@ -7,6 +7,7 @@ use crate::providers::utils::{ }; use anyhow::{anyhow, Error}; use async_stream::try_stream; +use chrono; use futures::Stream; use rmcp::model::{ object, AnnotateAble, CallToolRequestParam, Content, ErrorCode, ErrorData, RawContent, @@ -281,7 +282,18 @@ pub fn format_tools(tools: &[Tool]) -> anyhow::Result> { /// Convert OpenAI's API response to internal Message format pub fn response_to_message(response: &Value) -> anyhow::Result { - let original = &response["choices"][0]["message"]; + let Some(original) = response + .get("choices") + .and_then(|c| c.get(0)) + .and_then(|m| m.get("message")) + else { + return Ok(Message::new( + Role::Assistant, + chrono::Utc::now().timestamp(), + Vec::new(), + )); + }; + let mut content = Vec::new(); if let Some(text) = original.get("content") { @@ -465,12 +477,14 @@ where if chunk.choices.is_empty() { yield (None, usage) - } else if let Some(tool_calls) = &chunk.choices[0].delta.tool_calls { + } else if chunk.choices[0].delta.tool_calls.as_ref().is_some_and(|tc| !tc.is_empty()) { let mut tool_call_data: std::collections::HashMap = std::collections::HashMap::new(); - for tool_call in tool_calls { - if let (Some(index), Some(id), Some(name)) = (tool_call.index, &tool_call.id, &tool_call.function.name) { - tool_call_data.insert(index, (id.clone(), name.clone(), tool_call.function.arguments.clone())); + if let Some(tool_calls) = &chunk.choices[0].delta.tool_calls { + for tool_call in tool_calls { + if let (Some(index), Some(id), Some(name)) = (tool_call.index, &tool_call.id, &tool_call.function.name) { + tool_call_data.insert(index, (id.clone(), name.clone(), tool_call.function.arguments.clone())); + } } } @@ -489,21 +503,25 @@ where let tool_chunk: StreamingChunk = serde_json::from_str(line) .map_err(|e| anyhow!("Failed to parse streaming chunk: {}: {:?}", e, &line))?; - if let Some(delta_tool_calls) = &tool_chunk.choices[0].delta.tool_calls { - for delta_call in delta_tool_calls { - if let Some(index) = delta_call.index { - if let Some((_, _, ref mut args)) = tool_call_data.get_mut(&index) { - args.push_str(&delta_call.function.arguments); - } else if let (Some(id), Some(name)) = (&delta_call.id, &delta_call.function.name) { - tool_call_data.insert(index, (id.clone(), name.clone(), delta_call.function.arguments.clone())); + if !tool_chunk.choices.is_empty() { + if let Some(delta_tool_calls) = &tool_chunk.choices[0].delta.tool_calls { + for delta_call in delta_tool_calls { + if let Some(index) = delta_call.index { + if let Some((_, _, ref mut args)) = tool_call_data.get_mut(&index) { + args.push_str(&delta_call.function.arguments); + } else if let (Some(id), Some(name)) = (&delta_call.id, &delta_call.function.name) { + tool_call_data.insert(index, (id.clone(), name.clone(), delta_call.function.arguments.clone())); + } } } + } else { + done = true; } - } else { - done = true; - } - if tool_chunk.choices[0].finish_reason == Some("tool_calls".to_string()) { + if tool_chunk.choices[0].finish_reason == Some("tool_calls".to_string()) { + done = true; + } + } else { done = true; } } @@ -563,7 +581,8 @@ where Some(msg), usage, ) - } else if let Some(text) = &chunk.choices[0].delta.content { + } else if chunk.choices[0].delta.content.is_some() { + let text = chunk.choices[0].delta.content.as_ref().unwrap(); let mut msg = Message::new( Role::Assistant, chrono::Utc::now().timestamp(),