diff --git a/codex-rs/app-server-protocol/src/protocol/common.rs b/codex-rs/app-server-protocol/src/protocol/common.rs index ed8207e24ed..433ee163604 100644 --- a/codex-rs/app-server-protocol/src/protocol/common.rs +++ b/codex-rs/app-server-protocol/src/protocol/common.rs @@ -109,6 +109,10 @@ client_request_definitions! { params: v2::ThreadResumeParams, response: v2::ThreadResumeResponse, }, + ThreadFork => "thread/fork" { + params: v2::ThreadForkParams, + response: v2::ThreadForkResponse, + }, ThreadArchive => "thread/archive" { params: v2::ThreadArchiveParams, response: v2::ThreadArchiveResponse, @@ -221,6 +225,11 @@ client_request_definitions! { params: v1::ResumeConversationParams, response: v1::ResumeConversationResponse, }, + /// Fork a recorded Codex conversation into a new session. + ForkConversation { + params: v1::ForkConversationParams, + response: v1::ForkConversationResponse, + }, ArchiveConversation { params: v1::ArchiveConversationParams, response: v1::ArchiveConversationResponse, diff --git a/codex-rs/app-server-protocol/src/protocol/v1.rs b/codex-rs/app-server-protocol/src/protocol/v1.rs index 981ab28d1b4..ecc9d7c07de 100644 --- a/codex-rs/app-server-protocol/src/protocol/v1.rs +++ b/codex-rs/app-server-protocol/src/protocol/v1.rs @@ -83,6 +83,15 @@ pub struct ResumeConversationResponse { pub rollout_path: PathBuf, } +#[derive(Serialize, Deserialize, Debug, Clone, JsonSchema, TS)] +#[serde(rename_all = "camelCase")] +pub struct ForkConversationResponse { + pub conversation_id: ThreadId, + pub model: String, + pub initial_messages: Option>, + pub rollout_path: PathBuf, +} + #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] #[serde(untagged)] pub enum GetConversationSummaryParams { @@ -148,6 +157,14 @@ pub struct ResumeConversationParams { pub overrides: Option, } +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] +#[serde(rename_all = "camelCase")] +pub struct ForkConversationParams { + pub path: Option, + pub conversation_id: Option, + pub overrides: Option, +} + #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] #[serde(rename_all = "camelCase")] pub struct AddConversationSubscriptionResponse { diff --git a/codex-rs/app-server-protocol/src/protocol/v2.rs b/codex-rs/app-server-protocol/src/protocol/v2.rs index ee061158065..dc47685c050 100644 --- a/codex-rs/app-server-protocol/src/protocol/v2.rs +++ b/codex-rs/app-server-protocol/src/protocol/v2.rs @@ -1064,6 +1064,47 @@ pub struct ThreadResumeResponse { pub reasoning_effort: Option, } +#[derive(Serialize, Deserialize, Debug, Default, Clone, PartialEq, JsonSchema, TS)] +#[serde(rename_all = "camelCase")] +#[ts(export_to = "v2/")] +/// There are two ways to fork a thread: +/// 1. By thread_id: load the thread from disk by thread_id and fork it into a new thread. +/// 2. By path: load the thread from disk by path and fork it into a new thread. +/// +/// If using path, the thread_id param will be ignored. +/// +/// Prefer using thread_id whenever possible. +pub struct ThreadForkParams { + pub thread_id: String, + + /// [UNSTABLE] Specify the rollout path to fork from. + /// If specified, the thread_id param will be ignored. + pub path: Option, + + /// Configuration overrides for the forked thread, if any. + pub model: Option, + pub model_provider: Option, + pub cwd: Option, + pub approval_policy: Option, + pub sandbox: Option, + pub config: Option>, + pub base_instructions: Option, + pub developer_instructions: Option, +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] +#[serde(rename_all = "camelCase")] +#[ts(export_to = "v2/")] +pub struct ThreadForkResponse { + pub thread: Thread, + pub model: String, + pub model_provider: String, + pub cwd: PathBuf, + pub approval_policy: AskForApproval, + pub sandbox: SandboxPolicy, + pub reasoning_effort: Option, +} + #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] #[serde(rename_all = "camelCase")] #[ts(export_to = "v2/")] @@ -1238,7 +1279,7 @@ pub struct Thread { pub source: SessionSource, /// Optional Git metadata captured when the thread was created. pub git_info: Option, - /// Only populated on `thread/resume` and `thread/rollback` responses. + /// Only populated on `thread/resume`, `thread/rollback`, `thread/fork` responses. /// For all other responses and notifications returning a Thread, /// the turns field will be an empty list. pub turns: Vec, @@ -1314,7 +1355,7 @@ impl From for TokenUsageBreakdown { #[ts(export_to = "v2/")] pub struct Turn { pub id: String, - /// Only populated on a `thread/resume` response. + /// Only populated on a `thread/resume` or `thread/fork` response. /// For all other responses and notifications returning a Turn, /// the items field will be an empty list. pub items: Vec, diff --git a/codex-rs/app-server/README.md b/codex-rs/app-server/README.md index 88d96dda56b..a247cb73b99 100644 --- a/codex-rs/app-server/README.md +++ b/codex-rs/app-server/README.md @@ -41,7 +41,7 @@ Use the thread APIs to create, list, or archive conversations. Drive a conversat ## Lifecycle Overview - Initialize once: Immediately after launching the codex app-server process, send an `initialize` request with your client metadata, then emit an `initialized` notification. Any other request before this handshake gets rejected. -- Start (or resume) a thread: Call `thread/start` to open a fresh conversation. The response returns the thread object and you’ll also get a `thread/started` notification. If you’re continuing an existing conversation, call `thread/resume` with its ID instead. +- Start (or resume) a thread: Call `thread/start` to open a fresh conversation. The response returns the thread object and you’ll also get a `thread/started` notification. If you’re continuing an existing conversation, call `thread/resume` with its ID instead. If you want to branch from an existing conversation, call `thread/fork` to create a new thread id with copied history. - Begin a turn: To send user input, call `turn/start` with the target `threadId` and the user's input. Optional fields let you override model, cwd, sandbox policy, etc. This immediately returns the new turn object and triggers a `turn/started` notification. - Stream events: After `turn/start`, keep reading JSON-RPC notifications on stdout. You’ll see `item/started`, `item/completed`, deltas like `item/agentMessage/delta`, tool progress, etc. These represent streaming model output plus any side effects (commands, tool calls, reasoning notes). - Finish the turn: When the model is done (or the turn is interrupted via making the `turn/interrupt` call), the server sends `turn/completed` with the final turn state and token usage. @@ -72,6 +72,7 @@ Example (from OpenAI's official VSCode extension): - `thread/start` — create a new thread; emits `thread/started` and auto-subscribes you to turn/item events for that thread. - `thread/resume` — reopen an existing thread by id so subsequent `turn/start` calls append to it. +- `thread/fork` — fork an existing thread into a new thread id by copying the stored history; emits `thread/started` and auto-subscribes you to turn/item events for the new thread. - `thread/list` — page through stored rollouts; supports cursor-based pagination and optional `modelProviders` filtering. - `thread/archive` — move a thread’s rollout file into the archived directory; returns `{}` on success. - `thread/rollback` — drop the last N turns from the agent’s in-memory context and persist a rollback marker in the rollout so future resumes see the pruned history; returns the updated `thread` (with `turns` populated) on success. @@ -120,6 +121,14 @@ To continue a stored session, call `thread/resume` with the `thread.id` you prev { "id": 11, "result": { "thread": { "id": "thr_123", … } } } ``` +To branch from a stored session, call `thread/fork` with the `thread.id`. This creates a new thread id and emits a `thread/started` notification for it: + +```json +{ "method": "thread/fork", "id": 12, "params": { "threadId": "thr_123" } } +{ "id": 12, "result": { "thread": { "id": "thr_456", … } } } +{ "method": "thread/started", "params": { "thread": { … } } } +``` + ### Example: List threads (with pagination & filters) `thread/list` lets you render a history UI. Pass any combination of: diff --git a/codex-rs/app-server/src/codex_message_processor.rs b/codex-rs/app-server/src/codex_message_processor.rs index 3b0e4a9db16..371cb5d42f7 100644 --- a/codex-rs/app-server/src/codex_message_processor.rs +++ b/codex-rs/app-server/src/codex_message_processor.rs @@ -28,6 +28,8 @@ use codex_app_server_protocol::ConversationSummary; use codex_app_server_protocol::ExecOneOffCommandResponse; use codex_app_server_protocol::FeedbackUploadParams; use codex_app_server_protocol::FeedbackUploadResponse; +use codex_app_server_protocol::ForkConversationParams; +use codex_app_server_protocol::ForkConversationResponse; use codex_app_server_protocol::FuzzyFileSearchParams; use codex_app_server_protocol::FuzzyFileSearchResponse; use codex_app_server_protocol::GetAccountParams; @@ -86,6 +88,8 @@ use codex_app_server_protocol::SkillsListResponse; use codex_app_server_protocol::Thread; use codex_app_server_protocol::ThreadArchiveParams; use codex_app_server_protocol::ThreadArchiveResponse; +use codex_app_server_protocol::ThreadForkParams; +use codex_app_server_protocol::ThreadForkResponse; use codex_app_server_protocol::ThreadItem; use codex_app_server_protocol::ThreadListParams; use codex_app_server_protocol::ThreadListResponse; @@ -124,6 +128,7 @@ use codex_core::config::ConfigService; use codex_core::config::edit::ConfigEditsBuilder; use codex_core::config::types::McpServerTransportConfig; use codex_core::default_client::get_codex_user_agent; +use codex_core::error::CodexErr; use codex_core::exec::ExecParams; use codex_core::exec_env::create_env; use codex_core::features::Feature; @@ -367,6 +372,9 @@ impl CodexMessageProcessor { ClientRequest::ThreadResume { request_id, params } => { self.thread_resume(request_id, params).await; } + ClientRequest::ThreadFork { request_id, params } => { + self.thread_fork(request_id, params).await; + } ClientRequest::ThreadArchive { request_id, params } => { self.thread_archive(request_id, params).await; } @@ -433,6 +441,9 @@ impl CodexMessageProcessor { ClientRequest::ResumeConversation { request_id, params } => { self.handle_resume_conversation(request_id, params).await; } + ClientRequest::ForkConversation { request_id, params } => { + self.handle_fork_conversation(request_id, params).await; + } ClientRequest::ArchiveConversation { request_id, params } => { self.archive_conversation(request_id, params).await; } @@ -1793,6 +1804,198 @@ impl CodexMessageProcessor { } } + async fn thread_fork(&mut self, request_id: RequestId, params: ThreadForkParams) { + let ThreadForkParams { + thread_id, + path, + model, + model_provider, + cwd, + approval_policy, + sandbox, + config: cli_overrides, + base_instructions, + developer_instructions, + } = params; + + let overrides_requested = model.is_some() + || model_provider.is_some() + || cwd.is_some() + || approval_policy.is_some() + || sandbox.is_some() + || cli_overrides.is_some() + || base_instructions.is_some() + || developer_instructions.is_some(); + + let config = if overrides_requested { + let overrides = self.build_thread_config_overrides( + model, + model_provider, + cwd, + approval_policy, + sandbox, + base_instructions, + developer_instructions, + ); + + // Persist windows sandbox feature. + let mut cli_overrides = cli_overrides.unwrap_or_default(); + if cfg!(windows) && self.config.features.enabled(Feature::WindowsSandbox) { + cli_overrides.insert( + "features.experimental_windows_sandbox".to_string(), + serde_json::json!(true), + ); + } + + match derive_config_from_params(&self.cli_overrides, Some(cli_overrides), overrides) + .await + { + Ok(config) => config, + Err(err) => { + let error = JSONRPCErrorError { + code: INVALID_REQUEST_ERROR_CODE, + message: format!("error deriving config: {err}"), + data: None, + }; + self.outgoing.send_error(request_id, error).await; + return; + } + } + } else { + self.config.as_ref().clone() + }; + + let rollout_path = if let Some(path) = path { + path + } else { + let existing_thread_id = match ThreadId::from_string(&thread_id) { + Ok(id) => id, + Err(err) => { + let error = JSONRPCErrorError { + code: INVALID_REQUEST_ERROR_CODE, + message: format!("invalid thread id: {err}"), + data: None, + }; + self.outgoing.send_error(request_id, error).await; + return; + } + }; + + match find_thread_path_by_id_str( + &self.config.codex_home, + &existing_thread_id.to_string(), + ) + .await + { + Ok(Some(p)) => p, + Ok(None) => { + self.send_invalid_request_error( + request_id, + format!("no rollout found for thread id {existing_thread_id}"), + ) + .await; + return; + } + Err(err) => { + self.send_invalid_request_error( + request_id, + format!("failed to locate thread id {existing_thread_id}: {err}"), + ) + .await; + return; + } + } + }; + + let fallback_model_provider = config.model_provider_id.clone(); + + let NewThread { + thread_id, + session_configured, + .. + } = match self + .thread_manager + .fork_thread(usize::MAX, config, rollout_path.clone()) + .await + { + Ok(thread) => thread, + Err(err) => { + let (code, message) = match err { + CodexErr::Io(_) | CodexErr::Json(_) => ( + INVALID_REQUEST_ERROR_CODE, + format!("failed to load rollout `{}`: {err}", rollout_path.display()), + ), + CodexErr::InvalidRequest(message) => (INVALID_REQUEST_ERROR_CODE, message), + _ => (INTERNAL_ERROR_CODE, format!("error forking thread: {err}")), + }; + let error = JSONRPCErrorError { + code, + message, + data: None, + }; + self.outgoing.send_error(request_id, error).await; + return; + } + }; + + let SessionConfiguredEvent { + rollout_path, + initial_messages, + .. + } = session_configured; + // Auto-attach a conversation listener when forking a thread. + if let Err(err) = self + .attach_conversation_listener(thread_id, false, ApiVersion::V2) + .await + { + tracing::warn!( + "failed to attach listener for thread {}: {}", + thread_id, + err.message + ); + } + + let mut thread = match read_summary_from_rollout( + rollout_path.as_path(), + fallback_model_provider.as_str(), + ) + .await + { + Ok(summary) => summary_to_thread(summary), + Err(err) => { + self.send_internal_error( + request_id, + format!( + "failed to load rollout `{}` for thread {thread_id}: {err}", + rollout_path.display() + ), + ) + .await; + return; + } + }; + thread.turns = initial_messages + .as_deref() + .map_or_else(Vec::new, build_turns_from_event_msgs); + + let response = ThreadForkResponse { + thread: thread.clone(), + model: session_configured.model, + model_provider: session_configured.model_provider_id, + cwd: session_configured.cwd, + approval_policy: session_configured.approval_policy.into(), + sandbox: session_configured.sandbox_policy.into(), + reasoning_effort: session_configured.reasoning_effort, + }; + + self.outgoing.send_response(request_id, response).await; + + let notif = ThreadStartedNotification { thread }; + self.outgoing + .send_server_notification(ServerNotification::ThreadStarted(notif)) + .await; + } + async fn get_thread_summary( &self, request_id: RequestId, @@ -2416,6 +2619,166 @@ impl CodexMessageProcessor { } } + async fn handle_fork_conversation( + &self, + request_id: RequestId, + params: ForkConversationParams, + ) { + let ForkConversationParams { + path, + conversation_id, + overrides, + } = params; + + // Derive a Config using the same logic as new conversation, honoring overrides if provided. + let config = match overrides { + Some(overrides) => { + let NewConversationParams { + model, + model_provider, + profile, + cwd, + approval_policy, + sandbox: sandbox_mode, + config: cli_overrides, + base_instructions, + developer_instructions, + compact_prompt, + include_apply_patch_tool, + } = overrides; + + // Persist windows sandbox feature. + let mut cli_overrides = cli_overrides.unwrap_or_default(); + if cfg!(windows) && self.config.features.enabled(Feature::WindowsSandbox) { + cli_overrides.insert( + "features.experimental_windows_sandbox".to_string(), + serde_json::json!(true), + ); + } + + let overrides = ConfigOverrides { + model, + config_profile: profile, + cwd: cwd.map(PathBuf::from), + approval_policy, + sandbox_mode, + model_provider, + codex_linux_sandbox_exe: self.codex_linux_sandbox_exe.clone(), + base_instructions, + developer_instructions, + compact_prompt, + include_apply_patch_tool, + ..Default::default() + }; + + derive_config_from_params(&self.cli_overrides, Some(cli_overrides), overrides).await + } + None => Ok(self.config.as_ref().clone()), + }; + let config = match config { + Ok(cfg) => cfg, + Err(err) => { + self.send_invalid_request_error( + request_id, + format!("error deriving config: {err}"), + ) + .await; + return; + } + }; + + let rollout_path = if let Some(path) = path { + path + } else if let Some(conversation_id) = conversation_id { + match find_thread_path_by_id_str(&self.config.codex_home, &conversation_id.to_string()) + .await + { + Ok(Some(found_path)) => found_path, + Ok(None) => { + self.send_invalid_request_error( + request_id, + format!("no rollout found for conversation id {conversation_id}"), + ) + .await; + return; + } + Err(err) => { + self.send_invalid_request_error( + request_id, + format!("failed to locate conversation id {conversation_id}: {err}"), + ) + .await; + return; + } + } + } else { + self.send_invalid_request_error( + request_id, + "either path or conversation id must be provided".to_string(), + ) + .await; + return; + }; + + let NewThread { + thread_id, + session_configured, + .. + } = match self + .thread_manager + .fork_thread(usize::MAX, config, rollout_path.clone()) + .await + { + Ok(thread) => thread, + Err(err) => { + let (code, message) = match err { + CodexErr::Io(_) | CodexErr::Json(_) => ( + INVALID_REQUEST_ERROR_CODE, + format!("failed to load rollout `{}`: {err}", rollout_path.display()), + ), + CodexErr::InvalidRequest(message) => (INVALID_REQUEST_ERROR_CODE, message), + _ => ( + INTERNAL_ERROR_CODE, + format!("error forking conversation: {err}"), + ), + }; + let error = JSONRPCErrorError { + code, + message, + data: None, + }; + self.outgoing.send_error(request_id, error).await; + return; + } + }; + + self.outgoing + .send_server_notification(ServerNotification::SessionConfigured( + SessionConfiguredNotification { + session_id: session_configured.session_id, + model: session_configured.model.clone(), + reasoning_effort: session_configured.reasoning_effort, + history_log_id: session_configured.history_log_id, + history_entry_count: session_configured.history_entry_count, + initial_messages: session_configured.initial_messages.clone(), + rollout_path: session_configured.rollout_path.clone(), + }, + )) + .await; + let initial_messages = session_configured + .initial_messages + .map(|msgs| msgs.into_iter().collect()); + + // Reply with conversation id + model and initial messages (when present) + let response = ForkConversationResponse { + conversation_id: thread_id, + model: session_configured.model.clone(), + initial_messages, + rollout_path: session_configured.rollout_path.clone(), + }; + self.outgoing.send_response(request_id, response).await; + } + async fn send_invalid_request_error(&self, request_id: RequestId, message: String) { let error = JSONRPCErrorError { code: INVALID_REQUEST_ERROR_CODE, diff --git a/codex-rs/app-server/tests/common/mcp_process.rs b/codex-rs/app-server/tests/common/mcp_process.rs index f3ec682fb21..aba5da20f50 100644 --- a/codex-rs/app-server/tests/common/mcp_process.rs +++ b/codex-rs/app-server/tests/common/mcp_process.rs @@ -21,6 +21,7 @@ use codex_app_server_protocol::ConfigBatchWriteParams; use codex_app_server_protocol::ConfigReadParams; use codex_app_server_protocol::ConfigValueWriteParams; use codex_app_server_protocol::FeedbackUploadParams; +use codex_app_server_protocol::ForkConversationParams; use codex_app_server_protocol::GetAccountParams; use codex_app_server_protocol::GetAuthStatusParams; use codex_app_server_protocol::InitializeParams; @@ -43,6 +44,7 @@ use codex_app_server_protocol::SendUserTurnParams; use codex_app_server_protocol::ServerRequest; use codex_app_server_protocol::SetDefaultModelParams; use codex_app_server_protocol::ThreadArchiveParams; +use codex_app_server_protocol::ThreadForkParams; use codex_app_server_protocol::ThreadListParams; use codex_app_server_protocol::ThreadResumeParams; use codex_app_server_protocol::ThreadRollbackParams; @@ -308,6 +310,15 @@ impl McpProcess { self.send_request("thread/resume", params).await } + /// Send a `thread/fork` JSON-RPC request. + pub async fn send_thread_fork_request( + &mut self, + params: ThreadForkParams, + ) -> anyhow::Result { + let params = Some(serde_json::to_value(params)?); + self.send_request("thread/fork", params).await + } + /// Send a `thread/archive` JSON-RPC request. pub async fn send_thread_archive_request( &mut self, @@ -353,6 +364,15 @@ impl McpProcess { self.send_request("resumeConversation", params).await } + /// Send a `forkConversation` JSON-RPC request. + pub async fn send_fork_conversation_request( + &mut self, + params: ForkConversationParams, + ) -> anyhow::Result { + let params = Some(serde_json::to_value(params)?); + self.send_request("forkConversation", params).await + } + /// Send a `loginApiKey` JSON-RPC request. pub async fn send_login_api_key_request( &mut self, diff --git a/codex-rs/app-server/tests/suite/fork_thread.rs b/codex-rs/app-server/tests/suite/fork_thread.rs new file mode 100644 index 00000000000..17548fe041e --- /dev/null +++ b/codex-rs/app-server/tests/suite/fork_thread.rs @@ -0,0 +1,140 @@ +use anyhow::Result; +use app_test_support::McpProcess; +use app_test_support::create_fake_rollout; +use app_test_support::to_response; +use codex_app_server_protocol::ForkConversationParams; +use codex_app_server_protocol::ForkConversationResponse; +use codex_app_server_protocol::JSONRPCNotification; +use codex_app_server_protocol::JSONRPCResponse; +use codex_app_server_protocol::NewConversationParams; // reused for overrides shape +use codex_app_server_protocol::RequestId; +use codex_app_server_protocol::ServerNotification; +use codex_app_server_protocol::SessionConfiguredNotification; +use codex_core::protocol::EventMsg; +use pretty_assertions::assert_eq; +use tempfile::TempDir; +use tokio::time::timeout; + +const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10); + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn fork_conversation_creates_new_rollout() -> Result<()> { + let codex_home = TempDir::new()?; + + let preview = "Hello A"; + let conversation_id = create_fake_rollout( + codex_home.path(), + "2025-01-02T12-00-00", + "2025-01-02T12:00:00Z", + preview, + Some("openai"), + None, + )?; + + let original_path = codex_home + .path() + .join("sessions") + .join("2025") + .join("01") + .join("02") + .join(format!( + "rollout-2025-01-02T12-00-00-{conversation_id}.jsonl" + )); + assert!( + original_path.exists(), + "expected original rollout to exist at {}", + original_path.display() + ); + let original_contents = std::fs::read_to_string(&original_path)?; + + let mut mcp = McpProcess::new(codex_home.path()).await?; + timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??; + + let fork_req_id = mcp + .send_fork_conversation_request(ForkConversationParams { + path: Some(original_path.clone()), + conversation_id: None, + overrides: Some(NewConversationParams { + model: Some("o3".to_string()), + ..Default::default() + }), + }) + .await?; + + // Expect a sessionConfigured notification for the forked session. + let notification: JSONRPCNotification = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_notification_message("sessionConfigured"), + ) + .await??; + let session_configured: ServerNotification = notification.try_into()?; + let ServerNotification::SessionConfigured(SessionConfiguredNotification { + model, + session_id, + rollout_path, + initial_messages: session_initial_messages, + .. + }) = session_configured + else { + unreachable!("expected sessionConfigured notification"); + }; + + assert_eq!(model, "o3"); + assert_ne!( + session_id.to_string(), + conversation_id, + "expected a new conversation id when forking" + ); + assert_ne!( + rollout_path, original_path, + "expected a new rollout path when forking" + ); + assert!( + rollout_path.exists(), + "expected forked rollout to exist at {}", + rollout_path.display() + ); + + let session_initial_messages = + session_initial_messages.expect("expected initial messages when forking from rollout"); + match session_initial_messages.as_slice() { + [EventMsg::UserMessage(message)] => { + assert_eq!(message.message, preview); + } + other => panic!("unexpected initial messages from rollout fork: {other:#?}"), + } + + // Then the response for forkConversation. + let fork_resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(fork_req_id)), + ) + .await??; + let ForkConversationResponse { + conversation_id: forked_id, + model: forked_model, + initial_messages: response_initial_messages, + rollout_path: response_rollout_path, + } = to_response::(fork_resp)?; + + assert_eq!(forked_model, "o3"); + assert_eq!(response_rollout_path, rollout_path); + assert_ne!(forked_id.to_string(), conversation_id); + + let response_initial_messages = + response_initial_messages.expect("expected initial messages in fork response"); + match response_initial_messages.as_slice() { + [EventMsg::UserMessage(message)] => { + assert_eq!(message.message, preview); + } + other => panic!("unexpected initial messages in fork response: {other:#?}"), + } + + let after_contents = std::fs::read_to_string(&original_path)?; + assert_eq!( + after_contents, original_contents, + "fork should not mutate the original rollout file" + ); + + Ok(()) +} diff --git a/codex-rs/app-server/tests/suite/mod.rs b/codex-rs/app-server/tests/suite/mod.rs index 41d6f83b957..ae7e0cb438d 100644 --- a/codex-rs/app-server/tests/suite/mod.rs +++ b/codex-rs/app-server/tests/suite/mod.rs @@ -3,6 +3,7 @@ mod auth; mod codex_message_processor_flow; mod config; mod create_thread; +mod fork_thread; mod fuzzy_file_search; mod interrupt; mod list_resume; diff --git a/codex-rs/app-server/tests/suite/v2/mod.rs b/codex-rs/app-server/tests/suite/v2/mod.rs index 1ef00c6939d..92267357909 100644 --- a/codex-rs/app-server/tests/suite/v2/mod.rs +++ b/codex-rs/app-server/tests/suite/v2/mod.rs @@ -5,6 +5,7 @@ mod output_schema; mod rate_limits; mod review; mod thread_archive; +mod thread_fork; mod thread_list; mod thread_resume; mod thread_rollback; diff --git a/codex-rs/app-server/tests/suite/v2/thread_fork.rs b/codex-rs/app-server/tests/suite/v2/thread_fork.rs new file mode 100644 index 00000000000..c6ee2878d6b --- /dev/null +++ b/codex-rs/app-server/tests/suite/v2/thread_fork.rs @@ -0,0 +1,140 @@ +use anyhow::Result; +use app_test_support::McpProcess; +use app_test_support::create_fake_rollout; +use app_test_support::create_mock_chat_completions_server; +use app_test_support::to_response; +use codex_app_server_protocol::JSONRPCNotification; +use codex_app_server_protocol::JSONRPCResponse; +use codex_app_server_protocol::RequestId; +use codex_app_server_protocol::SessionSource; +use codex_app_server_protocol::ThreadForkParams; +use codex_app_server_protocol::ThreadForkResponse; +use codex_app_server_protocol::ThreadItem; +use codex_app_server_protocol::ThreadStartedNotification; +use codex_app_server_protocol::TurnStatus; +use codex_app_server_protocol::UserInput; +use pretty_assertions::assert_eq; +use std::path::Path; +use tempfile::TempDir; +use tokio::time::timeout; + +const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10); + +#[tokio::test] +async fn thread_fork_creates_new_thread_and_emits_started() -> Result<()> { + let server = create_mock_chat_completions_server(vec![]).await; + let codex_home = TempDir::new()?; + create_config_toml(codex_home.path(), &server.uri())?; + + let preview = "Saved user message"; + let conversation_id = create_fake_rollout( + codex_home.path(), + "2025-01-05T12-00-00", + "2025-01-05T12:00:00Z", + preview, + Some("mock_provider"), + None, + )?; + + let original_path = codex_home + .path() + .join("sessions") + .join("2025") + .join("01") + .join("05") + .join(format!( + "rollout-2025-01-05T12-00-00-{conversation_id}.jsonl" + )); + assert!( + original_path.exists(), + "expected original rollout to exist at {}", + original_path.display() + ); + let original_contents = std::fs::read_to_string(&original_path)?; + + let mut mcp = McpProcess::new(codex_home.path()).await?; + timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??; + + let fork_id = mcp + .send_thread_fork_request(ThreadForkParams { + thread_id: conversation_id.clone(), + ..Default::default() + }) + .await?; + let fork_resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(fork_id)), + ) + .await??; + let ThreadForkResponse { thread, .. } = to_response::(fork_resp)?; + + let after_contents = std::fs::read_to_string(&original_path)?; + assert_eq!( + after_contents, original_contents, + "fork should not mutate the original rollout file" + ); + + assert_ne!(thread.id, conversation_id); + assert_eq!(thread.preview, preview); + assert_eq!(thread.model_provider, "mock_provider"); + assert!(thread.path.is_absolute()); + assert_ne!(thread.path, original_path); + assert!(thread.cwd.is_absolute()); + assert_eq!(thread.source, SessionSource::VsCode); + + assert_eq!( + thread.turns.len(), + 1, + "expected forked thread to include one turn" + ); + let turn = &thread.turns[0]; + assert_eq!(turn.status, TurnStatus::Completed); + assert_eq!(turn.items.len(), 1, "expected user message item"); + match &turn.items[0] { + ThreadItem::UserMessage { content, .. } => { + assert_eq!( + content, + &vec![UserInput::Text { + text: preview.to_string() + }] + ); + } + other => panic!("expected user message item, got {other:?}"), + } + + // A corresponding thread/started notification should arrive. + let notif: JSONRPCNotification = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_notification_message("thread/started"), + ) + .await??; + let started: ThreadStartedNotification = + serde_json::from_value(notif.params.expect("params must be present"))?; + assert_eq!(started.thread, thread); + + Ok(()) +} + +// Helper to create a config.toml pointing at the mock model server. +fn create_config_toml(codex_home: &Path, server_uri: &str) -> std::io::Result<()> { + let config_toml = codex_home.join("config.toml"); + std::fs::write( + config_toml, + format!( + r#" +model = "mock-model" +approval_policy = "never" +sandbox_mode = "read-only" + +model_provider = "mock_provider" + +[model_providers.mock_provider] +name = "Mock provider for test" +base_url = "{server_uri}/v1" +wire_api = "chat" +request_max_retries = 0 +stream_max_retries = 0 +"# + ), + ) +} diff --git a/codex-rs/core/src/rollout/truncation.rs b/codex-rs/core/src/rollout/truncation.rs index b8127f0345b..cd222403246 100644 --- a/codex-rs/core/src/rollout/truncation.rs +++ b/codex-rs/core/src/rollout/truncation.rs @@ -45,12 +45,17 @@ pub(crate) fn user_message_positions_in_rollout(items: &[RolloutItem]) -> Vec Vec { + if n_from_start == usize::MAX { + return items.to_vec(); + } + let user_positions = user_message_positions_in_rollout(items); // If fewer than or equal to n user messages exist, treat as empty (out of range). @@ -139,6 +144,22 @@ mod tests { assert_matches!(truncated2.as_slice(), []); } + #[test] + fn truncation_max_keeps_full_rollout() { + let rollout = vec![ + RolloutItem::ResponseItem(user_msg("u1")), + RolloutItem::ResponseItem(assistant_msg("a1")), + RolloutItem::ResponseItem(user_msg("u2")), + ]; + + let truncated = truncate_rollout_before_nth_user_message_from_start(&rollout, usize::MAX); + + assert_eq!( + serde_json::to_value(&truncated).unwrap(), + serde_json::to_value(&rollout).unwrap() + ); + } + #[test] fn truncates_rollout_from_start_applies_thread_rollback_markers() { let rollout_items = vec![ diff --git a/codex-rs/core/src/thread_manager.rs b/codex-rs/core/src/thread_manager.rs index d82b242fe5c..09fa12cb932 100644 --- a/codex-rs/core/src/thread_manager.rs +++ b/codex-rs/core/src/thread_manager.rs @@ -179,7 +179,7 @@ impl ThreadManager { /// Fork an existing thread by taking messages up to the given position (not including /// the message at the given position) and starting a new thread with identical /// configuration (unless overridden by the caller's `config`). The new thread will have - /// a fresh id. + /// a fresh id. Pass `usize::MAX` to keep the full rollout history. pub async fn fork_thread( &self, nth_user_message: usize,