diff --git a/codex-rs/cloud-tasks/Cargo.toml b/codex-rs/cloud-tasks/Cargo.toml index 188538bec68..cc79b3e7907 100644 --- a/codex-rs/cloud-tasks/Cargo.toml +++ b/codex-rs/cloud-tasks/Cargo.toml @@ -37,6 +37,9 @@ unicode-width = { workspace = true } owo-colors = { workspace = true, features = ["supports-colors"] } supports-color = { workspace = true } +[dependencies.async-trait] +workspace = true + [dev-dependencies] async-trait = { workspace = true } pretty_assertions = { workspace = true } diff --git a/codex-rs/cloud-tasks/src/cli.rs b/codex-rs/cloud-tasks/src/cli.rs index a7612153b4e..6b36509639a 100644 --- a/codex-rs/cloud-tasks/src/cli.rs +++ b/codex-rs/cloud-tasks/src/cli.rs @@ -34,10 +34,6 @@ pub struct ExecCommand { #[arg(long = "env", value_name = "ENV_ID")] pub environment: String, - /// Git branch to run in Codex Cloud. - #[arg(long = "branch", value_name = "BRANCH", default_value = "main")] - pub branch: String, - /// Number of assistant attempts (best-of-N). #[arg( long = "attempts", @@ -45,6 +41,10 @@ pub struct ExecCommand { value_parser = parse_attempts )] pub attempts: usize, + + /// Git branch to run in Codex Cloud (defaults to current branch). + #[arg(long = "branch", value_name = "BRANCH")] + pub branch: Option, } fn parse_attempts(input: &str) -> Result { diff --git a/codex-rs/cloud-tasks/src/lib.rs b/codex-rs/cloud-tasks/src/lib.rs index f73e07f3afb..105f6cfb2e7 100644 --- a/codex-rs/cloud-tasks/src/lib.rs +++ b/codex-rs/cloud-tasks/src/lib.rs @@ -104,6 +104,54 @@ async fn init_backend(user_agent_suffix: &str) -> anyhow::Result }) } +#[async_trait::async_trait] +trait GitInfoProvider { + async fn default_branch_name(&self, path: &std::path::Path) -> Option; + + async fn current_branch_name(&self, path: &std::path::Path) -> Option; +} + +struct RealGitInfo; + +#[async_trait::async_trait] +impl GitInfoProvider for RealGitInfo { + async fn default_branch_name(&self, path: &std::path::Path) -> Option { + codex_core::git_info::default_branch_name(path).await + } + + async fn current_branch_name(&self, path: &std::path::Path) -> Option { + codex_core::git_info::current_branch_name(path).await + } +} + +async fn resolve_git_ref(branch_override: Option<&String>) -> String { + resolve_git_ref_with_git_info(branch_override, &RealGitInfo).await +} + +async fn resolve_git_ref_with_git_info( + branch_override: Option<&String>, + git_info: &impl GitInfoProvider, +) -> String { + if let Some(branch) = branch_override { + let branch = branch.trim(); + if !branch.is_empty() { + return branch.to_string(); + } + } + + if let Ok(cwd) = std::env::current_dir() { + if let Some(branch) = git_info.current_branch_name(&cwd).await { + branch + } else if let Some(branch) = git_info.default_branch_name(&cwd).await { + branch + } else { + "main".to_string() + } + } else { + "main".to_string() + } +} + async fn run_exec_command(args: crate::cli::ExecCommand) -> anyhow::Result<()> { let crate::cli::ExecCommand { query, @@ -114,11 +162,12 @@ async fn run_exec_command(args: crate::cli::ExecCommand) -> anyhow::Result<()> { let ctx = init_backend("codex_cloud_tasks_exec").await?; let prompt = resolve_query_input(query)?; let env_id = resolve_environment_id(&ctx, &environment).await?; + let git_ref = resolve_git_ref(branch.as_ref()).await; let created = codex_cloud_tasks_client::CloudBackend::create_task( &*ctx.backend, &env_id, &prompt, - &branch, + &git_ref, false, attempts, ) @@ -1362,17 +1411,7 @@ pub async fn run_main(cli: Cli, _codex_linux_sandbox_exe: Option) -> an let backend = Arc::clone(&backend); let best_of_n = page.best_of_n; tokio::spawn(async move { - let git_ref = if let Ok(cwd) = std::env::current_dir() { - if let Some(branch) = codex_core::git_info::default_branch_name(&cwd).await { - branch - } else if let Some(branch) = codex_core::git_info::current_branch_name(&cwd).await { - branch - } else { - "main".to_string() - } - } else { - "main".to_string() - }; + let git_ref = resolve_git_ref(None).await; let result = codex_cloud_tasks_client::CloudBackend::create_task(&*backend, &env, &text, &git_ref, false, best_of_n).await; let evt = match result { @@ -1991,6 +2030,7 @@ fn pretty_lines_from_error(raw: &str) -> Vec { #[cfg(test)] mod tests { use super::*; + use crate::resolve_git_ref_with_git_info; use codex_cloud_tasks_client::DiffSummary; use codex_cloud_tasks_client::MockClient; use codex_cloud_tasks_client::TaskId; @@ -2005,6 +2045,85 @@ mod tests { use ratatui::buffer::Buffer; use ratatui::layout::Rect; + struct StubGitInfo { + default_branch: Option, + current_branch: Option, + } + + impl StubGitInfo { + fn new(default_branch: Option, current_branch: Option) -> Self { + Self { + default_branch, + current_branch, + } + } + } + + #[async_trait::async_trait] + impl super::GitInfoProvider for StubGitInfo { + async fn default_branch_name(&self, _path: &std::path::Path) -> Option { + self.default_branch.clone() + } + + async fn current_branch_name(&self, _path: &std::path::Path) -> Option { + self.current_branch.clone() + } + } + + #[tokio::test] + async fn branch_override_is_used_when_provided() { + let git_ref = resolve_git_ref_with_git_info( + Some(&"feature/override".to_string()), + &StubGitInfo::new(None, None), + ) + .await; + + assert_eq!(git_ref, "feature/override"); + } + + #[tokio::test] + async fn trims_override_whitespace() { + let git_ref = resolve_git_ref_with_git_info( + Some(&" feature/spaces ".to_string()), + &StubGitInfo::new(None, None), + ) + .await; + + assert_eq!(git_ref, "feature/spaces"); + } + + #[tokio::test] + async fn prefers_current_branch_when_available() { + let git_ref = resolve_git_ref_with_git_info( + None, + &StubGitInfo::new( + Some("default-main".to_string()), + Some("feature/current".to_string()), + ), + ) + .await; + + assert_eq!(git_ref, "feature/current"); + } + + #[tokio::test] + async fn falls_back_to_current_branch_when_default_is_missing() { + let git_ref = resolve_git_ref_with_git_info( + None, + &StubGitInfo::new(None, Some("develop".to_string())), + ) + .await; + + assert_eq!(git_ref, "develop"); + } + + #[tokio::test] + async fn falls_back_to_main_when_no_git_info_is_available() { + let git_ref = resolve_git_ref_with_git_info(None, &StubGitInfo::new(None, None)).await; + + assert_eq!(git_ref, "main"); + } + #[test] fn format_task_status_lines_with_diff_and_label() { let now = Utc::now();