diff --git a/CHANGELOG.md b/CHANGELOG.md index 43da4107a5..2a8fd2e044 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,9 @@ This project adheres to [Semantic Versioning](https://semver.org/). ([#2014](https://github.com/nix-rust/nix/pull/2014)) - Added `SO_TS_CLOCK` for FreeBSD to `nix::sys::socket::sockopt`. ([#2093](https://github.com/nix-rust/nix/pull/2093)) +- Added support for prctl in Linux. + (#[1550](https://github.com/nix-rust/nix/pull/1550)) + ### Changed diff --git a/Cargo.toml b/Cargo.toml index 010c9f895b..9cc9027fd5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -99,3 +99,7 @@ path = "test/test_clearenv.rs" name = "test-mount" path = "test/test_mount.rs" harness = false + +[[test]] +name = "test-prctl" +path = "test/sys/test_prctl.rs" diff --git a/src/sys/mod.rs b/src/sys/mod.rs index 9bba1c4e9e..bf047b3dda 100644 --- a/src/sys/mod.rs +++ b/src/sys/mod.rs @@ -67,6 +67,12 @@ feature! { pub mod personality; } +#[cfg(target_os = "linux")] +feature! { + #![feature = "process"] + pub mod prctl; +} + feature! { #![feature = "pthread"] pub mod pthread; diff --git a/src/sys/prctl.rs b/src/sys/prctl.rs new file mode 100644 index 0000000000..995382cb0c --- /dev/null +++ b/src/sys/prctl.rs @@ -0,0 +1,208 @@ +//! prctl is a Linux-only API for performing operations on a process or thread. +//! +//! Note that careless use of some prctl() operations can confuse the user-space run-time +//! environment, so these operations should be used with care. +//! +//! For more documentation, please read [prctl(2)](https://man7.org/linux/man-pages/man2/prctl.2.html). + +use crate::errno::Errno; +use crate::sys::signal::Signal; +use crate::Result; + +use libc::{c_int, c_ulong}; +use std::convert::TryFrom; +use std::ffi::{CStr, CString}; + +libc_enum! { + /// The type of hardware memory corruption kill policy for the thread. + + #[repr(i32)] + #[non_exhaustive] + #[allow(non_camel_case_types)] + pub enum PrctlMCEKillPolicy { + /// The thread will receive SIGBUS as soon as a memory corruption is detected. + PR_MCE_KILL_EARLY, + /// The process is killed only when it accesses a corrupted page. + PR_MCE_KILL_LATE, + /// Uses the system-wide default. + PR_MCE_KILL_DEFAULT, + } + impl TryFrom +} + +fn prctl_set_bool(option: c_int, status: bool) -> Result<()> { + let res = unsafe { libc::prctl(option, status as c_ulong, 0, 0, 0) }; + Errno::result(res).map(drop) +} + +fn prctl_get_bool(option: c_int) -> Result { + let res = unsafe { libc::prctl(option, 0, 0, 0, 0) }; + Errno::result(res).map(|res| res != 0) +} + +/// Set the "child subreaper" attribute for this process +pub fn set_child_subreaper(attribute: bool) -> Result<()> { + prctl_set_bool(libc::PR_SET_CHILD_SUBREAPER, attribute) +} + +/// Get the "child subreaper" attribute for this process +pub fn get_child_subreaper() -> Result { + // prctl writes into this var + let mut subreaper: c_int = 0; + + let res = unsafe { libc::prctl(libc::PR_GET_CHILD_SUBREAPER, &mut subreaper, 0, 0, 0) }; + + Errno::result(res).map(|_| subreaper != 0) +} + +/// Set the dumpable attribute which determines if core dumps are created for this process. +pub fn set_dumpable(attribute: bool) -> Result<()> { + prctl_set_bool(libc::PR_SET_DUMPABLE, attribute) +} + +/// Get the dumpable attribute for this process. +pub fn get_dumpable() -> Result { + prctl_get_bool(libc::PR_GET_DUMPABLE) +} + +/// Set the "keep capabilities" attribute for this process. This causes the thread to retain +/// capabilities even if it switches its UID to a nonzero value. +pub fn set_keepcaps(attribute: bool) -> Result<()> { + prctl_set_bool(libc::PR_SET_KEEPCAPS, attribute) +} + +/// Get the "keep capabilities" attribute for this process +pub fn get_keepcaps() -> Result { + prctl_get_bool(libc::PR_GET_KEEPCAPS) +} + +/// Clear the thread memory corruption kill policy and use the system-wide default +pub fn clear_mce_kill() -> Result<()> { + let res = unsafe { libc::prctl(libc::PR_MCE_KILL, libc::PR_MCE_KILL_CLEAR, 0, 0, 0) }; + + Errno::result(res).map(drop) +} + +/// Set the thread memory corruption kill policy +pub fn set_mce_kill(policy: PrctlMCEKillPolicy) -> Result<()> { + let res = unsafe { + libc::prctl( + libc::PR_MCE_KILL, + libc::PR_MCE_KILL_SET, + policy as c_ulong, + 0, + 0, + ) + }; + + Errno::result(res).map(drop) +} + +/// Get the thread memory corruption kill policy +pub fn get_mce_kill() -> Result { + let res = unsafe { libc::prctl(libc::PR_MCE_KILL_GET, 0, 0, 0, 0) }; + + match Errno::result(res) { + Ok(val) => Ok(PrctlMCEKillPolicy::try_from(val)?), + Err(e) => Err(e), + } +} + +/// Set the parent-death signal of the calling process. This is the signal that the calling process +/// will get when its parent dies. +pub fn set_pdeathsig>>(signal: T) -> Result<()> { + let sig = match signal.into() { + Some(s) => s as c_int, + None => 0, + }; + + let res = unsafe { libc::prctl(libc::PR_SET_PDEATHSIG, sig, 0, 0, 0) }; + + Errno::result(res).map(drop) +} + +/// Returns the current parent-death signal +pub fn get_pdeathsig() -> Result> { + // prctl writes into this var + let mut sig: c_int = 0; + + let res = unsafe { libc::prctl(libc::PR_GET_PDEATHSIG, &mut sig, 0, 0, 0) }; + + match Errno::result(res) { + Ok(_) => Ok(match sig { + 0 => None, + _ => Some(Signal::try_from(sig)?), + }), + Err(e) => Err(e), + } +} + +/// Set the name of the calling thread. Strings longer than 15 bytes will be truncated. +pub fn set_name(name: &CStr) -> Result<()> { + let res = unsafe { libc::prctl(libc::PR_SET_NAME, name.as_ptr(), 0, 0, 0) }; + + Errno::result(res).map(drop) +} + +/// Return the name of the calling thread +pub fn get_name() -> Result { + // Size of buffer determined by linux/sched.h TASK_COMM_LEN + let buf = [0u8; 16]; + + let res = unsafe { libc::prctl(libc::PR_GET_NAME, &buf, 0, 0, 0) }; + + let len = buf.iter().position(|&c| c == 0).unwrap_or(buf.len()); + let name = CStr::from_bytes_with_nul(&buf[..=len]).map_err(|_| Errno::EINVAL)?; + + Errno::result(res).map(|_| name.to_owned()) +} + +/// Sets the timer slack value for the calling thread. Timer slack is used by the kernel to group +/// timer expirations and make them the supplied amount of nanoseconds late. +pub fn set_timerslack(ns: u64) -> Result<()> { + let res = unsafe { libc::prctl(libc::PR_SET_TIMERSLACK, ns, 0, 0, 0) }; + + Errno::result(res).map(drop) +} + +/// Get the timerslack for the calling thread. +pub fn get_timerslack() -> Result { + let res = unsafe { libc::prctl(libc::PR_GET_TIMERSLACK, 0, 0, 0, 0) }; + + Errno::result(res) +} + +/// Disable all performance counters attached to the calling process. +pub fn task_perf_events_disable() -> Result<()> { + let res = unsafe { libc::prctl(libc::PR_TASK_PERF_EVENTS_DISABLE, 0, 0, 0, 0) }; + + Errno::result(res).map(drop) +} + +/// Enable all performance counters attached to the calling process. +pub fn task_perf_events_enable() -> Result<()> { + let res = unsafe { libc::prctl(libc::PR_TASK_PERF_EVENTS_ENABLE, 0, 0, 0, 0) }; + + Errno::result(res).map(drop) +} + +/// Set the calling threads "no new privs" attribute. Once set this option can not be unset. +pub fn set_no_new_privs() -> Result<()> { + prctl_set_bool(libc::PR_SET_NO_NEW_PRIVS, true) // Cannot be unset +} + +/// Get the "no new privs" attribute for the calling thread. +pub fn get_no_new_privs() -> Result { + prctl_get_bool(libc::PR_GET_NO_NEW_PRIVS) +} + +/// Set the state of the "THP disable" flag for the calling thread. Setting this disables +/// transparent huge pages. +pub fn set_thp_disable(flag: bool) -> Result<()> { + prctl_set_bool(libc::PR_SET_THP_DISABLE, flag) +} + +/// Get the "THP disable" flag for the calling thread. +pub fn get_thp_disable() -> Result { + prctl_get_bool(libc::PR_GET_THP_DISABLE) +} diff --git a/test/sys/test_prctl.rs b/test/sys/test_prctl.rs new file mode 100644 index 0000000000..351213b7ef --- /dev/null +++ b/test/sys/test_prctl.rs @@ -0,0 +1,125 @@ +#[cfg(target_os = "linux")] +#[cfg(feature = "process")] +mod test_prctl { + use std::ffi::CStr; + + use nix::sys::prctl; + + #[cfg_attr(qemu, ignore)] + #[test] + fn test_get_set_subreaper() { + let original = prctl::get_child_subreaper().unwrap(); + + prctl::set_child_subreaper(true).unwrap(); + let subreaper = prctl::get_child_subreaper().unwrap(); + assert!(subreaper); + + prctl::set_child_subreaper(original).unwrap(); + } + + #[test] + fn test_get_set_dumpable() { + let original = prctl::get_dumpable().unwrap(); + + prctl::set_dumpable(false).unwrap(); + let dumpable = prctl::get_dumpable().unwrap(); + assert!(!dumpable); + + prctl::set_dumpable(original).unwrap(); + } + + #[test] + fn test_get_set_keepcaps() { + let original = prctl::get_keepcaps().unwrap(); + + prctl::set_keepcaps(true).unwrap(); + let keepcaps = prctl::get_keepcaps().unwrap(); + assert!(keepcaps); + + prctl::set_keepcaps(original).unwrap(); + } + + #[test] + fn test_get_set_clear_mce_kill() { + use prctl::PrctlMCEKillPolicy::*; + + prctl::set_mce_kill(PR_MCE_KILL_LATE).unwrap(); + let mce = prctl::get_mce_kill().unwrap(); + assert_eq!(mce, PR_MCE_KILL_LATE); + + prctl::clear_mce_kill().unwrap(); + let mce = prctl::get_mce_kill().unwrap(); + assert_eq!(mce, PR_MCE_KILL_DEFAULT); + } + + #[cfg_attr(qemu, ignore)] + #[test] + fn test_get_set_pdeathsig() { + use nix::sys::signal::Signal; + + let original = prctl::get_pdeathsig().unwrap(); + + prctl::set_pdeathsig(Signal::SIGUSR1).unwrap(); + let sig = prctl::get_pdeathsig().unwrap(); + assert_eq!(sig, Some(Signal::SIGUSR1)); + + prctl::set_pdeathsig(original).unwrap(); + } + + #[test] + fn test_get_set_name() { + let original = prctl::get_name().unwrap(); + + let long_name = + CStr::from_bytes_with_nul(b"0123456789abcdefghijklmn\0").unwrap(); + prctl::set_name(long_name).unwrap(); + let res = prctl::get_name().unwrap(); + + // name truncated by kernel to TASK_COMM_LEN + assert_eq!(&long_name.to_str().unwrap()[..15], res.to_str().unwrap()); + + let short_name = CStr::from_bytes_with_nul(b"01234567\0").unwrap(); + prctl::set_name(short_name).unwrap(); + let res = prctl::get_name().unwrap(); + assert_eq!(short_name.to_str().unwrap(), res.to_str().unwrap()); + + prctl::set_name(&original).unwrap(); + } + + #[cfg_attr(qemu, ignore)] + #[test] + fn test_get_set_timerslack() { + let original = prctl::get_timerslack().unwrap(); + + let slack = 60_000; + prctl::set_timerslack(slack).unwrap(); + let res = prctl::get_timerslack().unwrap(); + assert_eq!(slack, res as u64); + + prctl::set_timerslack(original as u64).unwrap(); + } + + #[test] + fn test_disable_enable_perf_events() { + prctl::task_perf_events_disable().unwrap(); + prctl::task_perf_events_enable().unwrap(); + } + + #[test] + fn test_get_set_no_new_privs() { + prctl::set_no_new_privs().unwrap(); + let no_new_privs = prctl::get_no_new_privs().unwrap(); + assert!(no_new_privs); + } + + #[test] + fn test_get_set_thp_disable() { + let original = prctl::get_thp_disable().unwrap(); + + prctl::set_thp_disable(true).unwrap(); + let thp_disable = prctl::get_thp_disable().unwrap(); + assert!(thp_disable); + + prctl::set_thp_disable(original).unwrap(); + } +}