diff --git a/crates/goose-cli/src/commands/configure.rs b/crates/goose-cli/src/commands/configure.rs index 726087aefe75..e4673415dd3d 100644 --- a/crates/goose-cli/src/commands/configure.rs +++ b/crates/goose-cli/src/commands/configure.rs @@ -20,7 +20,7 @@ use goose::config::{ use goose::conversation::message::Message; use goose::model::ModelConfig; use goose::providers::provider_test::test_provider_configuration; -use goose::providers::{create, providers}; +use goose::providers::{create, providers, retry_operation, RetryConfig}; use goose::session::{SessionManager, SessionType}; use serde_json::Value; use std::collections::HashMap; @@ -570,13 +570,15 @@ pub async fn configure_provider_dialog() -> anyhow::Result { } } - // Attempt to fetch supported models for this provider let spin = spinner(); spin.start("Attempting to fetch supported models..."); let models_res = { let temp_model_config = ModelConfig::new(&provider_meta.default_model)?; let temp_provider = create(provider_name, temp_model_config).await?; - temp_provider.fetch_recommended_models().await + retry_operation(&RetryConfig::default(), || async { + temp_provider.fetch_recommended_models().await + }) + .await }; spin.stop(style("Model fetch complete").green()); diff --git a/crates/goose-server/src/routes/config_management.rs b/crates/goose-server/src/routes/config_management.rs index 83f94aba37de..409bf4d0d97a 100644 --- a/crates/goose-server/src/routes/config_management.rs +++ b/crates/goose-server/src/routes/config_management.rs @@ -15,7 +15,9 @@ use goose::providers::auto_detect::detect_provider_from_api_key; use goose::providers::base::{ProviderMetadata, ProviderType}; use goose::providers::canonical::maybe_get_canonical_model; use goose::providers::create_with_default_model; +use goose::providers::errors::ProviderError; use goose::providers::providers as get_providers; +use goose::providers::{retry_operation, RetryConfig}; use goose::{ agents::execute_commands, agents::ExtensionConfig, config::permission::PermissionLevel, slash_commands, @@ -399,13 +401,15 @@ pub async fn get_provider_models( .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - let models_result = provider.fetch_recommended_models().await; + let models_result = retry_operation(&RetryConfig::default(), || async { + provider.fetch_recommended_models().await + }) + .await; match models_result { Ok(Some(models)) => Ok(Json(models)), Ok(None) => Ok(Json(Vec::new())), Err(provider_error) => { - use goose::providers::errors::ProviderError; let status_code = match provider_error { // Permanent misconfigurations - client should fix configuration ProviderError::Authentication(_) => StatusCode::BAD_REQUEST, diff --git a/crates/goose/src/providers/auto_detect.rs b/crates/goose/src/providers/auto_detect.rs index 71a60f6b151e..0513fd928be7 100644 --- a/crates/goose/src/providers/auto_detect.rs +++ b/crates/goose/src/providers/auto_detect.rs @@ -1,4 +1,5 @@ use crate::model::ModelConfig; +use crate::providers::retry::{retry_operation, RetryConfig}; pub async fn detect_provider_from_api_key(api_key: &str) -> Option<(String, Vec)> { let provider_tests = vec![ @@ -24,10 +25,16 @@ pub async fn detect_provider_from_api_key(api_key: &str) -> Option<(String, Vec< ) .await { - Ok(provider) => match provider.fetch_supported_models().await { - Ok(Some(models)) => Some((provider_name.to_string(), models)), - _ => None, - }, + Ok(provider) => { + match retry_operation(&RetryConfig::default(), || async { + provider.fetch_supported_models().await + }) + .await + { + Ok(Some(models)) => Some((provider_name.to_string(), models)), + _ => None, + } + } Err(_) => None, }; diff --git a/crates/goose/src/providers/mod.rs b/crates/goose/src/providers/mod.rs index 4f83f683c2df..7a9397c92259 100644 --- a/crates/goose/src/providers/mod.rs +++ b/crates/goose/src/providers/mod.rs @@ -42,3 +42,4 @@ pub mod xai; pub use factory::{ create, create_with_default_model, create_with_named_model, providers, refresh_custom_providers, }; +pub use retry::{retry_operation, RetryConfig}; diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index f2cd11a51541..f6fe3042bc03 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -320,30 +320,19 @@ impl Provider for OpenAiProvider { async fn fetch_supported_models(&self) -> Result>, ProviderError> { let models_path = self.base_path.replace("v1/chat/completions", "v1/models"); - let response = self - .with_retry(|| async { - let response = self.api_client.response_get(&models_path).await?; - let json = handle_response_openai_compat(response).await?; - if let Some(err_obj) = json.get("error") { - let msg = err_obj - .get("message") - .and_then(|v| v.as_str()) - .unwrap_or("unknown error"); - return Err(ProviderError::Authentication(msg.to_string())); - } - Ok(json) - }) - .await - .inspect_err(|e| { - tracing::warn!("Failed to fetch supported models from OpenAI: {:?}", e); - })?; - - let data = response - .get("data") - .and_then(|v| v.as_array()) - .ok_or_else(|| { - ProviderError::UsageError("Missing data field in JSON response".into()) - })?; + let response = self.api_client.response_get(&models_path).await?; + let json = handle_response_openai_compat(response).await?; + if let Some(err_obj) = json.get("error") { + let msg = err_obj + .get("message") + .and_then(|v| v.as_str()) + .unwrap_or("unknown error"); + return Err(ProviderError::Authentication(msg.to_string())); + } + + let data = json.get("data").and_then(|v| v.as_array()).ok_or_else(|| { + ProviderError::UsageError("Missing data field in JSON response".into()) + })?; let mut models: Vec = data .iter() .filter_map(|m| m.get("id").and_then(|v| v.as_str()).map(str::to_string)) diff --git a/crates/goose/src/providers/retry.rs b/crates/goose/src/providers/retry.rs index 4c8c4975aec7..1adf8f9a95c0 100644 --- a/crates/goose/src/providers/retry.rs +++ b/crates/goose/src/providers/retry.rs @@ -48,6 +48,10 @@ impl RetryConfig { } } + pub fn max_retries(&self) -> usize { + self.max_retries + } + pub fn delay_for_attempt(&self, attempt: usize) -> Duration { if attempt == 0 { return Duration::from_millis(0); @@ -67,6 +71,56 @@ impl RetryConfig { } } +pub fn should_retry(error: &ProviderError) -> bool { + matches!( + error, + ProviderError::RateLimitExceeded { .. } + | ProviderError::ServerError(_) + | ProviderError::RequestFailed(_) + ) +} + +pub async fn retry_operation( + config: &RetryConfig, + operation: F, +) -> Result +where + F: Fn() -> Fut + Send, + Fut: Future> + Send, + T: Send, +{ + let mut attempts = 0; + + loop { + match operation().await { + Ok(result) => return Ok(result), + Err(error) => { + if should_retry(&error) && attempts < config.max_retries { + attempts += 1; + tracing::warn!( + "Request failed, retrying ({}/{}): {:?}", + attempts, + config.max_retries, + error + ); + + let delay = match &error { + ProviderError::RateLimitExceeded { + retry_delay: Some(d), + .. + } => *d, + _ => config.delay_for_attempt(attempts), + }; + + sleep(delay).await; + continue; + } + return Err(error); + } + } + } +} + /// Trait for retry functionality to keep Provider dyn-compatible #[async_trait] pub trait ProviderRetry { @@ -87,12 +141,7 @@ pub trait ProviderRetry { return match operation().await { Ok(result) => Ok(result), Err(error) => { - let should_retry = matches!( - error, - ProviderError::RateLimitExceeded { .. } | ProviderError::ServerError(_) - ); - - if should_retry && attempts < config.max_retries { + if should_retry(&error) && attempts < config.max_retries { attempts += 1; tracing::warn!( "Request failed, retrying ({}/{}): {:?}", @@ -130,7 +179,6 @@ pub trait ProviderRetry { } } -// Let specific providers define their retry config if desired impl ProviderRetry for P { fn retry_config(&self) -> RetryConfig { Provider::retry_config(self)