Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the Linux prctl syscall #1550

Merged
merged 1 commit into from
Aug 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ This project adheres to [Semantic Versioning](https://semver.org/).
([#2014](https://github.com/nix-rust/nix/pull/2014))
- Added `SO_TS_CLOCK` for FreeBSD to `nix::sys::socket::sockopt`.
([#2093](https://github.com/nix-rust/nix/pull/2093))
- Added support for prctl in Linux.
(#[1550](https://github.com/nix-rust/nix/pull/1550))


### Changed

Expand Down
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,7 @@ path = "test/test_clearenv.rs"
name = "test-mount"
path = "test/test_mount.rs"
harness = false

[[test]]
name = "test-prctl"
path = "test/sys/test_prctl.rs"
6 changes: 6 additions & 0 deletions src/sys/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ feature! {
pub mod personality;
}

#[cfg(target_os = "linux")]
feature! {
#![feature = "process"]
pub mod prctl;
}

feature! {
#![feature = "pthread"]
pub mod pthread;
Expand Down
208 changes: 208 additions & 0 deletions src/sys/prctl.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
//! prctl is a Linux-only API for performing operations on a process or thread.
//!
//! Note that careless use of some prctl() operations can confuse the user-space run-time
//! environment, so these operations should be used with care.
//!
//! For more documentation, please read [prctl(2)](https://man7.org/linux/man-pages/man2/prctl.2.html).

use crate::errno::Errno;
use crate::sys::signal::Signal;
use crate::Result;

use libc::{c_int, c_ulong};
use std::convert::TryFrom;
use std::ffi::{CStr, CString};

libc_enum! {
/// The type of hardware memory corruption kill policy for the thread.

#[repr(i32)]
#[non_exhaustive]
#[allow(non_camel_case_types)]
pub enum PrctlMCEKillPolicy {
/// The thread will receive SIGBUS as soon as a memory corruption is detected.
PR_MCE_KILL_EARLY,
/// The process is killed only when it accesses a corrupted page.
PR_MCE_KILL_LATE,
/// Uses the system-wide default.
PR_MCE_KILL_DEFAULT,
}
impl TryFrom<i32>
}

fn prctl_set_bool(option: c_int, status: bool) -> Result<()> {
let res = unsafe { libc::prctl(option, status as c_ulong, 0, 0, 0) };
Errno::result(res).map(drop)
}

fn prctl_get_bool(option: c_int) -> Result<bool> {
let res = unsafe { libc::prctl(option, 0, 0, 0, 0) };
Errno::result(res).map(|res| res != 0)
}

/// Set the "child subreaper" attribute for this process
pub fn set_child_subreaper(attribute: bool) -> Result<()> {
prctl_set_bool(libc::PR_SET_CHILD_SUBREAPER, attribute)
}

/// Get the "child subreaper" attribute for this process
pub fn get_child_subreaper() -> Result<bool> {
// prctl writes into this var
let mut subreaper: c_int = 0;

let res = unsafe { libc::prctl(libc::PR_GET_CHILD_SUBREAPER, &mut subreaper, 0, 0, 0) };

Errno::result(res).map(|_| subreaper != 0)
}

/// Set the dumpable attribute which determines if core dumps are created for this process.
pub fn set_dumpable(attribute: bool) -> Result<()> {
prctl_set_bool(libc::PR_SET_DUMPABLE, attribute)
}

/// Get the dumpable attribute for this process.
pub fn get_dumpable() -> Result<bool> {
prctl_get_bool(libc::PR_GET_DUMPABLE)
}

/// Set the "keep capabilities" attribute for this process. This causes the thread to retain
/// capabilities even if it switches its UID to a nonzero value.
pub fn set_keepcaps(attribute: bool) -> Result<()> {
prctl_set_bool(libc::PR_SET_KEEPCAPS, attribute)
}

/// Get the "keep capabilities" attribute for this process
pub fn get_keepcaps() -> Result<bool> {
prctl_get_bool(libc::PR_GET_KEEPCAPS)
}

/// Clear the thread memory corruption kill policy and use the system-wide default
pub fn clear_mce_kill() -> Result<()> {
let res = unsafe { libc::prctl(libc::PR_MCE_KILL, libc::PR_MCE_KILL_CLEAR, 0, 0, 0) };

Errno::result(res).map(drop)
}

/// Set the thread memory corruption kill policy
pub fn set_mce_kill(policy: PrctlMCEKillPolicy) -> Result<()> {
let res = unsafe {
libc::prctl(
libc::PR_MCE_KILL,
libc::PR_MCE_KILL_SET,
policy as c_ulong,
0,
0,
)
};

Errno::result(res).map(drop)
}

/// Get the thread memory corruption kill policy
pub fn get_mce_kill() -> Result<PrctlMCEKillPolicy> {
let res = unsafe { libc::prctl(libc::PR_MCE_KILL_GET, 0, 0, 0, 0) };

match Errno::result(res) {
Ok(val) => Ok(PrctlMCEKillPolicy::try_from(val)?),
Err(e) => Err(e),
}
}

/// Set the parent-death signal of the calling process. This is the signal that the calling process
/// will get when its parent dies.
pub fn set_pdeathsig<T: Into<Option<Signal>>>(signal: T) -> Result<()> {
let sig = match signal.into() {
Some(s) => s as c_int,
None => 0,
};

let res = unsafe { libc::prctl(libc::PR_SET_PDEATHSIG, sig, 0, 0, 0) };

Errno::result(res).map(drop)
}

/// Returns the current parent-death signal
pub fn get_pdeathsig() -> Result<Option<Signal>> {
// prctl writes into this var
let mut sig: c_int = 0;

let res = unsafe { libc::prctl(libc::PR_GET_PDEATHSIG, &mut sig, 0, 0, 0) };

match Errno::result(res) {
Ok(_) => Ok(match sig {
0 => None,
_ => Some(Signal::try_from(sig)?),
}),
Err(e) => Err(e),
}
}

/// Set the name of the calling thread. Strings longer than 15 bytes will be truncated.
pub fn set_name(name: &CStr) -> Result<()> {
let res = unsafe { libc::prctl(libc::PR_SET_NAME, name.as_ptr(), 0, 0, 0) };

Errno::result(res).map(drop)
}

/// Return the name of the calling thread
pub fn get_name() -> Result<CString> {
// Size of buffer determined by linux/sched.h TASK_COMM_LEN
let buf = [0u8; 16];

let res = unsafe { libc::prctl(libc::PR_GET_NAME, &buf, 0, 0, 0) };

let len = buf.iter().position(|&c| c == 0).unwrap_or(buf.len());
let name = CStr::from_bytes_with_nul(&buf[..=len]).map_err(|_| Errno::EINVAL)?;
Comment on lines +154 to +155
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally you could do this instead. But that would require raising the MSRV to 1.69.0, so probably we shouldn't do it.

Suggested change
let len = buf.iter().position(|&c| c == 0).unwrap_or(buf.len());
let name = CStr::from_bytes_with_nul(&buf[..=len]).map_err(|_| Errno::EINVAL)?;
let name = CStr::from_bytes_until_nul(&buf[..]);


Errno::result(res).map(|_| name.to_owned())
}

/// Sets the timer slack value for the calling thread. Timer slack is used by the kernel to group
/// timer expirations and make them the supplied amount of nanoseconds late.
pub fn set_timerslack(ns: u64) -> Result<()> {
let res = unsafe { libc::prctl(libc::PR_SET_TIMERSLACK, ns, 0, 0, 0) };

Errno::result(res).map(drop)
}

/// Get the timerslack for the calling thread.
pub fn get_timerslack() -> Result<i32> {
let res = unsafe { libc::prctl(libc::PR_GET_TIMERSLACK, 0, 0, 0, 0) };

Errno::result(res)
}

/// Disable all performance counters attached to the calling process.
pub fn task_perf_events_disable() -> Result<()> {
let res = unsafe { libc::prctl(libc::PR_TASK_PERF_EVENTS_DISABLE, 0, 0, 0, 0) };

Errno::result(res).map(drop)
}

/// Enable all performance counters attached to the calling process.
pub fn task_perf_events_enable() -> Result<()> {
let res = unsafe { libc::prctl(libc::PR_TASK_PERF_EVENTS_ENABLE, 0, 0, 0, 0) };

Errno::result(res).map(drop)
}

/// Set the calling threads "no new privs" attribute. Once set this option can not be unset.
pub fn set_no_new_privs() -> Result<()> {
prctl_set_bool(libc::PR_SET_NO_NEW_PRIVS, true) // Cannot be unset
}

/// Get the "no new privs" attribute for the calling thread.
pub fn get_no_new_privs() -> Result<bool> {
prctl_get_bool(libc::PR_GET_NO_NEW_PRIVS)
}

/// Set the state of the "THP disable" flag for the calling thread. Setting this disables
/// transparent huge pages.
pub fn set_thp_disable(flag: bool) -> Result<()> {
prctl_set_bool(libc::PR_SET_THP_DISABLE, flag)
}

/// Get the "THP disable" flag for the calling thread.
pub fn get_thp_disable() -> Result<bool> {
prctl_get_bool(libc::PR_GET_THP_DISABLE)
}
125 changes: 125 additions & 0 deletions test/sys/test_prctl.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
#[cfg(target_os = "linux")]
#[cfg(feature = "process")]
mod test_prctl {
use std::ffi::CStr;

use nix::sys::prctl;

#[cfg_attr(qemu, ignore)]
#[test]
fn test_get_set_subreaper() {
let original = prctl::get_child_subreaper().unwrap();

prctl::set_child_subreaper(true).unwrap();
let subreaper = prctl::get_child_subreaper().unwrap();
assert!(subreaper);

prctl::set_child_subreaper(original).unwrap();
}

#[test]
fn test_get_set_dumpable() {
let original = prctl::get_dumpable().unwrap();

prctl::set_dumpable(false).unwrap();
let dumpable = prctl::get_dumpable().unwrap();
assert!(!dumpable);

prctl::set_dumpable(original).unwrap();
}

#[test]
fn test_get_set_keepcaps() {
let original = prctl::get_keepcaps().unwrap();

prctl::set_keepcaps(true).unwrap();
let keepcaps = prctl::get_keepcaps().unwrap();
assert!(keepcaps);

prctl::set_keepcaps(original).unwrap();
}

#[test]
fn test_get_set_clear_mce_kill() {
use prctl::PrctlMCEKillPolicy::*;

prctl::set_mce_kill(PR_MCE_KILL_LATE).unwrap();
let mce = prctl::get_mce_kill().unwrap();
assert_eq!(mce, PR_MCE_KILL_LATE);

prctl::clear_mce_kill().unwrap();
let mce = prctl::get_mce_kill().unwrap();
assert_eq!(mce, PR_MCE_KILL_DEFAULT);
}

#[cfg_attr(qemu, ignore)]
#[test]
fn test_get_set_pdeathsig() {
use nix::sys::signal::Signal;

let original = prctl::get_pdeathsig().unwrap();

prctl::set_pdeathsig(Signal::SIGUSR1).unwrap();
let sig = prctl::get_pdeathsig().unwrap();
assert_eq!(sig, Some(Signal::SIGUSR1));

prctl::set_pdeathsig(original).unwrap();
}

#[test]
fn test_get_set_name() {
let original = prctl::get_name().unwrap();

let long_name =
CStr::from_bytes_with_nul(b"0123456789abcdefghijklmn\0").unwrap();
prctl::set_name(long_name).unwrap();
let res = prctl::get_name().unwrap();

// name truncated by kernel to TASK_COMM_LEN
assert_eq!(&long_name.to_str().unwrap()[..15], res.to_str().unwrap());

let short_name = CStr::from_bytes_with_nul(b"01234567\0").unwrap();
prctl::set_name(short_name).unwrap();
let res = prctl::get_name().unwrap();
assert_eq!(short_name.to_str().unwrap(), res.to_str().unwrap());

prctl::set_name(&original).unwrap();
}

#[cfg_attr(qemu, ignore)]
#[test]
fn test_get_set_timerslack() {
let original = prctl::get_timerslack().unwrap();

let slack = 60_000;
prctl::set_timerslack(slack).unwrap();
let res = prctl::get_timerslack().unwrap();
assert_eq!(slack, res as u64);

prctl::set_timerslack(original as u64).unwrap();
}

#[test]
fn test_disable_enable_perf_events() {
prctl::task_perf_events_disable().unwrap();
prctl::task_perf_events_enable().unwrap();
}

#[test]
fn test_get_set_no_new_privs() {
prctl::set_no_new_privs().unwrap();
let no_new_privs = prctl::get_no_new_privs().unwrap();
assert!(no_new_privs);
}

#[test]
fn test_get_set_thp_disable() {
let original = prctl::get_thp_disable().unwrap();

prctl::set_thp_disable(true).unwrap();
let thp_disable = prctl::get_thp_disable().unwrap();
assert!(thp_disable);

prctl::set_thp_disable(original).unwrap();
}
}
DOhlsson marked this conversation as resolved.
Show resolved Hide resolved