Skip to content
Merged
3 changes: 3 additions & 0 deletions crates/goose-llm/src/providers/databricks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,10 @@ impl DatabricksProvider {
"token count",
"exceeds",
"exceed context limit",
"input length",
"max_tokens",
"decrease input length",
"context limit",
];
if check_phrases.iter().any(|c| payload_str.contains(c)) {
return Err(ProviderError::ContextLengthExceeded(payload_str));
Expand Down
305 changes: 241 additions & 64 deletions crates/goose/src/providers/databricks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,25 @@ use reqwest::{Client, StatusCode};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::time::Duration;
use tokio::time::sleep;

const DEFAULT_CLIENT_ID: &str = "databricks-cli";
const DEFAULT_REDIRECT_URL: &str = "http://localhost:8020";
// "offline_access" scope is used to request an OAuth 2.0 Refresh Token
// https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess
const DEFAULT_SCOPES: &[&str] = &["all-apis", "offline_access"];

/// Default timeout for API requests in seconds
const DEFAULT_TIMEOUT_SECS: u64 = 600;
/// Default initial interval for retry (in milliseconds)
const DEFAULT_INITIAL_RETRY_INTERVAL_MS: u64 = 5000;
/// Default maximum number of retries
const DEFAULT_MAX_RETRIES: usize = 6;
/// Default retry backoff multiplier
const DEFAULT_BACKOFF_MULTIPLIER: f64 = 2.0;
/// Default maximum interval for retry (in milliseconds)
const DEFAULT_MAX_RETRY_INTERVAL_MS: u64 = 320_000;

pub const DATABRICKS_DEFAULT_MODEL: &str = "databricks-claude-3-7-sonnet";
// Databricks can passthrough to a wide range of models, we only provide the default
pub const DATABRICKS_KNOWN_MODELS: &[&str] = &[
Expand All @@ -36,6 +48,53 @@ pub const DATABRICKS_KNOWN_MODELS: &[&str] = &[
pub const DATABRICKS_DOC_URL: &str =
"https://docs.databricks.com/en/generative-ai/external-models/index.html";

/// Retry configuration for handling rate limit errors
#[derive(Debug, Clone)]
struct RetryConfig {
/// Maximum number of retry attempts
max_retries: usize,
/// Initial interval between retries in milliseconds
initial_interval_ms: u64,
/// Multiplier for backoff (exponential)
backoff_multiplier: f64,
/// Maximum interval between retries in milliseconds
max_interval_ms: u64,
}

impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: DEFAULT_MAX_RETRIES,
initial_interval_ms: DEFAULT_INITIAL_RETRY_INTERVAL_MS,
backoff_multiplier: DEFAULT_BACKOFF_MULTIPLIER,
max_interval_ms: DEFAULT_MAX_RETRY_INTERVAL_MS,
}
}
}

impl RetryConfig {
/// Calculate the delay for a specific retry attempt (with jitter)
fn delay_for_attempt(&self, attempt: usize) -> Duration {
if attempt == 0 {
return Duration::from_millis(0);
}

// Calculate exponential backoff
let exponent = (attempt - 1) as u32;
let base_delay_ms = (self.initial_interval_ms as f64
* self.backoff_multiplier.powi(exponent as i32)) as u64;

// Apply max limit
let capped_delay_ms = std::cmp::min(base_delay_ms, self.max_interval_ms);

// Add jitter (+/-20% randomness) to avoid thundering herd problem
let jitter_factor = 0.8 + (rand::random::<f64>() * 0.4); // Between 0.8 and 1.2
let jittered_delay_ms = (capped_delay_ms as f64 * jitter_factor) as u64;

Duration::from_millis(jittered_delay_ms)
}
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum DatabricksAuth {
Token(String),
Expand Down Expand Up @@ -70,6 +129,8 @@ pub struct DatabricksProvider {
auth: DatabricksAuth,
model: ModelConfig,
image_format: ImageFormat,
#[serde(skip)]
retry_config: RetryConfig,
}

impl Default for DatabricksProvider {
Expand Down Expand Up @@ -100,9 +161,12 @@ impl DatabricksProvider {
let host = host?;

let client = Client::builder()
.timeout(Duration::from_secs(600))
.timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
.build()?;

// Load optional retry configuration from environment
let retry_config = Self::load_retry_config(config);

// If we find a databricks token we prefer that
if let Ok(api_key) = config.get_secret("DATABRICKS_TOKEN") {
return Ok(Self {
Expand All @@ -111,6 +175,7 @@ impl DatabricksProvider {
auth: DatabricksAuth::token(api_key),
model,
image_format: ImageFormat::OpenAi,
retry_config,
});
}

Expand All @@ -121,9 +186,44 @@ impl DatabricksProvider {
host,
model,
image_format: ImageFormat::OpenAi,
retry_config,
})
}

/// Loads retry configuration from environment variables or uses defaults.
fn load_retry_config(config: &crate::config::Config) -> RetryConfig {
let max_retries = config
.get_param("DATABRICKS_MAX_RETRIES")
.ok()
.and_then(|v: String| v.parse::<usize>().ok())
.unwrap_or(DEFAULT_MAX_RETRIES);

let initial_interval_ms = config
.get_param("DATABRICKS_INITIAL_RETRY_INTERVAL_MS")
.ok()
.and_then(|v: String| v.parse::<u64>().ok())
.unwrap_or(DEFAULT_INITIAL_RETRY_INTERVAL_MS);

let backoff_multiplier = config
.get_param("DATABRICKS_BACKOFF_MULTIPLIER")
.ok()
.and_then(|v: String| v.parse::<f64>().ok())
.unwrap_or(DEFAULT_BACKOFF_MULTIPLIER);

let max_interval_ms = config
.get_param("DATABRICKS_MAX_RETRY_INTERVAL_MS")
.ok()
.and_then(|v: String| v.parse::<u64>().ok())
.unwrap_or(DEFAULT_MAX_RETRY_INTERVAL_MS);

RetryConfig {
max_retries,
initial_interval_ms,
backoff_multiplier,
max_interval_ms,
}
}

/// Create a new DatabricksProvider with the specified host and token
///
/// # Arguments
Expand All @@ -145,6 +245,7 @@ impl DatabricksProvider {
auth: DatabricksAuth::token(api_key),
model,
image_format: ImageFormat::OpenAi,
retry_config: RetryConfig::default(),
})
}

Expand Down Expand Up @@ -182,72 +283,148 @@ impl DatabricksProvider {
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
})?;

let auth_header = self.ensure_auth_header().await?;
let response = self
.client
.post(url)
.header("Authorization", auth_header)
.json(&payload)
.send()
.await?;

let status = response.status();
let payload: Option<Value> = response.json().await.ok();

match status {
StatusCode::OK => payload.ok_or_else(|| ProviderError::RequestFailed("Response body is not valid JSON".to_string())),
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
Err(ProviderError::Authentication(format!("Authentication failed. Please ensure your API keys are valid and have the required permissions. \
Status: {}. Response: {:?}", status, payload)))
}
StatusCode::BAD_REQUEST => {
// Databricks provides a generic 'error' but also includes 'external_model_message' which is provider specific
// We try to extract the error message from the payload and check for phrases that indicate context length exceeded
let payload_str = serde_json::to_string(&payload).unwrap_or_default().to_lowercase();
let check_phrases = [
"too long",
"context length",
"context_length_exceeded",
"reduce the length",
"token count",
"exceeds",
"exceed context limit",
"max_tokens",
];
if check_phrases.iter().any(|c| payload_str.contains(c)) {
return Err(ProviderError::ContextLengthExceeded(payload_str));
}

let mut error_msg = "Unknown error".to_string();
if let Some(payload) = &payload {
// try to convert message to string, if that fails use external_model_message
error_msg = payload
.get("message")
.and_then(|m| m.as_str())
.or_else(|| {
payload.get("external_model_message")
.and_then(|ext| ext.get("message"))
.and_then(|m| m.as_str())
})
.unwrap_or("Unknown error").to_string();
}
// Initialize retry counter
let mut attempts = 0;
let mut last_error = None;

tracing::debug!(
"{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload)
loop {
// Check if we've exceeded max retries
if attempts > 0 && attempts > self.retry_config.max_retries {
let error_msg = format!(
"Exceeded maximum retry attempts ({}) for rate limiting (429)",
self.retry_config.max_retries
);
Err(ProviderError::RequestFailed(format!("Request failed with status: {}. Message: {}", status, error_msg)))
}
StatusCode::TOO_MANY_REQUESTS => {
Err(ProviderError::RateLimitExceeded(format!("{:?}", payload)))
tracing::error!("{}", error_msg);
return Err(last_error.unwrap_or(ProviderError::RateLimitExceeded(error_msg)));
}
StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => {
Err(ProviderError::ServerError(format!("{:?}", payload)))
}
_ => {
tracing::debug!(
"{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload)
);
Err(ProviderError::RequestFailed(format!("Request failed with status: {}", status)))

let auth_header = self.ensure_auth_header().await?;
let response = self
.client
.post(url.clone())
.header("Authorization", auth_header)
.json(&payload)
.send()
.await?;

let status = response.status();
let payload: Option<Value> = response.json().await.ok();

match status {
StatusCode::OK => {
return payload.ok_or_else(|| {
ProviderError::RequestFailed("Response body is not valid JSON".to_string())
});
}
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
return Err(ProviderError::Authentication(format!(
"Authentication failed. Please ensure your API keys are valid and have the required permissions. \
Status: {}. Response: {:?}",
status, payload
)));
}
StatusCode::BAD_REQUEST => {
// Databricks provides a generic 'error' but also includes 'external_model_message' which is provider specific
// We try to extract the error message from the payload and check for phrases that indicate context length exceeded
let payload_str = serde_json::to_string(&payload)
.unwrap_or_default()
.to_lowercase();
let check_phrases = [
"too long",
"context length",
"context_length_exceeded",
"reduce the length",
"token count",
"exceeds",
"exceed context limit",
"input length",
"max_tokens",
"decrease input length",
"context limit",
];
if check_phrases.iter().any(|c| payload_str.contains(c)) {
return Err(ProviderError::ContextLengthExceeded(payload_str));
}

let mut error_msg = "Unknown error".to_string();
if let Some(payload) = &payload {
// try to convert message to string, if that fails use external_model_message
error_msg = payload
.get("message")
.and_then(|m| m.as_str())
.or_else(|| {
payload
.get("external_model_message")
.and_then(|ext| ext.get("message"))
.and_then(|m| m.as_str())
})
.unwrap_or("Unknown error")
.to_string();
}

tracing::debug!(
"{}",
format!(
"Provider request failed with status: {}. Payload: {:?}",
status, payload
)
);
return Err(ProviderError::RequestFailed(format!(
"Request failed with status: {}. Message: {}",
status, error_msg
)));
}
StatusCode::TOO_MANY_REQUESTS => {
attempts += 1;
let error_msg = format!(
"Rate limit exceeded (attempt {}/{}): {:?}",
attempts, self.retry_config.max_retries, payload
);
tracing::warn!("{}. Retrying after backoff...", error_msg);

// Store the error in case we need to return it after max retries
last_error = Some(ProviderError::RateLimitExceeded(error_msg));

// Calculate and apply the backoff delay
let delay = self.retry_config.delay_for_attempt(attempts);
tracing::info!("Backing off for {:?} before retry", delay);
sleep(delay).await;

// Continue to the next retry attempt
continue;
}
StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => {
attempts += 1;
let error_msg = format!(
"Server error (attempt {}/{}): {:?}",
attempts, self.retry_config.max_retries, payload
);
tracing::warn!("{}. Retrying after backoff...", error_msg);

// Store the error in case we need to return it after max retries
last_error = Some(ProviderError::ServerError(error_msg));

// Calculate and apply the backoff delay
let delay = self.retry_config.delay_for_attempt(attempts);
tracing::info!("Backing off for {:?} before retry", delay);
sleep(delay).await;

// Continue to the next retry attempt
continue;
}
_ => {
tracing::debug!(
"{}",
format!(
"Provider request failed with status: {}. Payload: {:?}",
status, payload
)
);
return Err(ProviderError::RequestFailed(format!(
"Request failed with status: {}",
status
)));
}
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions crates/goose/src/providers/snowflake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,10 @@ impl SnowflakeProvider {
"token count",
"exceeds",
"exceed context limit",
"input length",
"max_tokens",
"decrease input length",
"context limit",
];
if check_phrases.iter().any(|c| payload_str.contains(c)) {
return Err(ProviderError::ContextLengthExceeded("Request exceeds maximum context length. Please reduce the number of messages or content size.".to_string()));
Expand Down
Loading