Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 173 additions & 19 deletions codex-rs/cloud-requirements/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,33 @@ use codex_core::AuthManager;
use codex_core::auth::CodexAuth;
use codex_core::config_loader::CloudRequirementsLoader;
use codex_core::config_loader::ConfigRequirementsToml;
use codex_core::util::backoff;
use codex_protocol::account::PlanType;
use std::sync::Arc;
use std::time::Duration;
use std::time::Instant;
use tokio::time::sleep;
use tokio::time::timeout;

const CLOUD_REQUIREMENTS_TIMEOUT: Duration = Duration::from_secs(15);
const CLOUD_REQUIREMENTS_MAX_ATTEMPTS: usize = 5;

#[derive(Clone, Copy, Debug, Eq, PartialEq)]
enum FetchCloudRequirementsStatus {
BackendClientInit,
Request,
Parse,
}

#[async_trait]
trait RequirementsFetcher: Send + Sync {
/// Returns requirements as a TOML string.
/// Returns `Ok(None)` when there are no cloud requirements for the account.
///
/// TODO(gt): For now, returns an Option. But when we want to make this fail-closed, return a
/// Result.
async fn fetch_requirements(&self, auth: &CodexAuth) -> Option<String>;
/// Returning `Err` indicates cloud requirements could not be fetched.
async fn fetch_requirements(
&self,
auth: &CodexAuth,
) -> Result<Option<String>, FetchCloudRequirementsStatus>;
}

struct BackendRequirementsFetcher {
Expand All @@ -43,28 +55,33 @@ impl BackendRequirementsFetcher {

#[async_trait]
impl RequirementsFetcher for BackendRequirementsFetcher {
async fn fetch_requirements(&self, auth: &CodexAuth) -> Option<String> {
async fn fetch_requirements(
&self,
auth: &CodexAuth,
) -> Result<Option<String>, FetchCloudRequirementsStatus> {
let client = BackendClient::from_auth(self.base_url.clone(), auth)
.inspect_err(|err| {
tracing::warn!(
error = %err,
"Failed to construct backend client for cloud requirements"
);
})
.ok()?;
.map_err(|_| FetchCloudRequirementsStatus::BackendClientInit)?;

let response = client
.get_config_requirements_file()
.await
.inspect_err(|err| tracing::warn!(error = %err, "Failed to fetch cloud requirements"))
.ok()?;
.map_err(|_| FetchCloudRequirementsStatus::Request)?;

let Some(contents) = response.contents else {
tracing::warn!("Cloud requirements response missing contents");
return None;
tracing::info!(
"Cloud requirements response missing contents; treating as no requirements"
);
return Ok(None);
};

Some(contents)
Ok(Some(contents))
}
}

Expand Down Expand Up @@ -128,11 +145,41 @@ impl CloudRequirementsService {
return None;
}

let contents = self.fetcher.fetch_requirements(&auth).await?;
parse_cloud_requirements(&contents)
.inspect_err(|err| tracing::warn!(error = %err, "Failed to parse cloud requirements"))
.ok()
.flatten()
self.fetch_with_retries(&auth).await
}

async fn fetch_with_retries(&self, auth: &CodexAuth) -> Option<ConfigRequirementsToml> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have shared code for this kind of things that support exp backoff etc

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

am depending on backoff here. I can introduce a retry crate but would argue unnecessary for a for-loop

for attempt in 1..=CLOUD_REQUIREMENTS_MAX_ATTEMPTS {
let fetch_result = self
.fetcher
.fetch_requirements(auth)
.await
.and_then(|contents| {
contents.map_or(Ok(None), |contents| {
parse_cloud_requirements(&contents).map_err(|err| {
tracing::warn!(error = %err, "Failed to parse cloud requirements");
FetchCloudRequirementsStatus::Parse
})
})
});

match fetch_result {
Ok(requirements) => return requirements,
Err(status) => {
if attempt < CLOUD_REQUIREMENTS_MAX_ATTEMPTS {
tracing::warn!(
status = ?status,
attempt,
max_attempts = CLOUD_REQUIREMENTS_MAX_ATTEMPTS,
"Failed to fetch cloud requirements; retrying"
);
sleep(backoff(attempt as u64)).await;
}
}
}
}

None
}
}

Expand Down Expand Up @@ -178,8 +225,11 @@ mod tests {
use codex_protocol::protocol::AskForApproval;
use pretty_assertions::assert_eq;
use serde_json::json;
use std::collections::VecDeque;
use std::future::pending;
use std::path::Path;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use tempfile::tempdir;

fn write_auth_json(codex_home: &Path, value: serde_json::Value) -> std::io::Result<()> {
Expand Down Expand Up @@ -246,18 +296,51 @@ mod tests {

#[async_trait::async_trait]
impl RequirementsFetcher for StaticFetcher {
async fn fetch_requirements(&self, _auth: &CodexAuth) -> Option<String> {
self.contents.clone()
async fn fetch_requirements(
&self,
_auth: &CodexAuth,
) -> Result<Option<String>, FetchCloudRequirementsStatus> {
Ok(self.contents.clone())
}
}

struct PendingFetcher;

#[async_trait::async_trait]
impl RequirementsFetcher for PendingFetcher {
async fn fetch_requirements(&self, _auth: &CodexAuth) -> Option<String> {
async fn fetch_requirements(
&self,
_auth: &CodexAuth,
) -> Result<Option<String>, FetchCloudRequirementsStatus> {
pending::<()>().await;
None
Ok(None)
}
}

struct SequenceFetcher {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks like a very non idiomatic way to a retry policy

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just in the test, to force retries?

responses:
tokio::sync::Mutex<VecDeque<Result<Option<String>, FetchCloudRequirementsStatus>>>,
request_count: AtomicUsize,
}

impl SequenceFetcher {
fn new(responses: Vec<Result<Option<String>, FetchCloudRequirementsStatus>>) -> Self {
Self {
responses: tokio::sync::Mutex::new(VecDeque::from(responses)),
request_count: AtomicUsize::new(0),
}
}
}

#[async_trait::async_trait]
impl RequirementsFetcher for SequenceFetcher {
async fn fetch_requirements(
&self,
_auth: &CodexAuth,
) -> Result<Option<String>, FetchCloudRequirementsStatus> {
self.request_count.fetch_add(1, Ordering::SeqCst);
let mut responses = self.responses.lock().await;
responses.pop_front().unwrap_or(Ok(None))
}
}

Expand Down Expand Up @@ -359,4 +442,75 @@ mod tests {
let result = handle.await.expect("cloud requirements task");
assert!(result.is_none());
}

#[tokio::test(start_paused = true)]
async fn fetch_cloud_requirements_retries_until_success() {
let fetcher = Arc::new(SequenceFetcher::new(vec![
Err(FetchCloudRequirementsStatus::Request),
Ok(Some("allowed_approval_policies = [\"never\"]".to_string())),
]));
let service = CloudRequirementsService::new(
auth_manager_with_plan("business"),
fetcher.clone(),
CLOUD_REQUIREMENTS_TIMEOUT,
);

let handle = tokio::spawn(async move { service.fetch().await });
tokio::task::yield_now().await;
tokio::time::advance(Duration::from_secs(1)).await;

assert_eq!(
handle.await.expect("cloud requirements task"),
Some(ConfigRequirementsToml {
allowed_approval_policies: Some(vec![AskForApproval::Never]),
allowed_sandbox_modes: None,
mcp_servers: None,
rules: None,
enforce_residency: None,
})
);
assert_eq!(fetcher.request_count.load(Ordering::SeqCst), 2);
}

#[tokio::test]
async fn fetch_cloud_requirements_none_is_success_without_retry() {
let fetcher = Arc::new(SequenceFetcher::new(vec![
Ok(None),
Err(FetchCloudRequirementsStatus::Request),
]));
let service = CloudRequirementsService::new(
auth_manager_with_plan("enterprise"),
fetcher.clone(),
CLOUD_REQUIREMENTS_TIMEOUT,
);

assert!(service.fetch().await.is_none());
assert_eq!(fetcher.request_count.load(Ordering::SeqCst), 1);
}

#[tokio::test(start_paused = true)]
async fn fetch_cloud_requirements_stops_after_max_retries() {
let fetcher = Arc::new(SequenceFetcher::new(vec![
Err(
FetchCloudRequirementsStatus::Request
);
CLOUD_REQUIREMENTS_MAX_ATTEMPTS
]));
let service = CloudRequirementsService::new(
auth_manager_with_plan("enterprise"),
fetcher.clone(),
CLOUD_REQUIREMENTS_TIMEOUT,
);

let handle = tokio::spawn(async move { service.fetch().await });
tokio::task::yield_now().await;
tokio::time::advance(Duration::from_secs(5)).await;
tokio::task::yield_now().await;

assert!(handle.await.expect("cloud requirements task").is_none());
assert_eq!(
fetcher.request_count.load(Ordering::SeqCst),
CLOUD_REQUIREMENTS_MAX_ATTEMPTS
);
}
}
2 changes: 1 addition & 1 deletion codex-rs/core/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ macro_rules! feedback_tags {
};
}

pub(crate) fn backoff(attempt: u64) -> Duration {
pub fn backoff(attempt: u64) -> Duration {
let exp = BACKOFF_FACTOR.powi(attempt.saturating_sub(1) as i32);
let base = (INITIAL_DELAY_MS as f64 * exp) as u64;
let jitter = rand::rng().random_range(0.9..1.1);
Expand Down
Loading