From b6352fe210c07fb28a37f53e72581330b031a07b Mon Sep 17 00:00:00 2001 From: Brandon Kvarda Date: Mon, 2 Jun 2025 21:08:54 -0700 Subject: [PATCH 1/4] Add exponential backoff for databricks provider --- crates/goose/src/providers/databricks.rs | 284 ++++++++++++++++++----- 1 file changed, 222 insertions(+), 62 deletions(-) diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index 7a04407a829c..11495b79c5b6 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -17,6 +17,7 @@ use reqwest::{Client, StatusCode}; use serde::{Deserialize, Serialize}; use serde_json::Value; use std::time::Duration; +use tokio::time::sleep; const DEFAULT_CLIENT_ID: &str = "databricks-cli"; const DEFAULT_REDIRECT_URL: &str = "http://localhost:8020"; @@ -24,6 +25,17 @@ const DEFAULT_REDIRECT_URL: &str = "http://localhost:8020"; // https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess const DEFAULT_SCOPES: &[&str] = &["all-apis", "offline_access"]; +/// Default timeout for API requests in seconds +const DEFAULT_TIMEOUT_SECS: u64 = 600; +/// Default initial interval for retry (in milliseconds) +const DEFAULT_INITIAL_RETRY_INTERVAL_MS: u64 = 5000; +/// Default maximum number of retries +const DEFAULT_MAX_RETRIES: usize = 6; +/// Default retry backoff multiplier +const DEFAULT_BACKOFF_MULTIPLIER: f64 = 2.0; +/// Default maximum interval for retry (in milliseconds) +const DEFAULT_MAX_RETRY_INTERVAL_MS: u64 = 320_000; + pub const DATABRICKS_DEFAULT_MODEL: &str = "databricks-claude-3-7-sonnet"; // Databricks can passthrough to a wide range of models, we only provide the default pub const DATABRICKS_KNOWN_MODELS: &[&str] = &[ @@ -36,6 +48,53 @@ pub const DATABRICKS_KNOWN_MODELS: &[&str] = &[ pub const DATABRICKS_DOC_URL: &str = "https://docs.databricks.com/en/generative-ai/external-models/index.html"; +/// Retry configuration for handling rate limit errors +#[derive(Debug, Clone)] +struct RetryConfig { + /// Maximum number of retry attempts + max_retries: usize, + /// Initial interval between retries in milliseconds + initial_interval_ms: u64, + /// Multiplier for backoff (exponential) + backoff_multiplier: f64, + /// Maximum interval between retries in milliseconds + max_interval_ms: u64, +} + +impl Default for RetryConfig { + fn default() -> Self { + Self { + max_retries: DEFAULT_MAX_RETRIES, + initial_interval_ms: DEFAULT_INITIAL_RETRY_INTERVAL_MS, + backoff_multiplier: DEFAULT_BACKOFF_MULTIPLIER, + max_interval_ms: DEFAULT_MAX_RETRY_INTERVAL_MS, + } + } +} + +impl RetryConfig { + /// Calculate the delay for a specific retry attempt (with jitter) + fn delay_for_attempt(&self, attempt: usize) -> Duration { + if attempt == 0 { + return Duration::from_millis(0); + } + + // Calculate exponential backoff + let exponent = (attempt - 1) as u32; + let base_delay_ms = (self.initial_interval_ms as f64 + * self.backoff_multiplier.powi(exponent as i32)) as u64; + + // Apply max limit + let capped_delay_ms = std::cmp::min(base_delay_ms, self.max_interval_ms); + + // Add jitter (+/-20% randomness) to avoid thundering herd problem + let jitter_factor = 0.8 + (rand::random::() * 0.4); // Between 0.8 and 1.2 + let jittered_delay_ms = (capped_delay_ms as f64 * jitter_factor) as u64; + + Duration::from_millis(jittered_delay_ms) + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub enum DatabricksAuth { Token(String), @@ -70,6 +129,8 @@ pub struct DatabricksProvider { auth: DatabricksAuth, model: ModelConfig, image_format: ImageFormat, + #[serde(skip)] + retry_config: RetryConfig, } impl Default for DatabricksProvider { @@ -100,9 +161,12 @@ impl DatabricksProvider { let host = host?; let client = Client::builder() - .timeout(Duration::from_secs(600)) + .timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS)) .build()?; + // Load optional retry configuration from environment + let retry_config = Self::load_retry_config(&config); + // If we find a databricks token we prefer that if let Ok(api_key) = config.get_secret("DATABRICKS_TOKEN") { return Ok(Self { @@ -111,6 +175,7 @@ impl DatabricksProvider { auth: DatabricksAuth::token(api_key), model, image_format: ImageFormat::OpenAi, + retry_config, }); } @@ -121,9 +186,44 @@ impl DatabricksProvider { host, model, image_format: ImageFormat::OpenAi, + retry_config, }) } + /// Loads retry configuration from environment variables or uses defaults. + fn load_retry_config(config: &crate::config::Config) -> RetryConfig { + let max_retries = config + .get_param("DATABRICKS_MAX_RETRIES") + .ok() + .and_then(|v: String| v.parse::().ok()) + .unwrap_or(DEFAULT_MAX_RETRIES); + + let initial_interval_ms = config + .get_param("DATABRICKS_INITIAL_RETRY_INTERVAL_MS") + .ok() + .and_then(|v: String| v.parse::().ok()) + .unwrap_or(DEFAULT_INITIAL_RETRY_INTERVAL_MS); + + let backoff_multiplier = config + .get_param("DATABRICKS_BACKOFF_MULTIPLIER") + .ok() + .and_then(|v: String| v.parse::().ok()) + .unwrap_or(DEFAULT_BACKOFF_MULTIPLIER); + + let max_interval_ms = config + .get_param("DATABRICKS_MAX_RETRY_INTERVAL_MS") + .ok() + .and_then(|v: String| v.parse::().ok()) + .unwrap_or(DEFAULT_MAX_RETRY_INTERVAL_MS); + + RetryConfig { + max_retries, + initial_interval_ms, + backoff_multiplier, + max_interval_ms, + } + } + /// Create a new DatabricksProvider with the specified host and token /// /// # Arguments @@ -145,6 +245,7 @@ impl DatabricksProvider { auth: DatabricksAuth::token(api_key), model, image_format: ImageFormat::OpenAi, + retry_config: RetryConfig::default(), }) } @@ -182,70 +283,129 @@ impl DatabricksProvider { ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}")) })?; - let auth_header = self.ensure_auth_header().await?; - let response = self - .client - .post(url) - .header("Authorization", auth_header) - .json(&payload) - .send() - .await?; - - let status = response.status(); - let payload: Option = response.json().await.ok(); - - match status { - StatusCode::OK => payload.ok_or_else(|| ProviderError::RequestFailed("Response body is not valid JSON".to_string())), - StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { - Err(ProviderError::Authentication(format!("Authentication failed. Please ensure your API keys are valid and have the required permissions. \ - Status: {}. Response: {:?}", status, payload))) - } - StatusCode::BAD_REQUEST => { - // Databricks provides a generic 'error' but also includes 'external_model_message' which is provider specific - // We try to extract the error message from the payload and check for phrases that indicate context length exceeded - let payload_str = serde_json::to_string(&payload).unwrap_or_default().to_lowercase(); - let check_phrases = [ - "too long", - "context length", - "context_length_exceeded", - "reduce the length", - "token count", - "exceeds", - ]; - if check_phrases.iter().any(|c| payload_str.contains(c)) { - return Err(ProviderError::ContextLengthExceeded(payload_str)); - } - - let mut error_msg = "Unknown error".to_string(); - if let Some(payload) = &payload { - // try to convert message to string, if that fails use external_model_message - error_msg = payload - .get("message") - .and_then(|m| m.as_str()) - .or_else(|| { - payload.get("external_model_message") - .and_then(|ext| ext.get("message")) - .and_then(|m| m.as_str()) - }) - .unwrap_or("Unknown error").to_string(); - } + // Initialize retry counter + let mut attempts = 0; + let mut last_error = None; - tracing::debug!( - "{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload) + loop { + // Check if we've exceeded max retries + if attempts > 0 && attempts > self.retry_config.max_retries { + let error_msg = format!( + "Exceeded maximum retry attempts ({}) for rate limiting (429)", + self.retry_config.max_retries ); - Err(ProviderError::RequestFailed(format!("Request failed with status: {}. Message: {}", status, error_msg))) - } - StatusCode::TOO_MANY_REQUESTS => { - Err(ProviderError::RateLimitExceeded(format!("{:?}", payload))) + tracing::error!("{}", error_msg); + return Err(last_error.unwrap_or(ProviderError::RateLimitExceeded(error_msg))); } - StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => { - Err(ProviderError::ServerError(format!("{:?}", payload))) - } - _ => { - tracing::debug!( - "{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload) - ); - Err(ProviderError::RequestFailed(format!("Request failed with status: {}", status))) + + let auth_header = self.ensure_auth_header().await?; + let response = self + .client + .post(url.clone()) + .header("Authorization", auth_header) + .json(&payload) + .send() + .await?; + + let status = response.status(); + let payload: Option = response.json().await.ok(); + + match status { + StatusCode::OK => { + return payload.ok_or_else(|| { + ProviderError::RequestFailed("Response body is not valid JSON".to_string()) + }); + } + StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { + return Err(ProviderError::Authentication(format!( + "Authentication failed. Please ensure your API keys are valid and have the required permissions. \ + Status: {}. Response: {:?}", + status, payload + ))); + } + StatusCode::BAD_REQUEST => { + // Databricks provides a generic 'error' but also includes 'external_model_message' which is provider specific + // We try to extract the error message from the payload and check for phrases that indicate context length exceeded + let payload_str = serde_json::to_string(&payload) + .unwrap_or_default() + .to_lowercase(); + let check_phrases = [ + "too long", + "context length", + "context_length_exceeded", + "reduce the length", + "token count", + "exceeds", + ]; + if check_phrases.iter().any(|c| payload_str.contains(c)) { + return Err(ProviderError::ContextLengthExceeded(payload_str)); + } + + let mut error_msg = "Unknown error".to_string(); + if let Some(payload) = &payload { + // try to convert message to string, if that fails use external_model_message + error_msg = payload + .get("message") + .and_then(|m| m.as_str()) + .or_else(|| { + payload.get("external_model_message") + .and_then(|ext| ext.get("message")) + .and_then(|m| m.as_str()) + }) + .unwrap_or("Unknown error").to_string(); + } + + tracing::debug!( + "{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload) + ); + return Err(ProviderError::RequestFailed(format!("Request failed with status: {}. Message: {}", status, error_msg))); + } + StatusCode::TOO_MANY_REQUESTS => { + attempts += 1; + let error_msg = format!("Rate limit exceeded (attempt {}/{}): {:?}", + attempts, + self.retry_config.max_retries, + payload + ); + tracing::warn!("{}. Retrying after backoff...", error_msg); + + // Store the error in case we need to return it after max retries + last_error = Some(ProviderError::RateLimitExceeded(error_msg)); + + // Calculate and apply the backoff delay + let delay = self.retry_config.delay_for_attempt(attempts); + tracing::info!("Backing off for {:?} before retry", delay); + sleep(delay).await; + + // Continue to the next retry attempt + continue; + } + StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => { + attempts += 1; + let error_msg = format!("Server error (attempt {}/{}): {:?}", + attempts, + self.retry_config.max_retries, + payload + ); + tracing::warn!("{}. Retrying after backoff...", error_msg); + + // Store the error in case we need to return it after max retries + last_error = Some(ProviderError::ServerError(error_msg)); + + // Calculate and apply the backoff delay + let delay = self.retry_config.delay_for_attempt(attempts); + tracing::info!("Backing off for {:?} before retry", delay); + sleep(delay).await; + + // Continue to the next retry attempt + continue; + } + _ => { + tracing::debug!( + "{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload) + ); + return Err(ProviderError::RequestFailed(format!("Request failed with status: {}", status))); + } } } } From 76920293188fcacfe63da26949fd2739a401c7ad Mon Sep 17 00:00:00 2001 From: Brandon Kvarda Date: Mon, 2 Jun 2025 21:20:34 -0700 Subject: [PATCH 2/4] Add a few more context length check phrases --- crates/goose/src/providers/databricks.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index 11495b79c5b6..d414e9181419 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -336,6 +336,10 @@ impl DatabricksProvider { "reduce the length", "token count", "exceeds", + "input length", + "max_tokens", + "decrease input length", + "context limit", ]; if check_phrases.iter().any(|c| payload_str.contains(c)) { return Err(ProviderError::ContextLengthExceeded(payload_str)); From 61c1c35aa57cfeaea5678b3d1b17a292544ba56b Mon Sep 17 00:00:00 2001 From: Brandon Kvarda Date: Tue, 3 Jun 2025 08:52:19 -0700 Subject: [PATCH 3/4] Fix lint/fmt issues --- crates/goose/src/providers/databricks.rs | 48 +++++++++++++++--------- 1 file changed, 31 insertions(+), 17 deletions(-) diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index d414e9181419..690cedea2dae 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -165,7 +165,7 @@ impl DatabricksProvider { .build()?; // Load optional retry configuration from environment - let retry_config = Self::load_retry_config(&config); + let retry_config = Self::load_retry_config(config); // If we find a databricks token we prefer that if let Ok(api_key) = config.get_secret("DATABRICKS_TOKEN") { @@ -352,27 +352,35 @@ impl DatabricksProvider { .get("message") .and_then(|m| m.as_str()) .or_else(|| { - payload.get("external_model_message") + payload + .get("external_model_message") .and_then(|ext| ext.get("message")) .and_then(|m| m.as_str()) }) - .unwrap_or("Unknown error").to_string(); + .unwrap_or("Unknown error") + .to_string(); } tracing::debug!( - "{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload) + "{}", + format!( + "Provider request failed with status: {}. Payload: {:?}", + status, payload + ) ); - return Err(ProviderError::RequestFailed(format!("Request failed with status: {}. Message: {}", status, error_msg))); + return Err(ProviderError::RequestFailed(format!( + "Request failed with status: {}. Message: {}", + status, error_msg + ))); } StatusCode::TOO_MANY_REQUESTS => { attempts += 1; - let error_msg = format!("Rate limit exceeded (attempt {}/{}): {:?}", - attempts, - self.retry_config.max_retries, - payload + let error_msg = format!( + "Rate limit exceeded (attempt {}/{}): {:?}", + attempts, self.retry_config.max_retries, payload ); tracing::warn!("{}. Retrying after backoff...", error_msg); - + // Store the error in case we need to return it after max retries last_error = Some(ProviderError::RateLimitExceeded(error_msg)); @@ -386,13 +394,12 @@ impl DatabricksProvider { } StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => { attempts += 1; - let error_msg = format!("Server error (attempt {}/{}): {:?}", - attempts, - self.retry_config.max_retries, - payload + let error_msg = format!( + "Server error (attempt {}/{}): {:?}", + attempts, self.retry_config.max_retries, payload ); tracing::warn!("{}. Retrying after backoff...", error_msg); - + // Store the error in case we need to return it after max retries last_error = Some(ProviderError::ServerError(error_msg)); @@ -406,9 +413,16 @@ impl DatabricksProvider { } _ => { tracing::debug!( - "{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload) + "{}", + format!( + "Provider request failed with status: {}. Payload: {:?}", + status, payload + ) ); - return Err(ProviderError::RequestFailed(format!("Request failed with status: {}", status))); + return Err(ProviderError::RequestFailed(format!( + "Request failed with status: {}", + status + ))); } } } From 05165613f16614821c1aed2a5f6f64f69506da9d Mon Sep 17 00:00:00 2001 From: Brandon Kvarda Date: Tue, 3 Jun 2025 09:19:34 -0700 Subject: [PATCH 4/4] Update other check_phrases to match combined list --- crates/goose-llm/src/providers/databricks.rs | 3 +++ crates/goose/src/providers/databricks.rs | 2 +- crates/goose/src/providers/snowflake.rs | 3 +++ 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/crates/goose-llm/src/providers/databricks.rs b/crates/goose-llm/src/providers/databricks.rs index 7b91d6a26afa..3dd31493c1cd 100644 --- a/crates/goose-llm/src/providers/databricks.rs +++ b/crates/goose-llm/src/providers/databricks.rs @@ -139,7 +139,10 @@ impl DatabricksProvider { "token count", "exceeds", "exceed context limit", + "input length", "max_tokens", + "decrease input length", + "context limit", ]; if check_phrases.iter().any(|c| payload_str.contains(c)) { return Err(ProviderError::ContextLengthExceeded(payload_str)); diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index e7f28bd81580..bccae36460a3 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -286,7 +286,7 @@ impl DatabricksProvider { // Initialize retry counter let mut attempts = 0; let mut last_error = None; - + loop { // Check if we've exceeded max retries if attempts > 0 && attempts > self.retry_config.max_retries { diff --git a/crates/goose/src/providers/snowflake.rs b/crates/goose/src/providers/snowflake.rs index f1c3ad10a603..32c1f2c60041 100644 --- a/crates/goose/src/providers/snowflake.rs +++ b/crates/goose/src/providers/snowflake.rs @@ -334,7 +334,10 @@ impl SnowflakeProvider { "token count", "exceeds", "exceed context limit", + "input length", "max_tokens", + "decrease input length", + "context limit", ]; if check_phrases.iter().any(|c| payload_str.contains(c)) { return Err(ProviderError::ContextLengthExceeded("Request exceeds maximum context length. Please reduce the number of messages or content size.".to_string()));