Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/goose/src/providers/errors.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use reqwest::StatusCode;
use thiserror::Error;

#[derive(Error, Debug)]
#[derive(Error, Debug, PartialEq)]
pub enum ProviderError {
#[error("Authentication error: {0}")]
Authentication(String),
Expand Down
305 changes: 256 additions & 49 deletions crates/goose/src/providers/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use regex::Regex;
use reqwest::{Response, StatusCode};
use rmcp::model::{AnnotateAble, ImageContent, RawImageContent};
use serde::{Deserialize, Serialize};
use serde_json::{from_value, json, Map, Value};
use serde_json::{json, Map, Value};
use std::io::Read;
use std::path::Path;

Expand Down Expand Up @@ -71,10 +71,13 @@ pub fn map_http_error_to_provider_error(
let error = match status {
StatusCode::OK => unreachable!("Should not call this function with OK status"),
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
ProviderError::Authentication(format!(
let message = format!(
"Authentication failed. Please ensure your API keys are valid and have the required permissions. \
Status: {}. Response: {:?}", status, payload
))
Status: {}{}",
status,
payload.as_ref().map(|p| format!(". Response: {:?}", p)).unwrap_or_default()
);
ProviderError::Authentication(message)
}
StatusCode::BAD_REQUEST => {
let mut error_msg = "Unknown error".to_string();
Expand All @@ -84,26 +87,27 @@ pub fn map_http_error_to_provider_error(
ProviderError::ContextLengthExceeded(payload_str)
} else {
if let Some(error) = payload.get("error") {
error_msg = error.get("message")
error_msg = error
.get("message")
.and_then(|m| m.as_str())
.unwrap_or("Unknown error")
.to_string();
}
ProviderError::RequestFailed(format!("Request failed with status: {}. Message: {}", status, error_msg))
ProviderError::RequestFailed(format!(
"Request failed with status: {}. Message: {}",
status, error_msg
))
}
} else {
ProviderError::RequestFailed(format!("Request failed with status: {}. Message: {}", status, error_msg))
ProviderError::RequestFailed(format!(
"Request failed with status: {}. Message: {}",
status, error_msg
))
}
}
StatusCode::TOO_MANY_REQUESTS => {
ProviderError::RateLimitExceeded(format!("{:?}", payload))
}
_ if status.is_server_error() => {
ProviderError::ServerError(format!("{:?}", payload))
}
_ => {
ProviderError::RequestFailed(format!("Request failed with status: {}", status))
}
StatusCode::TOO_MANY_REQUESTS => ProviderError::RateLimitExceeded(format!("{:?}", payload)),
_ if status.is_server_error() => ProviderError::ServerError(format!("{:?}", payload)),
_ => ProviderError::RequestFailed(format!("Request failed with status: {}", status)),
};

if !status.is_success() {
Expand All @@ -118,45 +122,51 @@ pub fn map_http_error_to_provider_error(
error
}

/// Handle response from OpenAI compatible endpoints
/// Error codes: https://platform.openai.com/docs/guides/error-codes
/// Context window exceeded: https://community.openai.com/t/help-needed-tackling-context-length-limits-in-openai-models/617543
/// Handles HTTP responses from OpenAI-compatible endpoints.
///
/// Returns the response if status is OK; otherwise, reads the body and maps to a `ProviderError`,
/// with special handling for context length exceeded and other OpenAI-formatted errors.
///
/// ### References
/// - Error Codes: https://platform.openai.com/docs/guides/error-codes
/// - Context Window Exceeded: https://community.openai.com/t/help-needed-tackling-context-length-limits-in-openai-models/617543
///
/// ### Arguments
/// - `response`: The HTTP response to process.
///
/// ### Returns
/// - `Ok(Response)`: The original response on success.
/// - `Err(ProviderError)`: Describes the failure reason.```
pub async fn handle_status_openai_compat(response: Response) -> Result<Response, ProviderError> {
let status = response.status();
if status == StatusCode::OK {
return Ok(response);
}

match status {
StatusCode::OK => Ok(response),
_ => {
let body = response.json::<Value>().await;
match body {
Err(e) => Err(ProviderError::RequestFailed(e.to_string())),
Ok(body) => {
let error = if matches!(status, StatusCode::BAD_REQUEST | StatusCode::NOT_FOUND)
{
if let Ok(err_resp) = from_value::<OpenAIErrorResponse>(body.clone()) {
let err = err_resp.error;
if err.is_context_length_exceeded() {
ProviderError::ContextLengthExceeded(
err.message.unwrap_or("Unknown error".to_string()),
)
} else {
ProviderError::RequestFailed(format!(
"{} (status {})",
err,
status.as_u16()
))
}
} else {
map_http_error_to_provider_error(status, Some(body))
}
} else {
map_http_error_to_provider_error(status, Some(body))
};
Err(error)
}
let body_str = response
.text()
.await
.map_err(|_| map_http_error_to_provider_error(status, None))?;

if matches!(status, StatusCode::BAD_REQUEST | StatusCode::NOT_FOUND) {
if let Ok(err_resp) = serde_json::from_str::<OpenAIErrorResponse>(&body_str) {
let err = err_resp.error;
if err.is_context_length_exceeded() {
return Err(ProviderError::ContextLengthExceeded(
err.message.unwrap_or("Unknown error".to_string()),
));
} else {
return Err(ProviderError::RequestFailed(format!(
"{} (status {})",
err,
status.as_u16()
)));
}
}
}

let payload = serde_json::from_str::<Value>(&body_str).ok();
Err(map_http_error_to_provider_error(status, payload))
}

pub async fn handle_response_openai_compat(response: Response) -> Result<Value, ProviderError> {
Expand Down Expand Up @@ -486,6 +496,7 @@ pub fn json_escape_control_chars_in_string(s: &str) -> String {
mod tests {
use super::*;
use serde_json::json;
use wiremock::{matchers, Mock, MockServer, ResponseTemplate};

#[test]
fn test_detect_image_path() {
Expand Down Expand Up @@ -792,4 +803,200 @@ mod tests {
"Hello\\u0001World"
);
}
#[tokio::test]
async fn test_handle_status_openai_compat() {
let test_cases = vec![
// (status_code, body, expected_result)
// Success case - 200 OK returns response as-is
(
200,
Some(json!({
"choices": [{
"finish_reason": "stop",
"index": 0,
"message": {
"content": "Hi there! How can I help you today?",
"role": "assistant"
}
}],
"created": 1755133833,
"id": "chatcmpl-test",
"model": "gpt-5-nano",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 8,
"total_tokens": 18
}
})),
Ok(()),
),
// 400 Bad Request with OpenAI-formatted error (directly handled)
(
400,
Some(json!({
"error": {
"code": "unsupported_parameter",
"message": "Unsupported parameter: 'max_tokens' is not supported with this model. Use 'max_completion_tokens' instead.",
"param": "max_tokens",
"type": "invalid_request_error"
}
})),
Err(ProviderError::RequestFailed(
"Unsupported parameter: 'max_tokens' is not supported with this model. Use 'max_completion_tokens' instead. (code: unsupported_parameter, type: invalid_request_error) (status 400)".to_string(),
)),
),
// 400 with context_length_exceeded in OpenAI format (directly handled)
(
400,
Some(json!({
"error": {
"code": "context_length_exceeded",
"message": "This model's maximum context length is 4096 tokens.",
"type": "invalid_request_error"
}
})),
Err(ProviderError::ContextLengthExceeded(
"This model's maximum context length is 4096 tokens.".to_string(),
)),
),
// 404 Not Found with OpenAI-formatted error (directly handled like 400)
(
404,
Some(json!({
"error": {
"code": "model_not_found",
"message": "The model 'gpt-5' does not exist",
"type": "invalid_request_error"
}
})),
Err(ProviderError::RequestFailed(
"The model 'gpt-5' does not exist (code: model_not_found, type: invalid_request_error) (status 404)".to_string(),
)),
),
// Non-JSON body error (tests parse failure path)
(
413,
Some(Value::String("Payload Too Large".to_string())),
Err(ProviderError::RequestFailed(
"Request failed with status: 413 Payload Too Large".to_string(),
)),
),
];

for (status_code, body, expected_result) in test_cases {
let mock_server = MockServer::start().await;

let mut response_template = ResponseTemplate::new(status_code);

// Set body based on test case
if let Some(body_value) = body {
if body_value.is_string() {
// For non-JSON bodies (like "Payload Too Large")
response_template =
response_template.set_body_string(body_value.as_str().unwrap().to_string());
} else {
// For JSON bodies
response_template = response_template.set_body_json(&body_value);
}
}

Mock::given(matchers::method("GET"))
.and(matchers::path("/test"))
.respond_with(response_template)
.mount(&mock_server)
.await;

// Make request to mock server
let client = reqwest::Client::new();
let response = client
.get(&format!("{}/test", &mock_server.uri()))
.send()
.await
.unwrap();

// Test handle_status_openai_compat
let result = handle_status_openai_compat(response).await.map(|_| ());

assert_eq!(result, expected_result, "for status {}", status_code);
}
}

#[test]
fn test_map_http_error_to_provider_error() {
let test_cases = vec![
// UNAUTHORIZED/FORBIDDEN - with payload
(
StatusCode::UNAUTHORIZED,
Some(json!({"error": "auth failed"})),
ProviderError::Authentication(
"Authentication failed. Please ensure your API keys are valid and have the required permissions. Status: 401 Unauthorized. Response: Object {\"error\": String(\"auth failed\")}".to_string(),
),
),
// UNAUTHORIZED/FORBIDDEN - without payload
(
StatusCode::FORBIDDEN,
None,
ProviderError::Authentication(
"Authentication failed. Please ensure your API keys are valid and have the required permissions. Status: 403 Forbidden".to_string(),
),
),
// BAD_REQUEST - with context_length_exceeded detection
(
StatusCode::BAD_REQUEST,
Some(json!({"error": {"message": "context_length_exceeded"}})),
ProviderError::ContextLengthExceeded(
"{\"error\":{\"message\":\"context_length_exceeded\"}}".to_string(),
),
),
// BAD_REQUEST - with error.message extraction
(
StatusCode::BAD_REQUEST,
Some(json!({"error": {"message": "Custom error"}})),
ProviderError::RequestFailed(
"Request failed with status: 400 Bad Request. Message: Custom error".to_string(),
),
),
// BAD_REQUEST - without payload
(
StatusCode::BAD_REQUEST,
None,
ProviderError::RequestFailed(
"Request failed with status: 400 Bad Request. Message: Unknown error".to_string(),
),
),
// TOO_MANY_REQUESTS
(
StatusCode::TOO_MANY_REQUESTS,
Some(json!({"retry_after": 60})),
ProviderError::RateLimitExceeded(
"Some(Object {\"retry_after\": Number(60)})".to_string(),
),
),
// is_server_error() without payload
(
StatusCode::INTERNAL_SERVER_ERROR,
None,
ProviderError::ServerError("None".to_string()),
),
// is_server_error() with payload
(
StatusCode::BAD_GATEWAY,
Some(json!({"error": "upstream error"})),
ProviderError::ServerError("Some(Object {\"error\": String(\"upstream error\")})".to_string()),
),
// Default - any other status code
(
StatusCode::IM_A_TEAPOT,
Some(json!({"ignored": "payload"})),
ProviderError::RequestFailed(
"Request failed with status: 418 I'm a teapot".to_string(),
),
),
];

for (status, payload, expected_error) in test_cases {
let result = map_http_error_to_provider_error(status, payload);
assert_eq!(result, expected_error);
}
}
}
Loading