diff --git a/codex-rs/core/src/environment_context.rs b/codex-rs/core/src/environment_context.rs index e7b2e19ffb..4339511daf 100644 --- a/codex-rs/core/src/environment_context.rs +++ b/codex-rs/core/src/environment_context.rs @@ -20,6 +20,14 @@ pub enum NetworkAccess { Restricted, Enabled, } + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct OperatingSystemInfo { + pub name: String, + pub version: String, + pub is_likely_windows_subsystem_for_linux: Option, +} + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(rename = "environment_context", rename_all = "snake_case")] pub(crate) struct EnvironmentContext { @@ -29,6 +37,7 @@ pub(crate) struct EnvironmentContext { pub network_access: Option, pub writable_roots: Option>, pub shell: Option, + pub operating_system: Option, } impl EnvironmentContext { @@ -70,6 +79,7 @@ impl EnvironmentContext { _ => None, }, shell, + operating_system: Self::operating_system_info(), } } @@ -83,6 +93,7 @@ impl EnvironmentContext { sandbox_mode, network_access, writable_roots, + operating_system, // should compare all fields except shell shell: _, } = other; @@ -92,6 +103,7 @@ impl EnvironmentContext { && self.sandbox_mode == *sandbox_mode && self.network_access == *network_access && self.writable_roots == *writable_roots + && self.operating_system == *operating_system } pub fn diff(before: &TurnContext, after: &TurnContext) -> Self { @@ -110,7 +122,12 @@ impl EnvironmentContext { } else { None }; - EnvironmentContext::new(cwd, approval_policy, sandbox_policy, None) + // Diff messages should only include fields that changed between turns. + // Operating system is a static property of the host and should not be + // emitted as part of a per-turn diff. + let mut ec = EnvironmentContext::new(cwd, approval_policy, sandbox_policy, None); + ec.operating_system = None; + ec } } @@ -141,10 +158,11 @@ impl EnvironmentContext { /// ... /// /// ``` - pub fn serialize_to_xml(self) -> String { + pub fn serialize_to_xml(&self) -> String { let mut lines = vec![ENVIRONMENT_CONTEXT_OPEN_TAG.to_string()]; - if let Some(cwd) = self.cwd { - lines.push(format!(" {}", cwd.to_string_lossy())); + if let Some(cwd) = self.cwd.as_ref() { + let cwd = cwd.to_string_lossy(); + lines.push(format!(" {cwd}")); } if let Some(approval_policy) = self.approval_policy { lines.push(format!( @@ -154,29 +172,44 @@ impl EnvironmentContext { if let Some(sandbox_mode) = self.sandbox_mode { lines.push(format!(" {sandbox_mode}")); } - if let Some(network_access) = self.network_access { + if let Some(network_access) = self.network_access.as_ref() { lines.push(format!( " {network_access}" )); } - if let Some(writable_roots) = self.writable_roots { + if let Some(writable_roots) = self.writable_roots.as_ref() { lines.push(" ".to_string()); for writable_root in writable_roots { - lines.push(format!( - " {}", - writable_root.to_string_lossy() - )); + let writable_root = writable_root.to_string_lossy(); + lines.push(format!(" {writable_root}")); } lines.push(" ".to_string()); } - if let Some(shell) = self.shell + if let Some(shell) = self.shell.as_ref() && let Some(shell_name) = shell.name() { lines.push(format!(" {shell_name}")); } + if let Some(operating_system) = self.operating_system.as_ref() { + lines.push(" ".to_string()); + let name = operating_system.name.as_str(); + lines.push(format!(" {name}")); + let version = operating_system.version.as_str(); + lines.push(format!(" {version}")); + if let Some(is_wsl) = operating_system.is_likely_windows_subsystem_for_linux { + lines.push(format!( + " {is_wsl}" + )); + } + lines.push(" ".to_string()); + } lines.push(ENVIRONMENT_CONTEXT_CLOSE_TAG.to_string()); lines.join("\n") } + + fn operating_system_info() -> Option { + operating_system_info_impl() + } } impl From for ResponseItem { @@ -191,6 +224,47 @@ impl From for ResponseItem { } } +// Restrict Operating System Info to Windows and Linux inside WSL for now +#[cfg(target_os = "windows")] +fn operating_system_info_impl() -> Option { + let info = os_info::get(); + Some(OperatingSystemInfo { + name: info.os_type().to_string(), + version: info.version().to_string(), + is_likely_windows_subsystem_for_linux: Some(has_wsl_env_markers()), + }) +} + +#[cfg(all(unix, not(target_os = "macos")))] +fn operating_system_info_impl() -> Option { + let info = os_info::get(); + match has_wsl_env_markers() { + true => Some(OperatingSystemInfo { + name: info.os_type().to_string(), + version: info.version().to_string(), + is_likely_windows_subsystem_for_linux: Some(true), + }), + false => None, + } +} + +#[cfg(target_os = "macos")] +fn operating_system_info_impl() -> Option { + None +} + +#[cfg(not(target_os = "macos"))] +fn has_wsl_env_markers() -> bool { + // Cache detection result since env vars are stable across process lifetime + // and this function may be called multiple times. + static CACHE: std::sync::OnceLock = std::sync::OnceLock::new(); + *CACHE.get_or_init(|| { + std::env::var_os("WSL_INTEROP").is_some() + || std::env::var_os("WSLENV").is_some() + || std::env::var_os("WSL_DISTRO_NAME").is_some() + }) +} + #[cfg(test)] mod tests { use crate::shell::BashShell; @@ -198,6 +272,58 @@ mod tests { use super::*; use pretty_assertions::assert_eq; + fn expected_environment_context(mut body_lines: Vec) -> String { + let mut lines = vec!["".to_string()]; + lines.append(&mut body_lines); + if let Some(os) = EnvironmentContext::operating_system_info() { + lines.push(" ".to_string()); + lines.push(format!(" {}", os.name)); + lines.push(format!(" {}", os.version)); + if let Some(is_wsl) = os.is_likely_windows_subsystem_for_linux { + lines.push(format!( + " {is_wsl}" + )); + } + lines.push(" ".to_string()); + } + lines.push("".to_string()); + lines.join("\n") + } + + #[cfg(target_os = "windows")] + #[test] + fn operating_system_info_on_windows_includes_os_details() { + let info = operating_system_info_impl().expect("expected Windows operating system info"); + let os_details = os_info::get(); + + assert_eq!(info.name, os_details.os_type().to_string()); + assert_eq!(info.version, os_details.version().to_string()); + assert_eq!( + info.is_likely_windows_subsystem_for_linux, + Some(has_wsl_env_markers()) + ); + } + + #[cfg(all(unix, not(target_os = "macos")))] + #[test] + fn operating_system_info_matches_wsl_detection_on_unix() { + let info = operating_system_info_impl(); + let os_details = os_info::get(); + if has_wsl_env_markers() { + let info = info.expect("expected WSL operating system info"); + assert_eq!(info.name, os_details.os_type().to_string()); + assert_eq!(info.version, os_details.version().to_string()); + assert_eq!(info.is_likely_windows_subsystem_for_linux, Some(true)); + } else { + assert_eq!(info, None); + } + } + + #[cfg(target_os = "macos")] + #[test] + fn operating_system_info_is_none_on_macos() { + assert_eq!(operating_system_info_impl(), None); + } fn workspace_write_policy(writable_roots: Vec<&str>, network_access: bool) -> SandboxPolicy { SandboxPolicy::WorkspaceWrite { @@ -217,16 +343,16 @@ mod tests { None, ); - let expected = r#" - /repo - on-request - workspace-write - restricted - - /repo - /tmp - -"#; + let expected = expected_environment_context(vec![ + " /repo".to_string(), + " on-request".to_string(), + " workspace-write".to_string(), + " restricted".to_string(), + " ".to_string(), + " /repo".to_string(), + " /tmp".to_string(), + " ".to_string(), + ]); assert_eq!(context.serialize_to_xml(), expected); } @@ -240,11 +366,11 @@ mod tests { None, ); - let expected = r#" - never - read-only - restricted -"#; + let expected = expected_environment_context(vec![ + " never".to_string(), + " read-only".to_string(), + " restricted".to_string(), + ]); assert_eq!(context.serialize_to_xml(), expected); } @@ -258,11 +384,11 @@ mod tests { None, ); - let expected = r#" - on-failure - danger-full-access - enabled -"#; + let expected = expected_environment_context(vec![ + " on-failure".to_string(), + " danger-full-access".to_string(), + " enabled".to_string(), + ]); assert_eq!(context.serialize_to_xml(), expected); } diff --git a/codex-rs/core/tests/suite/prompt_caching.rs b/codex-rs/core/tests/suite/prompt_caching.rs index fee57784d7..df3b92aafc 100644 --- a/codex-rs/core/tests/suite/prompt_caching.rs +++ b/codex-rs/core/tests/suite/prompt_caching.rs @@ -36,19 +36,53 @@ fn text_user_input(text: String) -> serde_json::Value { }) } +#[allow(dead_code)] +fn has_wsl_env_markers() -> bool { + std::env::var_os("WSL_INTEROP").is_some() + || std::env::var_os("WSLENV").is_some() + || std::env::var_os("WSL_DISTRO_NAME").is_some() +} + +fn operating_system_context_block() -> String { + #[cfg(target_os = "windows")] + { + let info = os_info::get(); + let name = info.os_type().to_string(); + let version = info.version().to_string(); + let is_wsl = has_wsl_env_markers(); + format!( + " \n {name}\n {version}\n {is_wsl}\n \n" + ) + } + + #[cfg(all(unix, not(target_os = "macos")))] + { + if has_wsl_env_markers() { + " \n {name}\n \n true\n \n".to_string() + } else { + String::new() + } + } + + #[cfg(target_os = "macos")] + { + String::new() + } +} + fn default_env_context_str(cwd: &str, shell: &Shell) -> String { + let shell_line = match shell.name() { + Some(name) => format!(" {name}\n"), + None => String::new(), + }; + let os_block = operating_system_context_block(); format!( r#" - {} + {cwd} on-request read-only restricted -{}"#, - cwd, - match shell.name() { - Some(name) => format!(" {name}\n"), - None => String::new(), - } +{shell_line}{os_block}"# ) } @@ -341,22 +375,10 @@ async fn prefixes_context_and_instructions_once_and_consistently_across_requests let shell = default_user_shell().await; - let expected_env_text = format!( - r#" - {} - on-request - read-only - restricted -{}"#, - cwd.path().to_string_lossy(), - match shell.name() { - Some(name) => format!(" {name}\n"), - None => String::new(), - } - ); + let cwd_str = cwd.path().to_string_lossy().into_owned(); + let expected_env_text = default_env_context_str(&cwd_str, &shell); let expected_ui_text = format!( - "# AGENTS.md instructions for {}\n\n\nbe consistent and helpful\n", - cwd.path().to_string_lossy() + "# AGENTS.md instructions for {cwd_str}\n\n\nbe consistent and helpful\n" ); let expected_env_msg = serde_json::json!({