Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions codex-rs/core/src/shell_snapshot.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::io::ErrorKind;
use std::path::Path;
use std::path::PathBuf;
use std::process::Stdio;
use std::sync::Arc;
use std::time::Duration;
use std::time::SystemTime;
Expand Down Expand Up @@ -181,6 +182,7 @@ async fn run_script_with_timeout(
// returns a ref of handler.
let mut handler = Command::new(&args[0]);
handler.args(&args[1..]);
handler.stdin(Stdio::null());
#[cfg(unix)]
unsafe {
handler.pre_exec(|| {
Expand Down Expand Up @@ -473,6 +475,62 @@ mod tests {

use tempfile::tempdir;

#[cfg(unix)]
struct BlockingStdinPipe {
original: i32,
write_end: i32,
}

#[cfg(unix)]
impl BlockingStdinPipe {
fn install() -> Result<Self> {
let mut fds = [0i32; 2];
if unsafe { libc::pipe(fds.as_mut_ptr()) } == -1 {
return Err(std::io::Error::last_os_error()).context("create stdin pipe");
}

let original = unsafe { libc::dup(libc::STDIN_FILENO) };
if original == -1 {
let err = std::io::Error::last_os_error();
unsafe {
libc::close(fds[0]);
libc::close(fds[1]);
}
return Err(err).context("dup stdin");
}

if unsafe { libc::dup2(fds[0], libc::STDIN_FILENO) } == -1 {
let err = std::io::Error::last_os_error();
unsafe {
libc::close(fds[0]);
libc::close(fds[1]);
libc::close(original);
}
return Err(err).context("replace stdin");
}

unsafe {
libc::close(fds[0]);
}

Ok(Self {
original,
write_end: fds[1],
})
}
}

#[cfg(unix)]
impl Drop for BlockingStdinPipe {
fn drop(&mut self) {
unsafe {
libc::dup2(self.original, libc::STDIN_FILENO);
libc::close(self.original);
libc::close(self.write_end);
}
}
}

#[cfg(not(target_os = "windows"))]
fn assert_posix_snapshot_sections(snapshot: &str) {
assert!(snapshot.contains("# Snapshot file"));
Expand Down Expand Up @@ -553,6 +611,38 @@ mod tests {
Ok(())
}

#[cfg(unix)]
#[tokio::test]
async fn snapshot_shell_does_not_inherit_stdin() -> Result<()> {
let _stdin_guard = BlockingStdinPipe::install()?;

let dir = tempdir()?;
let home = dir.path();
fs::write(home.join(".bashrc"), "read -r ignored\n").await?;

let shell = Shell {
shell_type: ShellType::Bash,
shell_path: PathBuf::from("/bin/bash"),
shell_snapshot: crate::shell::empty_shell_snapshot_receiver(),
};

let home_display = home.display();
let script = format!(
"HOME=\"{home_display}\"; export HOME; {}",
bash_snapshot_script()
);
let output = run_script_with_timeout(&shell, &script, Duration::from_millis(500), true)
.await
.context("run snapshot command")?;

assert!(
output.contains("# Snapshot file"),
"expected snapshot marker in output; output={output:?}"
);

Ok(())
}

#[cfg(target_os = "linux")]
#[tokio::test]
async fn timed_out_snapshot_shell_is_terminated() -> Result<()> {
Expand Down
Loading