diff --git a/codex-rs/cloud-requirements/src/lib.rs b/codex-rs/cloud-requirements/src/lib.rs index 95f5180a83f..9ca432dc7fa 100644 --- a/codex-rs/cloud-requirements/src/lib.rs +++ b/codex-rs/cloud-requirements/src/lib.rs @@ -14,21 +14,33 @@ use codex_core::AuthManager; use codex_core::auth::CodexAuth; use codex_core::config_loader::CloudRequirementsLoader; use codex_core::config_loader::ConfigRequirementsToml; +use codex_core::util::backoff; use codex_protocol::account::PlanType; use std::sync::Arc; use std::time::Duration; use std::time::Instant; +use tokio::time::sleep; use tokio::time::timeout; const CLOUD_REQUIREMENTS_TIMEOUT: Duration = Duration::from_secs(15); +const CLOUD_REQUIREMENTS_MAX_ATTEMPTS: usize = 5; + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +enum FetchCloudRequirementsStatus { + BackendClientInit, + Request, + Parse, +} #[async_trait] trait RequirementsFetcher: Send + Sync { - /// Returns requirements as a TOML string. + /// Returns `Ok(None)` when there are no cloud requirements for the account. /// - /// TODO(gt): For now, returns an Option. But when we want to make this fail-closed, return a - /// Result. - async fn fetch_requirements(&self, auth: &CodexAuth) -> Option; + /// Returning `Err` indicates cloud requirements could not be fetched. + async fn fetch_requirements( + &self, + auth: &CodexAuth, + ) -> Result, FetchCloudRequirementsStatus>; } struct BackendRequirementsFetcher { @@ -43,7 +55,10 @@ impl BackendRequirementsFetcher { #[async_trait] impl RequirementsFetcher for BackendRequirementsFetcher { - async fn fetch_requirements(&self, auth: &CodexAuth) -> Option { + async fn fetch_requirements( + &self, + auth: &CodexAuth, + ) -> Result, FetchCloudRequirementsStatus> { let client = BackendClient::from_auth(self.base_url.clone(), auth) .inspect_err(|err| { tracing::warn!( @@ -51,20 +66,22 @@ impl RequirementsFetcher for BackendRequirementsFetcher { "Failed to construct backend client for cloud requirements" ); }) - .ok()?; + .map_err(|_| FetchCloudRequirementsStatus::BackendClientInit)?; let response = client .get_config_requirements_file() .await .inspect_err(|err| tracing::warn!(error = %err, "Failed to fetch cloud requirements")) - .ok()?; + .map_err(|_| FetchCloudRequirementsStatus::Request)?; let Some(contents) = response.contents else { - tracing::warn!("Cloud requirements response missing contents"); - return None; + tracing::info!( + "Cloud requirements response missing contents; treating as no requirements" + ); + return Ok(None); }; - Some(contents) + Ok(Some(contents)) } } @@ -128,11 +145,41 @@ impl CloudRequirementsService { return None; } - let contents = self.fetcher.fetch_requirements(&auth).await?; - parse_cloud_requirements(&contents) - .inspect_err(|err| tracing::warn!(error = %err, "Failed to parse cloud requirements")) - .ok() - .flatten() + self.fetch_with_retries(&auth).await + } + + async fn fetch_with_retries(&self, auth: &CodexAuth) -> Option { + for attempt in 1..=CLOUD_REQUIREMENTS_MAX_ATTEMPTS { + let fetch_result = self + .fetcher + .fetch_requirements(auth) + .await + .and_then(|contents| { + contents.map_or(Ok(None), |contents| { + parse_cloud_requirements(&contents).map_err(|err| { + tracing::warn!(error = %err, "Failed to parse cloud requirements"); + FetchCloudRequirementsStatus::Parse + }) + }) + }); + + match fetch_result { + Ok(requirements) => return requirements, + Err(status) => { + if attempt < CLOUD_REQUIREMENTS_MAX_ATTEMPTS { + tracing::warn!( + status = ?status, + attempt, + max_attempts = CLOUD_REQUIREMENTS_MAX_ATTEMPTS, + "Failed to fetch cloud requirements; retrying" + ); + sleep(backoff(attempt as u64)).await; + } + } + } + } + + None } } @@ -178,8 +225,11 @@ mod tests { use codex_protocol::protocol::AskForApproval; use pretty_assertions::assert_eq; use serde_json::json; + use std::collections::VecDeque; use std::future::pending; use std::path::Path; + use std::sync::atomic::AtomicUsize; + use std::sync::atomic::Ordering; use tempfile::tempdir; fn write_auth_json(codex_home: &Path, value: serde_json::Value) -> std::io::Result<()> { @@ -246,8 +296,11 @@ mod tests { #[async_trait::async_trait] impl RequirementsFetcher for StaticFetcher { - async fn fetch_requirements(&self, _auth: &CodexAuth) -> Option { - self.contents.clone() + async fn fetch_requirements( + &self, + _auth: &CodexAuth, + ) -> Result, FetchCloudRequirementsStatus> { + Ok(self.contents.clone()) } } @@ -255,9 +308,39 @@ mod tests { #[async_trait::async_trait] impl RequirementsFetcher for PendingFetcher { - async fn fetch_requirements(&self, _auth: &CodexAuth) -> Option { + async fn fetch_requirements( + &self, + _auth: &CodexAuth, + ) -> Result, FetchCloudRequirementsStatus> { pending::<()>().await; - None + Ok(None) + } + } + + struct SequenceFetcher { + responses: + tokio::sync::Mutex, FetchCloudRequirementsStatus>>>, + request_count: AtomicUsize, + } + + impl SequenceFetcher { + fn new(responses: Vec, FetchCloudRequirementsStatus>>) -> Self { + Self { + responses: tokio::sync::Mutex::new(VecDeque::from(responses)), + request_count: AtomicUsize::new(0), + } + } + } + + #[async_trait::async_trait] + impl RequirementsFetcher for SequenceFetcher { + async fn fetch_requirements( + &self, + _auth: &CodexAuth, + ) -> Result, FetchCloudRequirementsStatus> { + self.request_count.fetch_add(1, Ordering::SeqCst); + let mut responses = self.responses.lock().await; + responses.pop_front().unwrap_or(Ok(None)) } } @@ -359,4 +442,75 @@ mod tests { let result = handle.await.expect("cloud requirements task"); assert!(result.is_none()); } + + #[tokio::test(start_paused = true)] + async fn fetch_cloud_requirements_retries_until_success() { + let fetcher = Arc::new(SequenceFetcher::new(vec![ + Err(FetchCloudRequirementsStatus::Request), + Ok(Some("allowed_approval_policies = [\"never\"]".to_string())), + ])); + let service = CloudRequirementsService::new( + auth_manager_with_plan("business"), + fetcher.clone(), + CLOUD_REQUIREMENTS_TIMEOUT, + ); + + let handle = tokio::spawn(async move { service.fetch().await }); + tokio::task::yield_now().await; + tokio::time::advance(Duration::from_secs(1)).await; + + assert_eq!( + handle.await.expect("cloud requirements task"), + Some(ConfigRequirementsToml { + allowed_approval_policies: Some(vec![AskForApproval::Never]), + allowed_sandbox_modes: None, + mcp_servers: None, + rules: None, + enforce_residency: None, + }) + ); + assert_eq!(fetcher.request_count.load(Ordering::SeqCst), 2); + } + + #[tokio::test] + async fn fetch_cloud_requirements_none_is_success_without_retry() { + let fetcher = Arc::new(SequenceFetcher::new(vec![ + Ok(None), + Err(FetchCloudRequirementsStatus::Request), + ])); + let service = CloudRequirementsService::new( + auth_manager_with_plan("enterprise"), + fetcher.clone(), + CLOUD_REQUIREMENTS_TIMEOUT, + ); + + assert!(service.fetch().await.is_none()); + assert_eq!(fetcher.request_count.load(Ordering::SeqCst), 1); + } + + #[tokio::test(start_paused = true)] + async fn fetch_cloud_requirements_stops_after_max_retries() { + let fetcher = Arc::new(SequenceFetcher::new(vec![ + Err( + FetchCloudRequirementsStatus::Request + ); + CLOUD_REQUIREMENTS_MAX_ATTEMPTS + ])); + let service = CloudRequirementsService::new( + auth_manager_with_plan("enterprise"), + fetcher.clone(), + CLOUD_REQUIREMENTS_TIMEOUT, + ); + + let handle = tokio::spawn(async move { service.fetch().await }); + tokio::task::yield_now().await; + tokio::time::advance(Duration::from_secs(5)).await; + tokio::task::yield_now().await; + + assert!(handle.await.expect("cloud requirements task").is_none()); + assert_eq!( + fetcher.request_count.load(Ordering::SeqCst), + CLOUD_REQUIREMENTS_MAX_ATTEMPTS + ); + } } diff --git a/codex-rs/core/src/util.rs b/codex-rs/core/src/util.rs index 1a538da558e..59fecb0a9ee 100644 --- a/codex-rs/core/src/util.rs +++ b/codex-rs/core/src/util.rs @@ -37,7 +37,7 @@ macro_rules! feedback_tags { }; } -pub(crate) fn backoff(attempt: u64) -> Duration { +pub fn backoff(attempt: u64) -> Duration { let exp = BACKOFF_FACTOR.powi(attempt.saturating_sub(1) as i32); let base = (INITIAL_DELAY_MS as f64 * exp) as u64; let jitter = rand::rng().random_range(0.9..1.1);