Skip to content
3 changes: 3 additions & 0 deletions codex-rs/cloud-tasks/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ unicode-width = { workspace = true }
owo-colors = { workspace = true, features = ["supports-colors"] }
supports-color = { workspace = true }

[dependencies.async-trait]
workspace = true

[dev-dependencies]
async-trait = { workspace = true }
pretty_assertions = { workspace = true }
8 changes: 4 additions & 4 deletions codex-rs/cloud-tasks/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,17 @@ pub struct ExecCommand {
#[arg(long = "env", value_name = "ENV_ID")]
pub environment: String,

/// Git branch to run in Codex Cloud.
#[arg(long = "branch", value_name = "BRANCH", default_value = "main")]
pub branch: String,

/// Number of assistant attempts (best-of-N).
#[arg(
long = "attempts",
default_value_t = 1usize,
value_parser = parse_attempts
)]
pub attempts: usize,

/// Git branch to run in Codex Cloud (defaults to current branch).
#[arg(long = "branch", value_name = "BRANCH")]
pub branch: Option<String>,
}

fn parse_attempts(input: &str) -> Result<usize, String> {
Expand Down
143 changes: 131 additions & 12 deletions codex-rs/cloud-tasks/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,54 @@ async fn init_backend(user_agent_suffix: &str) -> anyhow::Result<BackendContext>
})
}

#[async_trait::async_trait]
trait GitInfoProvider {
async fn default_branch_name(&self, path: &std::path::Path) -> Option<String>;

async fn current_branch_name(&self, path: &std::path::Path) -> Option<String>;
}

struct RealGitInfo;

#[async_trait::async_trait]
impl GitInfoProvider for RealGitInfo {
async fn default_branch_name(&self, path: &std::path::Path) -> Option<String> {
codex_core::git_info::default_branch_name(path).await
}

async fn current_branch_name(&self, path: &std::path::Path) -> Option<String> {
codex_core::git_info::current_branch_name(path).await
}
}

async fn resolve_git_ref(branch_override: Option<&String>) -> String {
resolve_git_ref_with_git_info(branch_override, &RealGitInfo).await
}

async fn resolve_git_ref_with_git_info(
branch_override: Option<&String>,
git_info: &impl GitInfoProvider,
) -> String {
if let Some(branch) = branch_override {
let branch = branch.trim();
if !branch.is_empty() {
return branch.to_string();
}
}

if let Ok(cwd) = std::env::current_dir() {
if let Some(branch) = git_info.current_branch_name(&cwd).await {
branch
} else if let Some(branch) = git_info.default_branch_name(&cwd).await {
branch
} else {
"main".to_string()
}
} else {
"main".to_string()
}
}

async fn run_exec_command(args: crate::cli::ExecCommand) -> anyhow::Result<()> {
let crate::cli::ExecCommand {
query,
Expand All @@ -114,11 +162,12 @@ async fn run_exec_command(args: crate::cli::ExecCommand) -> anyhow::Result<()> {
let ctx = init_backend("codex_cloud_tasks_exec").await?;
let prompt = resolve_query_input(query)?;
let env_id = resolve_environment_id(&ctx, &environment).await?;
let git_ref = resolve_git_ref(branch.as_ref()).await;
let created = codex_cloud_tasks_client::CloudBackend::create_task(
&*ctx.backend,
&env_id,
&prompt,
&branch,
&git_ref,
false,
attempts,
)
Expand Down Expand Up @@ -1362,17 +1411,7 @@ pub async fn run_main(cli: Cli, _codex_linux_sandbox_exe: Option<PathBuf>) -> an
let backend = Arc::clone(&backend);
let best_of_n = page.best_of_n;
tokio::spawn(async move {
let git_ref = if let Ok(cwd) = std::env::current_dir() {
if let Some(branch) = codex_core::git_info::default_branch_name(&cwd).await {
branch
} else if let Some(branch) = codex_core::git_info::current_branch_name(&cwd).await {
branch
} else {
"main".to_string()
}
} else {
"main".to_string()
};
let git_ref = resolve_git_ref(None).await;

let result = codex_cloud_tasks_client::CloudBackend::create_task(&*backend, &env, &text, &git_ref, false, best_of_n).await;
let evt = match result {
Expand Down Expand Up @@ -1991,6 +2030,7 @@ fn pretty_lines_from_error(raw: &str) -> Vec<String> {
#[cfg(test)]
mod tests {
use super::*;
use crate::resolve_git_ref_with_git_info;
use codex_cloud_tasks_client::DiffSummary;
use codex_cloud_tasks_client::MockClient;
use codex_cloud_tasks_client::TaskId;
Expand All @@ -2005,6 +2045,85 @@ mod tests {
use ratatui::buffer::Buffer;
use ratatui::layout::Rect;

struct StubGitInfo {
default_branch: Option<String>,
current_branch: Option<String>,
}

impl StubGitInfo {
fn new(default_branch: Option<String>, current_branch: Option<String>) -> Self {
Self {
default_branch,
current_branch,
}
}
}

#[async_trait::async_trait]
impl super::GitInfoProvider for StubGitInfo {
async fn default_branch_name(&self, _path: &std::path::Path) -> Option<String> {
self.default_branch.clone()
}

async fn current_branch_name(&self, _path: &std::path::Path) -> Option<String> {
self.current_branch.clone()
}
}

#[tokio::test]
async fn branch_override_is_used_when_provided() {
let git_ref = resolve_git_ref_with_git_info(
Some(&"feature/override".to_string()),
&StubGitInfo::new(None, None),
)
.await;

assert_eq!(git_ref, "feature/override");
}

#[tokio::test]
async fn trims_override_whitespace() {
let git_ref = resolve_git_ref_with_git_info(
Some(&" feature/spaces ".to_string()),
&StubGitInfo::new(None, None),
)
.await;

assert_eq!(git_ref, "feature/spaces");
}

#[tokio::test]
async fn prefers_current_branch_when_available() {
let git_ref = resolve_git_ref_with_git_info(
None,
&StubGitInfo::new(
Some("default-main".to_string()),
Some("feature/current".to_string()),
),
)
.await;

assert_eq!(git_ref, "feature/current");
}

#[tokio::test]
async fn falls_back_to_current_branch_when_default_is_missing() {
let git_ref = resolve_git_ref_with_git_info(
None,
&StubGitInfo::new(None, Some("develop".to_string())),
)
.await;

assert_eq!(git_ref, "develop");
}

#[tokio::test]
async fn falls_back_to_main_when_no_git_info_is_available() {
let git_ref = resolve_git_ref_with_git_info(None, &StubGitInfo::new(None, None)).await;

assert_eq!(git_ref, "main");
}

#[test]
fn format_task_status_lines_with_diff_and_label() {
let now = Utc::now();
Expand Down
Loading