diff --git a/crates/goose/src/providers/errors.rs b/crates/goose/src/providers/errors.rs index 0060f6a6f764..ff2772e0f40c 100644 --- a/crates/goose/src/providers/errors.rs +++ b/crates/goose/src/providers/errors.rs @@ -1,7 +1,7 @@ use reqwest::StatusCode; use thiserror::Error; -#[derive(Error, Debug)] +#[derive(Error, Debug, PartialEq)] pub enum ProviderError { #[error("Authentication error: {0}")] Authentication(String), diff --git a/crates/goose/src/providers/utils.rs b/crates/goose/src/providers/utils.rs index 667323f46edd..e5dec1b72a47 100644 --- a/crates/goose/src/providers/utils.rs +++ b/crates/goose/src/providers/utils.rs @@ -7,7 +7,7 @@ use regex::Regex; use reqwest::{Response, StatusCode}; use rmcp::model::{AnnotateAble, ImageContent, RawImageContent}; use serde::{Deserialize, Serialize}; -use serde_json::{from_value, json, Map, Value}; +use serde_json::{json, Map, Value}; use std::io::Read; use std::path::Path; @@ -71,10 +71,13 @@ pub fn map_http_error_to_provider_error( let error = match status { StatusCode::OK => unreachable!("Should not call this function with OK status"), StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { - ProviderError::Authentication(format!( + let message = format!( "Authentication failed. Please ensure your API keys are valid and have the required permissions. \ - Status: {}. Response: {:?}", status, payload - )) + Status: {}{}", + status, + payload.as_ref().map(|p| format!(". Response: {:?}", p)).unwrap_or_default() + ); + ProviderError::Authentication(message) } StatusCode::BAD_REQUEST => { let mut error_msg = "Unknown error".to_string(); @@ -84,26 +87,27 @@ pub fn map_http_error_to_provider_error( ProviderError::ContextLengthExceeded(payload_str) } else { if let Some(error) = payload.get("error") { - error_msg = error.get("message") + error_msg = error + .get("message") .and_then(|m| m.as_str()) .unwrap_or("Unknown error") .to_string(); } - ProviderError::RequestFailed(format!("Request failed with status: {}. Message: {}", status, error_msg)) + ProviderError::RequestFailed(format!( + "Request failed with status: {}. Message: {}", + status, error_msg + )) } } else { - ProviderError::RequestFailed(format!("Request failed with status: {}. Message: {}", status, error_msg)) + ProviderError::RequestFailed(format!( + "Request failed with status: {}. Message: {}", + status, error_msg + )) } } - StatusCode::TOO_MANY_REQUESTS => { - ProviderError::RateLimitExceeded(format!("{:?}", payload)) - } - _ if status.is_server_error() => { - ProviderError::ServerError(format!("{:?}", payload)) - } - _ => { - ProviderError::RequestFailed(format!("Request failed with status: {}", status)) - } + StatusCode::TOO_MANY_REQUESTS => ProviderError::RateLimitExceeded(format!("{:?}", payload)), + _ if status.is_server_error() => ProviderError::ServerError(format!("{:?}", payload)), + _ => ProviderError::RequestFailed(format!("Request failed with status: {}", status)), }; if !status.is_success() { @@ -118,45 +122,51 @@ pub fn map_http_error_to_provider_error( error } -/// Handle response from OpenAI compatible endpoints -/// Error codes: https://platform.openai.com/docs/guides/error-codes -/// Context window exceeded: https://community.openai.com/t/help-needed-tackling-context-length-limits-in-openai-models/617543 +/// Handles HTTP responses from OpenAI-compatible endpoints. +/// +/// Returns the response if status is OK; otherwise, reads the body and maps to a `ProviderError`, +/// with special handling for context length exceeded and other OpenAI-formatted errors. +/// +/// ### References +/// - Error Codes: https://platform.openai.com/docs/guides/error-codes +/// - Context Window Exceeded: https://community.openai.com/t/help-needed-tackling-context-length-limits-in-openai-models/617543 +/// +/// ### Arguments +/// - `response`: The HTTP response to process. +/// +/// ### Returns +/// - `Ok(Response)`: The original response on success. +/// - `Err(ProviderError)`: Describes the failure reason.``` pub async fn handle_status_openai_compat(response: Response) -> Result { let status = response.status(); + if status == StatusCode::OK { + return Ok(response); + } - match status { - StatusCode::OK => Ok(response), - _ => { - let body = response.json::().await; - match body { - Err(e) => Err(ProviderError::RequestFailed(e.to_string())), - Ok(body) => { - let error = if matches!(status, StatusCode::BAD_REQUEST | StatusCode::NOT_FOUND) - { - if let Ok(err_resp) = from_value::(body.clone()) { - let err = err_resp.error; - if err.is_context_length_exceeded() { - ProviderError::ContextLengthExceeded( - err.message.unwrap_or("Unknown error".to_string()), - ) - } else { - ProviderError::RequestFailed(format!( - "{} (status {})", - err, - status.as_u16() - )) - } - } else { - map_http_error_to_provider_error(status, Some(body)) - } - } else { - map_http_error_to_provider_error(status, Some(body)) - }; - Err(error) - } + let body_str = response + .text() + .await + .map_err(|_| map_http_error_to_provider_error(status, None))?; + + if matches!(status, StatusCode::BAD_REQUEST | StatusCode::NOT_FOUND) { + if let Ok(err_resp) = serde_json::from_str::(&body_str) { + let err = err_resp.error; + if err.is_context_length_exceeded() { + return Err(ProviderError::ContextLengthExceeded( + err.message.unwrap_or("Unknown error".to_string()), + )); + } else { + return Err(ProviderError::RequestFailed(format!( + "{} (status {})", + err, + status.as_u16() + ))); } } } + + let payload = serde_json::from_str::(&body_str).ok(); + Err(map_http_error_to_provider_error(status, payload)) } pub async fn handle_response_openai_compat(response: Response) -> Result { @@ -486,6 +496,7 @@ pub fn json_escape_control_chars_in_string(s: &str) -> String { mod tests { use super::*; use serde_json::json; + use wiremock::{matchers, Mock, MockServer, ResponseTemplate}; #[test] fn test_detect_image_path() { @@ -792,4 +803,200 @@ mod tests { "Hello\\u0001World" ); } + #[tokio::test] + async fn test_handle_status_openai_compat() { + let test_cases = vec![ + // (status_code, body, expected_result) + // Success case - 200 OK returns response as-is + ( + 200, + Some(json!({ + "choices": [{ + "finish_reason": "stop", + "index": 0, + "message": { + "content": "Hi there! How can I help you today?", + "role": "assistant" + } + }], + "created": 1755133833, + "id": "chatcmpl-test", + "model": "gpt-5-nano", + "usage": { + "completion_tokens": 10, + "prompt_tokens": 8, + "total_tokens": 18 + } + })), + Ok(()), + ), + // 400 Bad Request with OpenAI-formatted error (directly handled) + ( + 400, + Some(json!({ + "error": { + "code": "unsupported_parameter", + "message": "Unsupported parameter: 'max_tokens' is not supported with this model. Use 'max_completion_tokens' instead.", + "param": "max_tokens", + "type": "invalid_request_error" + } + })), + Err(ProviderError::RequestFailed( + "Unsupported parameter: 'max_tokens' is not supported with this model. Use 'max_completion_tokens' instead. (code: unsupported_parameter, type: invalid_request_error) (status 400)".to_string(), + )), + ), + // 400 with context_length_exceeded in OpenAI format (directly handled) + ( + 400, + Some(json!({ + "error": { + "code": "context_length_exceeded", + "message": "This model's maximum context length is 4096 tokens.", + "type": "invalid_request_error" + } + })), + Err(ProviderError::ContextLengthExceeded( + "This model's maximum context length is 4096 tokens.".to_string(), + )), + ), + // 404 Not Found with OpenAI-formatted error (directly handled like 400) + ( + 404, + Some(json!({ + "error": { + "code": "model_not_found", + "message": "The model 'gpt-5' does not exist", + "type": "invalid_request_error" + } + })), + Err(ProviderError::RequestFailed( + "The model 'gpt-5' does not exist (code: model_not_found, type: invalid_request_error) (status 404)".to_string(), + )), + ), + // Non-JSON body error (tests parse failure path) + ( + 413, + Some(Value::String("Payload Too Large".to_string())), + Err(ProviderError::RequestFailed( + "Request failed with status: 413 Payload Too Large".to_string(), + )), + ), + ]; + + for (status_code, body, expected_result) in test_cases { + let mock_server = MockServer::start().await; + + let mut response_template = ResponseTemplate::new(status_code); + + // Set body based on test case + if let Some(body_value) = body { + if body_value.is_string() { + // For non-JSON bodies (like "Payload Too Large") + response_template = + response_template.set_body_string(body_value.as_str().unwrap().to_string()); + } else { + // For JSON bodies + response_template = response_template.set_body_json(&body_value); + } + } + + Mock::given(matchers::method("GET")) + .and(matchers::path("/test")) + .respond_with(response_template) + .mount(&mock_server) + .await; + + // Make request to mock server + let client = reqwest::Client::new(); + let response = client + .get(&format!("{}/test", &mock_server.uri())) + .send() + .await + .unwrap(); + + // Test handle_status_openai_compat + let result = handle_status_openai_compat(response).await.map(|_| ()); + + assert_eq!(result, expected_result, "for status {}", status_code); + } + } + + #[test] + fn test_map_http_error_to_provider_error() { + let test_cases = vec![ + // UNAUTHORIZED/FORBIDDEN - with payload + ( + StatusCode::UNAUTHORIZED, + Some(json!({"error": "auth failed"})), + ProviderError::Authentication( + "Authentication failed. Please ensure your API keys are valid and have the required permissions. Status: 401 Unauthorized. Response: Object {\"error\": String(\"auth failed\")}".to_string(), + ), + ), + // UNAUTHORIZED/FORBIDDEN - without payload + ( + StatusCode::FORBIDDEN, + None, + ProviderError::Authentication( + "Authentication failed. Please ensure your API keys are valid and have the required permissions. Status: 403 Forbidden".to_string(), + ), + ), + // BAD_REQUEST - with context_length_exceeded detection + ( + StatusCode::BAD_REQUEST, + Some(json!({"error": {"message": "context_length_exceeded"}})), + ProviderError::ContextLengthExceeded( + "{\"error\":{\"message\":\"context_length_exceeded\"}}".to_string(), + ), + ), + // BAD_REQUEST - with error.message extraction + ( + StatusCode::BAD_REQUEST, + Some(json!({"error": {"message": "Custom error"}})), + ProviderError::RequestFailed( + "Request failed with status: 400 Bad Request. Message: Custom error".to_string(), + ), + ), + // BAD_REQUEST - without payload + ( + StatusCode::BAD_REQUEST, + None, + ProviderError::RequestFailed( + "Request failed with status: 400 Bad Request. Message: Unknown error".to_string(), + ), + ), + // TOO_MANY_REQUESTS + ( + StatusCode::TOO_MANY_REQUESTS, + Some(json!({"retry_after": 60})), + ProviderError::RateLimitExceeded( + "Some(Object {\"retry_after\": Number(60)})".to_string(), + ), + ), + // is_server_error() without payload + ( + StatusCode::INTERNAL_SERVER_ERROR, + None, + ProviderError::ServerError("None".to_string()), + ), + // is_server_error() with payload + ( + StatusCode::BAD_GATEWAY, + Some(json!({"error": "upstream error"})), + ProviderError::ServerError("Some(Object {\"error\": String(\"upstream error\")})".to_string()), + ), + // Default - any other status code + ( + StatusCode::IM_A_TEAPOT, + Some(json!({"ignored": "payload"})), + ProviderError::RequestFailed( + "Request failed with status: 418 I'm a teapot".to_string(), + ), + ), + ]; + + for (status, payload, expected_error) in test_cases { + let result = map_http_error_to_provider_error(status, payload); + assert_eq!(result, expected_error); + } + } }