diff --git a/crates/goose/src/providers/azure.rs b/crates/goose/src/providers/azure.rs index 51a31c06b957..44a4f958b8dd 100644 --- a/crates/goose/src/providers/azure.rs +++ b/crates/goose/src/providers/azure.rs @@ -249,14 +249,10 @@ impl Provider for AzureProvider { let response = self.post(payload.clone()).await?; let message = response_to_message(response.clone())?; - let usage = match get_usage(&response) { - Ok(usage) => usage, - Err(ProviderError::UsageError(e)) => { - tracing::debug!("Failed to get usage data: {}", e); - Usage::default() - } - Err(e) => return Err(e), - }; + let usage = response.get("usage").map(get_usage).unwrap_or_else(|| { + tracing::debug!("Failed to get usage data"); + Usage::default() + }); let model = get_model(&response); emit_debug_trace(&self.model, &payload, &response, &usage); Ok((message, ProviderUsage::new(model, usage))) diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index e074d9393af8..99768c1972ae 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -13,13 +13,13 @@ use tokio_util::io::StreamReader; use super::base::{ConfigKey, MessageStream, Provider, ProviderMetadata, ProviderUsage, Usage}; use super::embedding::EmbeddingCapable; use super::errors::ProviderError; -use super::formats::databricks::{create_request, get_usage, response_to_message}; +use super::formats::databricks::{create_request, response_to_message}; use super::oauth; use super::utils::{get_model, ImageFormat}; use crate::config::ConfigError; use crate::message::Message; use crate::model::ModelConfig; -use crate::providers::formats::databricks::response_to_streaming_message; +use crate::providers::formats::openai::{get_usage, response_to_streaming_message}; use mcp_core::tool::Tool; use serde_json::json; use tokio::time::sleep; @@ -455,13 +455,10 @@ impl Provider for DatabricksProvider { // Parse response let message = response_to_message(response.clone())?; - let usage = match response.get("usage").map(get_usage) { - Some(usage) => usage, - None => { - tracing::debug!("Failed to get usage data"); - Usage::default() - } - }; + let usage = response.get("usage").map(get_usage).unwrap_or_else(|| { + tracing::debug!("Failed to get usage data"); + Usage::default() + }); let model = get_model(&response); super::utils::emit_debug_trace(&self.model, &payload, &response, &usage); diff --git a/crates/goose/src/providers/formats/databricks.rs b/crates/goose/src/providers/formats/databricks.rs index 8c462b624e9b..e47bd3e82ea8 100644 --- a/crates/goose/src/providers/formats/databricks.rs +++ b/crates/goose/src/providers/formats/databricks.rs @@ -1,13 +1,10 @@ use crate::message::{Message, MessageContent}; use crate::model::ModelConfig; -use crate::providers::base::{ProviderUsage, Usage}; use crate::providers::utils::{ convert_image, detect_image_path, is_valid_function_name, load_image_file, sanitize_function_name, ImageFormat, }; use anyhow::{anyhow, Error}; -use async_stream::try_stream; -use futures::Stream; use mcp_core::ToolError; use mcp_core::{Content, Role, Tool, ToolCall}; use serde::{Deserialize, Serialize}; @@ -404,140 +401,6 @@ struct StreamingChunk { model: String, } -fn strip_data_prefix(line: &str) -> Option<&str> { - line.strip_prefix("data: ").map(|s| s.trim()) -} - -pub fn response_to_streaming_message( - mut stream: S, -) -> impl Stream, Option)>> + 'static -where - S: Stream> + Unpin + Send + 'static, -{ - try_stream! { - use futures::StreamExt; - - 'outer: while let Some(response) = stream.next().await { - if response.as_ref().is_ok_and(|s| s == "data: [DONE]") { - break 'outer; - } - let response_str = response?; - let line = strip_data_prefix(&response_str); - - if line.is_none() || line.is_some_and(|l| l.is_empty()) { - continue - } - - let chunk: StreamingChunk = serde_json::from_str(line - .ok_or_else(|| anyhow!("unexpected stream format"))?) - .map_err(|e| anyhow!("Failed to parse streaming chunk: {}: {:?}", e, &line))?; - let model = chunk.model.clone(); - - let usage = chunk.usage.as_ref().map(|u| { - ProviderUsage { - usage: get_usage(u), - model, - } - }); - - if chunk.choices.is_empty() { - yield (None, usage) - } else if let Some(tool_calls) = &chunk.choices[0].delta.tool_calls { - let tool_call = &tool_calls[0]; - let id = tool_call.id.clone().ok_or(anyhow!("No tool call ID"))?; - let function_name = tool_call.function.name.clone().ok_or(anyhow!("No function name"))?; - let mut arguments = tool_call.function.arguments.clone(); - - while let Some(response_chunk) = stream.next().await { - if response_chunk.as_ref().is_ok_and(|s| s == "data: [DONE]") { - break 'outer; - } - let response_str = response_chunk?; - if let Some(line) = strip_data_prefix(&response_str) { - let tool_chunk: StreamingChunk = serde_json::from_str(line) - .map_err(|e| anyhow!("Failed to parse streaming chunk: {}: {:?}", e, &line))?; - let more_args = tool_chunk.choices[0].delta.tool_calls.as_ref() - .and_then(|calls| calls.first()) - .map(|call| call.function.arguments.as_str()); - if let Some(more_args) = more_args { - arguments.push_str(more_args); - } else { - break; - } - } - } - - let parsed = if arguments.is_empty() { - Ok(json!({})) - } else { - serde_json::from_str::(&arguments) - }; - - let content = match parsed { - Ok(params) => MessageContent::tool_request( - id, - Ok(ToolCall::new(function_name, params)), - ), - Err(e) => { - let error = ToolError::InvalidParameters(format!( - "Could not interpret tool use parameters for id {}: {}", - id, e - )); - MessageContent::tool_request(id, Err(error)) - } - }; - - yield ( - Some(Message { - id: chunk.id, - role: Role::Assistant, - created: chrono::Utc::now().timestamp(), - content: vec![content], - }), - usage, - ) - } else if let Some(text) = &chunk.choices[0].delta.content { - yield ( - Some(Message { - id: chunk.id, - role: Role::Assistant, - created: chrono::Utc::now().timestamp(), - content: vec![MessageContent::text(text)], - }), - if chunk.choices[0].finish_reason.is_some() { - usage - } else { - None - }, - ) - } - } - } -} - -pub fn get_usage(usage: &Value) -> Usage { - let input_tokens = usage - .get("prompt_tokens") - .and_then(|v| v.as_i64()) - .map(|v| v as i32); - - let output_tokens = usage - .get("completion_tokens") - .and_then(|v| v.as_i64()) - .map(|v| v as i32); - - let total_tokens = usage - .get("total_tokens") - .and_then(|v| v.as_i64()) - .map(|v| v as i32) - .or_else(|| match (input_tokens, output_tokens) { - (Some(input), Some(output)) => Some(input + output), - _ => None, - }); - - Usage::new(input_tokens, output_tokens, total_tokens) -} - /// Validates and fixes tool schemas to ensure they have proper parameter structure. /// If parameters exist, ensures they have properties and required fields, or removes parameters entirely. pub fn validate_tool_schemas(tools: &mut [Value]) { diff --git a/crates/goose/src/providers/formats/openai.rs b/crates/goose/src/providers/formats/openai.rs index ce929253405c..7660afe12748 100644 --- a/crates/goose/src/providers/formats/openai.rs +++ b/crates/goose/src/providers/formats/openai.rs @@ -1,16 +1,55 @@ use crate::message::{Message, MessageContent}; use crate::model::ModelConfig; -use crate::providers::base::Usage; -use crate::providers::errors::ProviderError; +use crate::providers::base::{ProviderUsage, Usage}; use crate::providers::utils::{ convert_image, detect_image_path, is_valid_function_name, load_image_file, sanitize_function_name, ImageFormat, }; use anyhow::{anyhow, Error}; +use async_stream::try_stream; +use futures::Stream; use mcp_core::ToolError; use mcp_core::{Content, Role, Tool, ToolCall}; +use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; +#[derive(Serialize, Deserialize, Debug)] +struct DeltaToolCallFunction { + name: Option, + arguments: String, // chunk of encoded JSON, +} + +#[derive(Serialize, Deserialize, Debug)] +struct DeltaToolCall { + id: Option, + function: DeltaToolCallFunction, + index: Option, + r#type: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +struct Delta { + content: Option, + role: Option, + tool_calls: Option>, +} + +#[derive(Serialize, Deserialize, Debug)] +struct StreamingChoice { + delta: Delta, + index: Option, + finish_reason: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +struct StreamingChunk { + choices: Vec, + created: Option, + id: Option, + usage: Option, + model: String, +} + /// Convert internal Message format to OpenAI's API message specification /// some openai compatible endpoints use the anthropic image spec at the content level /// even though the message structure is otherwise following openai, the enum switches this @@ -281,11 +320,7 @@ pub fn response_to_message(response: Value) -> anyhow::Result { )) } -pub fn get_usage(data: &Value) -> Result { - let usage = data - .get("usage") - .ok_or_else(|| ProviderError::UsageError("No usage data in response".to_string()))?; - +pub fn get_usage(usage: &Value) -> Usage { let input_tokens = usage .get("prompt_tokens") .and_then(|v| v.as_i64()) @@ -305,7 +340,7 @@ pub fn get_usage(data: &Value) -> Result { _ => None, }); - Ok(Usage::new(input_tokens, output_tokens, total_tokens)) + Usage::new(input_tokens, output_tokens, total_tokens) } /// Validates and fixes tool schemas to ensure they have proper parameter structure. @@ -354,6 +389,117 @@ fn ensure_valid_json_schema(schema: &mut Value) { } } +fn strip_data_prefix(line: &str) -> Option<&str> { + line.strip_prefix("data: ").map(|s| s.trim()) +} + +pub fn response_to_streaming_message( + mut stream: S, +) -> impl Stream, Option)>> + 'static +where + S: Stream> + Unpin + Send + 'static, +{ + try_stream! { + use futures::StreamExt; + + 'outer: while let Some(response) = stream.next().await { + if response.as_ref().is_ok_and(|s| s == "data: [DONE]") { + break 'outer; + } + let response_str = response?; + let line = strip_data_prefix(&response_str); + + if line.is_none() || line.is_some_and(|l| l.is_empty()) { + continue + } + + let chunk: StreamingChunk = serde_json::from_str(line + .ok_or_else(|| anyhow!("unexpected stream format"))?) + .map_err(|e| anyhow!("Failed to parse streaming chunk: {}: {:?}", e, &line))?; + let model = chunk.model.clone(); + + let usage = chunk.usage.as_ref().map(|u| { + ProviderUsage { + usage: get_usage(u), + model, + } + }); + + if chunk.choices.is_empty() { + yield (None, usage) + } else if let Some(tool_calls) = &chunk.choices[0].delta.tool_calls { + let tool_call = &tool_calls[0]; + let id = tool_call.id.clone().ok_or(anyhow!("No tool call ID"))?; + let function_name = tool_call.function.name.clone().ok_or(anyhow!("No function name"))?; + let mut arguments = tool_call.function.arguments.clone(); + + while let Some(response_chunk) = stream.next().await { + if response_chunk.as_ref().is_ok_and(|s| s == "data: [DONE]") { + break 'outer; + } + let response_str = response_chunk?; + if let Some(line) = strip_data_prefix(&response_str) { + let tool_chunk: StreamingChunk = serde_json::from_str(line) + .map_err(|e| anyhow!("Failed to parse streaming chunk: {}: {:?}", e, &line))?; + let more_args = tool_chunk.choices[0].delta.tool_calls.as_ref() + .and_then(|calls| calls.first()) + .map(|call| call.function.arguments.as_str()); + if let Some(more_args) = more_args { + arguments.push_str(more_args); + } else { + break; + } + } + } + + let parsed = if arguments.is_empty() { + Ok(json!({})) + } else { + serde_json::from_str::(&arguments) + }; + + let content = match parsed { + Ok(params) => MessageContent::tool_request( + id, + Ok(ToolCall::new(function_name, params)), + ), + Err(e) => { + let error = ToolError::InvalidParameters(format!( + "Could not interpret tool use parameters for id {}: {}", + id, e + )); + MessageContent::tool_request(id, Err(error)) + } + }; + + yield ( + Some(Message { + id: chunk.id, + role: Role::Assistant, + created: chrono::Utc::now().timestamp(), + content: vec![content], + }), + usage, + ) + } else if let Some(text) = &chunk.choices[0].delta.content { + yield ( + Some(Message { + id: chunk.id, + role: Role::Assistant, + created: chrono::Utc::now().timestamp(), + content: vec![MessageContent::text(text)], + }), + if chunk.choices[0].finish_reason.is_some() { + usage + } else { + None + }, + ) + } + } + } +} + pub fn create_request( model_config: &ModelConfig, system: &str, diff --git a/crates/goose/src/providers/githubcopilot.rs b/crates/goose/src/providers/githubcopilot.rs index 97bd3ad589e0..ef7a9fbecc43 100644 --- a/crates/goose/src/providers/githubcopilot.rs +++ b/crates/goose/src/providers/githubcopilot.rs @@ -415,14 +415,10 @@ impl Provider for GithubCopilotProvider { // Parse response let message = response_to_message(response.clone())?; - let usage = match get_usage(&response) { - Ok(usage) => usage, - Err(ProviderError::UsageError(e)) => { - tracing::debug!("Failed to get usage data: {}", e); - Usage::default() - } - Err(e) => return Err(e), - }; + let usage = response.get("usage").map(get_usage).unwrap_or_else(|| { + tracing::debug!("Failed to get usage data"); + Usage::default() + }); let model = get_model(&response); emit_debug_trace(&self.model, &payload, &response, &usage); Ok((message, ProviderUsage::new(model, usage))) diff --git a/crates/goose/src/providers/groq.rs b/crates/goose/src/providers/groq.rs index 3716df0e6dc3..9c8c5af9fecf 100644 --- a/crates/goose/src/providers/groq.rs +++ b/crates/goose/src/providers/groq.rs @@ -139,14 +139,10 @@ impl Provider for GroqProvider { let response = self.post(payload.clone()).await?; let message = response_to_message(response.clone())?; - let usage = match get_usage(&response) { - Ok(usage) => usage, - Err(ProviderError::UsageError(e)) => { - tracing::debug!("Failed to get usage data: {}", e); - Usage::default() - } - Err(e) => return Err(e), - }; + let usage = response.get("usage").map(get_usage).unwrap_or_else(|| { + tracing::debug!("Failed to get usage data"); + Usage::default() + }); let model = get_model(&response); super::utils::emit_debug_trace(&self.model, &payload, &response, &usage); Ok((message, ProviderUsage::new(model, usage))) diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index 4bbf1c392dae..bd18d593adcc 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -141,14 +141,10 @@ impl Provider for OllamaProvider { let response = self.post(payload.clone()).await?; let message = response_to_message(response.clone())?; - let usage = match get_usage(&response) { - Ok(usage) => usage, - Err(ProviderError::UsageError(e)) => { - tracing::debug!("Failed to get usage data: {}", e); - Usage::default() - } - Err(e) => return Err(e), - }; + let usage = response.get("usage").map(get_usage).unwrap_or_else(|| { + tracing::debug!("Failed to get usage data"); + Usage::default() + }); let model = get_model(&response); super::utils::emit_debug_trace(&self.model, &payload, &response, &usage); Ok((message, ProviderUsage::new(model, usage))) diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index 9884d147bffc..e4215ea85157 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -1,9 +1,16 @@ use anyhow::Result; +use async_stream::try_stream; use async_trait::async_trait; -use reqwest::Client; +use futures::TryStreamExt; +use reqwest::{Client, Response}; use serde_json::Value; use std::collections::HashMap; +use std::io; use std::time::Duration; +use tokio::pin; +use tokio_stream::StreamExt; +use tokio_util::codec::{FramedRead, LinesCodec}; +use tokio_util::io::StreamReader; use super::base::{ConfigKey, ModelInfo, Provider, ProviderMetadata, ProviderUsage, Usage}; use super::embedding::{EmbeddingCapable, EmbeddingRequest, EmbeddingResponse}; @@ -12,6 +19,9 @@ use super::formats::openai::{create_request, get_usage, response_to_message}; use super::utils::{emit_debug_trace, get_model, handle_response_openai_compat, ImageFormat}; use crate::message::Message; use crate::model::ModelConfig; +use crate::providers::base::MessageStream; +use crate::providers::formats::openai::response_to_streaming_message; +use crate::providers::utils::handle_status_openai_compat; use mcp_core::tool::Tool; pub const OPEN_AI_DEFAULT_MODEL: &str = "gpt-4o"; @@ -103,7 +113,7 @@ impl OpenAiProvider { request } - async fn post(&self, payload: Value) -> Result { + async fn post(&self, payload: Value) -> Result { let base_url = url::Url::parse(&self.host) .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; let url = base_url.join(&self.base_path).map_err(|e| { @@ -117,9 +127,7 @@ impl OpenAiProvider { let request = self.add_headers(request); - let response = request.json(&payload).send().await?; - - handle_response_openai_compat(response).await + Ok(request.json(&payload).send().await?) } } @@ -170,18 +178,14 @@ impl Provider for OpenAiProvider { let payload = create_request(&self.model, system, messages, tools, &ImageFormat::OpenAi)?; // Make request - let response = self.post(payload.clone()).await?; + let response = handle_response_openai_compat(self.post(payload.clone()).await?).await?; // Parse response let message = response_to_message(response.clone())?; - let usage = match get_usage(&response) { - Ok(usage) => usage, - Err(ProviderError::UsageError(e)) => { - tracing::debug!("Failed to get usage data: {}", e); - Usage::default() - } - Err(e) => return Err(e), - }; + let usage = response.get("usage").map(get_usage).unwrap_or_else(|| { + tracing::debug!("Failed to get usage data"); + Usage::default() + }); let model = get_model(&response); emit_debug_trace(&self.model, &payload, &response, &usage); Ok((message, ProviderUsage::new(model, usage))) @@ -236,6 +240,40 @@ impl Provider for OpenAiProvider { .await .map_err(|e| ProviderError::ExecutionError(e.to_string())) } + + fn supports_streaming(&self) -> bool { + true + } + + async fn stream( + &self, + system: &str, + messages: &[Message], + tools: &[Tool], + ) -> Result { + let mut payload = + create_request(&self.model, system, messages, tools, &ImageFormat::OpenAi)?; + payload["stream"] = serde_json::Value::Bool(true); + + let response = handle_status_openai_compat(self.post(payload.clone()).await?).await?; + + let stream = response.bytes_stream().map_err(io::Error::other); + + let model_config = self.model.clone(); + // Wrap in a line decoder and yield lines inside the stream + Ok(Box::pin(try_stream! { + let stream_reader = StreamReader::new(stream); + let framed = FramedRead::new(stream_reader, LinesCodec::new()).map_err(anyhow::Error::from); + + let message_stream = response_to_streaming_message(framed); + pin!(message_stream); + while let Some(message) = message_stream.next().await { + let (message, usage) = message.map_err(|e| ProviderError::RequestFailed(format!("Stream decode error: {}", e)))?; + super::utils::emit_debug_trace(&model_config, &payload, &message, &usage.as_ref().map(|f| f.usage).unwrap_or_default()); + yield (message, usage); + } + })) + } } fn parse_custom_headers(s: String) -> HashMap { diff --git a/crates/goose/src/providers/openrouter.rs b/crates/goose/src/providers/openrouter.rs index 0352012fc462..d629c37b014f 100644 --- a/crates/goose/src/providers/openrouter.rs +++ b/crates/goose/src/providers/openrouter.rs @@ -266,14 +266,10 @@ impl Provider for OpenRouterProvider { // Parse response let message = response_to_message(response.clone())?; - let usage = match get_usage(&response) { - Ok(usage) => usage, - Err(ProviderError::UsageError(e)) => { - tracing::debug!("Failed to get usage data: {}", e); - Usage::default() - } - Err(e) => return Err(e), - }; + let usage = response.get("usage").map(get_usage).unwrap_or_else(|| { + tracing::debug!("Failed to get usage data"); + Usage::default() + }); let model = get_model(&response); emit_debug_trace(&self.model, &payload, &response, &usage); Ok((message, ProviderUsage::new(model, usage))) diff --git a/crates/goose/src/providers/utils.rs b/crates/goose/src/providers/utils.rs index 7bcc172b6458..9c4ec8c6cb38 100644 --- a/crates/goose/src/providers/utils.rs +++ b/crates/goose/src/providers/utils.rs @@ -47,48 +47,57 @@ pub fn convert_image(image: &ImageContent, image_format: &ImageFormat) -> Value /// Handle response from OpenAI compatible endpoints /// Error codes: https://platform.openai.com/docs/guides/error-codes /// Context window exceeded: https://community.openai.com/t/help-needed-tackling-context-length-limits-in-openai-models/617543 -pub async fn handle_response_openai_compat(response: Response) -> Result { +pub async fn handle_status_openai_compat(response: Response) -> Result { let status = response.status(); - // Try to parse the response body as JSON (if applicable) - let payload = match response.json::().await { - Ok(json) => json, - Err(e) => return Err(ProviderError::RequestFailed(e.to_string())), - }; match status { - StatusCode::OK => Ok(payload), - StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { - Err(ProviderError::Authentication(format!("Authentication failed. Please ensure your API keys are valid and have the required permissions. \ - Status: {}. Response: {:?}", status, payload))) - } - StatusCode::BAD_REQUEST | StatusCode::NOT_FOUND => { - tracing::debug!( - "{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload) - ); - if let Ok(err_resp) = from_value::(payload) { - let err = err_resp.error; - if err.is_context_length_exceeded() { - return Err(ProviderError::ContextLengthExceeded(err.message.unwrap_or("Unknown error".to_string()))); + StatusCode::OK => Ok(response), + _ => { + let body = response.json::().await; + match (body, status) { + (Err(e), _) => Err(ProviderError::RequestFailed(e.to_string())), + (Ok(body), StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN) => { + Err(ProviderError::Authentication(format!("Authentication failed. Please ensure your API keys are valid and have the required permissions. \ + Status: {}. Response: {:?}", status, body))) + } + (Ok(body), StatusCode::BAD_REQUEST | StatusCode::NOT_FOUND) => { + tracing::debug!( + "{}", format!("Provider request failed with status: {}. Payload: {:?}", status, body) + ); + if let Ok(err_resp) = from_value::(body) { + let err = err_resp.error; + if err.is_context_length_exceeded() { + return Err(ProviderError::ContextLengthExceeded(err.message.unwrap_or("Unknown error".to_string()))); + } + return Err(ProviderError::RequestFailed(format!("{} (status {})", err, status.as_u16()))); + } + Err(ProviderError::RequestFailed(format!("Unknown error (status {})", status))) + } + (Ok(body), StatusCode::TOO_MANY_REQUESTS) => { + Err(ProviderError::RateLimitExceeded(format!("{:?}", body))) + } + (Ok(body), StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE) => { + Err(ProviderError::ServerError(format!("{:?}", body))) + } + (Ok(body), _) => { + tracing::debug!( + "{}", format!("Provider request failed with status: {}. Payload: {:?}", status, body) + ); + Err(ProviderError::RequestFailed(format!("Request failed with status: {}", status))) } - return Err(ProviderError::RequestFailed(format!("{} (status {})", err, status.as_u16()))); } - Err(ProviderError::RequestFailed(format!("Unknown error (status {})", status))) - } - StatusCode::TOO_MANY_REQUESTS => { - Err(ProviderError::RateLimitExceeded(format!("{:?}", payload))) - } - StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => { - Err(ProviderError::ServerError(format!("{:?}", payload))) - } - _ => { - tracing::debug!( - "{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload) - ); - Err(ProviderError::RequestFailed(format!("Request failed with status: {}", status))) } } } +pub async fn handle_response_openai_compat(response: Response) -> Result { + let response = handle_status_openai_compat(response).await?; + + response.json::().await.map_err(|e| { + ProviderError::RequestFailed(format!("Response body is not valid JSON: {}", e)) + }) +} + /// Check if the model is a Google model based on the "model" field in the payload. /// /// ### Arguments diff --git a/crates/goose/src/providers/xai.rs b/crates/goose/src/providers/xai.rs index 7e91a23f8b9e..cdaebdc0b1b8 100644 --- a/crates/goose/src/providers/xai.rs +++ b/crates/goose/src/providers/xai.rs @@ -166,14 +166,10 @@ impl Provider for XaiProvider { let response = self.post(payload.clone()).await?; let message = response_to_message(response.clone())?; - let usage = match get_usage(&response) { - Ok(usage) => usage, - Err(ProviderError::UsageError(e)) => { - tracing::debug!("Failed to get usage data: {}", e); - Usage::default() - } - Err(e) => return Err(e), - }; + let usage = response.get("usage").map(get_usage).unwrap_or_else(|| { + tracing::debug!("Failed to get usage data"); + Usage::default() + }); let model = get_model(&response); super::utils::emit_debug_trace(&self.model, &payload, &response, &usage); Ok((message, ProviderUsage::new(model, usage)))