Skip to content

Commit

Permalink
Add uninit buffer ancillary APIs
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Saveau <saveau.alexandre@gmail.com>
  • Loading branch information
SUPERCILEX committed Aug 15, 2024
1 parent 51a88c0 commit 70b188a
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 39 deletions.
55 changes: 42 additions & 13 deletions src/net/send_recv/msg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::net::UCred;

use core::iter::FusedIterator;
use core::marker::PhantomData;
use core::mem::{align_of, size_of, size_of_val, take};
use core::mem::{align_of, size_of, size_of_val, take, MaybeUninit};
#[cfg(linux_kernel)]
use core::ptr::addr_of;
use core::{ptr, slice};
Expand All @@ -24,25 +24,31 @@ use super::{RecvFlags, SendFlags, SocketAddrAny, SocketAddrV4, SocketAddrV6};
///
/// Allocate a buffer for a single file descriptor:
/// ```
/// # use core::mem::MaybeUninit;
/// # use rustix::cmsg_space;
/// let mut space = [0; rustix::cmsg_space!(ScmRights(1))];
/// let mut space = [MaybeUninit::uninit(); rustix::cmsg_space!(ScmRights(1))];
/// # let _: &[MaybeUninit<u8>] = space.as_slice();
/// ```
///
/// Allocate a buffer for credentials:
/// ```
/// # #[cfg(linux_kernel)]
/// # {
/// # use core::mem::MaybeUninit;
/// # use rustix::cmsg_space;
/// let mut space = [0; rustix::cmsg_space!(ScmCredentials(1))];
/// let mut space = [MaybeUninit::uninit(); rustix::cmsg_space!(ScmCredentials(1))];
/// # let _: &[MaybeUninit<u8>] = space.as_slice();
/// # }
/// ```
///
/// Allocate a buffer for two file descriptors and credentials:
/// ```
/// # #[cfg(linux_kernel)]
/// # {
/// # use core::mem::MaybeUninit;
/// # use rustix::cmsg_space;
/// let mut space = [0; rustix::cmsg_space!(ScmRights(2), ScmCredentials(1))];
/// let mut space = [MaybeUninit::uninit(); rustix::cmsg_space!(ScmRights(2), ScmCredentials(1))];
/// # let _: &[MaybeUninit<u8>] = space.as_slice();
/// # }
/// ```
#[macro_export]
Expand Down Expand Up @@ -164,7 +170,7 @@ pub enum RecvAncillaryMessage<'a> {
/// [`push`]: SendAncillaryBuffer::push
pub struct SendAncillaryBuffer<'buf, 'slice, 'fd> {
/// Raw byte buffer for messages.
buffer: &'buf mut [u8],
buffer: &'buf mut [MaybeUninit<u8>],

/// The amount of the buffer that is used.
length: usize,
Expand All @@ -179,6 +185,12 @@ impl<'buf> From<&'buf mut [u8]> for SendAncillaryBuffer<'buf, '_, '_> {
}
}

impl<'buf> From<&'buf mut [MaybeUninit<u8>]> for SendAncillaryBuffer<'buf, '_, '_> {
fn from(buffer: &'buf mut [MaybeUninit<u8>]) -> Self {
Self::new_(buffer)
}
}

impl Default for SendAncillaryBuffer<'_, '_, '_> {
fn default() -> Self {
Self {
Expand Down Expand Up @@ -231,6 +243,10 @@ impl<'buf, 'slice, 'fd> SendAncillaryBuffer<'buf, 'slice, 'fd> {
/// [`send`]: crate::net::send
#[inline]
pub fn new(buffer: &'buf mut [u8]) -> Self {
Self::new_(unsafe { core::mem::transmute::<&mut [u8], &mut [MaybeUninit<u8>]>(buffer) })
}

fn new_(buffer: &'buf mut [MaybeUninit<u8>]) -> Self {
Self {
buffer: align_for_cmsghdr(buffer),
length: 0,
Expand All @@ -248,7 +264,7 @@ impl<'buf, 'slice, 'fd> SendAncillaryBuffer<'buf, 'slice, 'fd> {
return core::ptr::null_mut();
}

self.buffer.as_mut_ptr()
self.buffer.as_mut_ptr().cast()
}

/// Returns the length of the message data.
Expand Down Expand Up @@ -301,7 +317,7 @@ impl<'buf, 'slice, 'fd> SendAncillaryBuffer<'buf, 'slice, 'fd> {
let buffer = leap!(self.buffer.get_mut(..new_length));

// Fill the new part of the buffer with zeroes.
buffer[self.length..new_length].fill(0);
buffer[self.length..new_length].fill(MaybeUninit::new(0));
self.length = new_length;

// Get the last header in the buffer.
Expand Down Expand Up @@ -339,7 +355,7 @@ impl<'slice, 'fd> Extend<SendAncillaryMessage<'slice, 'fd>>
#[derive(Default)]
pub struct RecvAncillaryBuffer<'buf> {
/// Raw byte buffer for messages.
buffer: &'buf mut [u8],
buffer: &'buf mut [MaybeUninit<u8>],

/// The portion of the buffer we've read from already.
read: usize,
Expand All @@ -354,6 +370,12 @@ impl<'buf> From<&'buf mut [u8]> for RecvAncillaryBuffer<'buf> {
}
}

impl<'buf> From<&'buf mut [MaybeUninit<u8>]> for RecvAncillaryBuffer<'buf> {
fn from(buffer: &'buf mut [MaybeUninit<u8>]) -> Self {
Self::new_(buffer)
}
}

impl<'buf> RecvAncillaryBuffer<'buf> {
/// Create a new, empty `RecvAncillaryBuffer` from a raw byte buffer.
///
Expand Down Expand Up @@ -396,6 +418,10 @@ impl<'buf> RecvAncillaryBuffer<'buf> {
/// [`recv`]: crate::net::recv
#[inline]
pub fn new(buffer: &'buf mut [u8]) -> Self {
Self::new_(unsafe { core::mem::transmute::<&mut [u8], &mut [MaybeUninit<u8>]>(buffer) })
}

fn new_(buffer: &'buf mut [MaybeUninit<u8>]) -> Self {
Self {
buffer: align_for_cmsghdr(buffer),
read: 0,
Expand All @@ -413,7 +439,7 @@ impl<'buf> RecvAncillaryBuffer<'buf> {
return core::ptr::null_mut();
}

self.buffer.as_mut_ptr()
self.buffer.as_mut_ptr().cast()
}

/// Returns the length of the message data.
Expand Down Expand Up @@ -454,7 +480,7 @@ impl Drop for RecvAncillaryBuffer<'_> {
/// Return a slice of `buffer` starting at the first `cmsghdr` alignment
/// boundary.
#[inline]
fn align_for_cmsghdr(buffer: &mut [u8]) -> &mut [u8] {
fn align_for_cmsghdr(buffer: &mut [MaybeUninit<u8>]) -> &mut [MaybeUninit<u8>] {
// If the buffer is empty, we won't be writing anything into it, so it
// doesn't need to be aligned.
if buffer.is_empty() {
Expand Down Expand Up @@ -486,7 +512,9 @@ impl<'buf> AncillaryDrain<'buf> {
/// The buffer must contain valid message data (or be empty).
pub unsafe fn parse(buffer: &'buf mut [u8]) -> Self {
Self {
messages: messages::Messages::new(buffer),
messages: messages::Messages::new(unsafe {
core::mem::transmute::<&mut [u8], &mut [MaybeUninit<u8>]>(buffer)
}),
read_and_length: None,
}
}
Expand Down Expand Up @@ -920,6 +948,7 @@ mod messages {
use crate::backend::net::msghdr;
use core::iter::FusedIterator;
use core::marker::PhantomData;
use core::mem::MaybeUninit;
use core::ptr::NonNull;

/// An iterator over the messages in an ancillary buffer.
Expand All @@ -933,12 +962,12 @@ mod messages {
header: Option<NonNull<c::cmsghdr>>,

/// Capture the original lifetime of the buffer.
_buffer: PhantomData<&'buf mut [u8]>,
_buffer: PhantomData<&'buf mut [MaybeUninit<u8>]>,
}

impl<'buf> Messages<'buf> {
/// Create a new iterator over messages from a byte buffer.
pub(super) fn new(buf: &'buf mut [u8]) -> Self {
pub(super) fn new(buf: &'buf mut [MaybeUninit<u8>]) -> Self {
let msghdr = {
let mut h = msghdr::zero_msghdr();
h.msg_control = buf.as_mut_ptr().cast();
Expand Down
30 changes: 16 additions & 14 deletions tests/net/unix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use rustix::net::{
accept, bind_unix, connect_unix, listen, socket, AddressFamily, SocketAddrUnix, SocketType,
};
use rustix::path::DecInt;
use std::mem::MaybeUninit;
use std::path::Path;
use std::str::FromStr;
use std::sync::{Arc, Condvar, Mutex};
Expand Down Expand Up @@ -447,13 +448,13 @@ fn test_unix_msg_with_scm_rights() {
let mut pipe_end = None;

let mut buffer = [0; BUFFER_SIZE];
let mut cmsg_space = [0; rustix::cmsg_space!(ScmRights(1))];
let mut cmsg_space = [MaybeUninit::uninit(); rustix::cmsg_space!(ScmRights(1))];

'exit: loop {
let data_socket = accept(&connection_socket).unwrap();
let mut sum = 0;
loop {
let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg_space);
let mut cmsg_buffer = RecvAncillaryBuffer::from(cmsg_space.as_mut_slice());
let nread = recvmsg(
&data_socket,
&mut [IoSliceMut::new(&mut buffer)],
Expand Down Expand Up @@ -555,8 +556,8 @@ fn test_unix_msg_with_scm_rights() {
// Format the CMSG.
let we = [write_end.as_fd()];
let msg = SendAncillaryMessage::ScmRights(&we);
let mut space = [0; rustix::cmsg_space!(ScmRights(1))];
let mut cmsg_buffer = SendAncillaryBuffer::new(&mut space);
let mut space = [MaybeUninit::uninit(); rustix::cmsg_space!(ScmRights(1))];
let mut cmsg_buffer = SendAncillaryBuffer::from(space.as_mut_slice());
assert!(cmsg_buffer.push(msg));

connect_unix(&data_socket, &addr).unwrap();
Expand Down Expand Up @@ -615,8 +616,8 @@ fn test_unix_peercred_explicit() {

let ucred = sockopt::get_socket_peercred(&send_sock).unwrap();
let msg = SendAncillaryMessage::ScmCredentials(ucred);
let mut space = [0; rustix::cmsg_space!(ScmCredentials(1))];
let mut cmsg_buffer = SendAncillaryBuffer::new(&mut space);
let mut space = [MaybeUninit::uninit(); rustix::cmsg_space!(ScmCredentials(1))];
let mut cmsg_buffer = SendAncillaryBuffer::from(space.as_mut_slice());
assert!(cmsg_buffer.push(msg));

sendmsg(
Expand All @@ -627,8 +628,8 @@ fn test_unix_peercred_explicit() {
)
.unwrap();

let mut cmsg_space = [0; rustix::cmsg_space!(ScmCredentials(1))];
let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg_space);
let mut cmsg_space = [MaybeUninit::uninit(); rustix::cmsg_space!(ScmCredentials(1))];
let mut cmsg_buffer = RecvAncillaryBuffer::from(cmsg_space.as_mut_slice());

let mut buffer = [0; BUFFER_SIZE];
recvmsg(
Expand Down Expand Up @@ -685,8 +686,8 @@ fn test_unix_peercred_implicit() {
)
.unwrap();

let mut cmsg_space = [0; rustix::cmsg_space!(ScmCredentials(1))];
let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg_space);
let mut cmsg_space = [MaybeUninit::uninit(); rustix::cmsg_space!(ScmCredentials(1))];
let mut cmsg_buffer = RecvAncillaryBuffer::from(cmsg_space.as_mut_slice());

let mut buffer = [0; BUFFER_SIZE];
recvmsg(
Expand Down Expand Up @@ -737,13 +738,14 @@ fn test_unix_msg_with_combo() {
let mut yet_another_pipe_end = None;

let mut buffer = [0; BUFFER_SIZE];
let mut cmsg_space = [0; rustix::cmsg_space!(ScmRights(2), ScmRights(1))];
let mut cmsg_space =
[MaybeUninit::uninit(); rustix::cmsg_space!(ScmRights(2), ScmRights(1))];

'exit: loop {
let data_socket = accept(&connection_socket).unwrap();
let mut sum = 0;
loop {
let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg_space);
let mut cmsg_buffer = RecvAncillaryBuffer::from(cmsg_space.as_mut_slice());
let nread = recvmsg(
&data_socket,
&mut [IoSliceMut::new(&mut buffer)],
Expand Down Expand Up @@ -859,8 +861,8 @@ fn test_unix_msg_with_combo() {

let data_socket = socket(AddressFamily::UNIX, SocketType::SEQPACKET, None).unwrap();

let mut space = [0; rustix::cmsg_space!(ScmRights(2), ScmRights(1))];
let mut cmsg_buffer = SendAncillaryBuffer::new(&mut space);
let mut space = [MaybeUninit::uninit(); rustix::cmsg_space!(ScmRights(2), ScmRights(1))];
let mut cmsg_buffer = SendAncillaryBuffer::from(space.as_mut_slice());

// Format a CMSG.
let we = [write_end.as_fd(), another_write_end.as_fd()];
Expand Down
27 changes: 15 additions & 12 deletions tests/net/unix_alloc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -445,13 +445,14 @@ fn test_unix_msg_with_scm_rights() {
let mut pipe_end = None;

let mut buffer = vec![0; BUFFER_SIZE];
let mut cmsg_space = vec![0; rustix::cmsg_space!(ScmRights(1))];
let mut cmsg_space = Vec::with_capacity(rustix::cmsg_space!(ScmRights(1)));

'exit: loop {
let data_socket = accept(&connection_socket).unwrap();
let mut sum = 0;
loop {
let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg_space);
let mut cmsg_buffer =
RecvAncillaryBuffer::from(cmsg_space.spare_capacity_mut());
let nread = recvmsg(
&data_socket,
&mut [IoSliceMut::new(&mut buffer)],
Expand Down Expand Up @@ -553,8 +554,8 @@ fn test_unix_msg_with_scm_rights() {
// Format the CMSG.
let we = [write_end.as_fd()];
let msg = SendAncillaryMessage::ScmRights(&we);
let mut space = vec![0; msg.size()];
let mut cmsg_buffer = SendAncillaryBuffer::new(&mut space);
let mut space = Vec::with_capacity(msg.size());
let mut cmsg_buffer = SendAncillaryBuffer::from(space.spare_capacity_mut());
assert!(cmsg_buffer.push(msg));

connect_unix(&data_socket, &addr).unwrap();
Expand Down Expand Up @@ -618,8 +619,8 @@ fn test_unix_peercred() {
assert_eq!(ucred.gid, getgid());

let msg = SendAncillaryMessage::ScmCredentials(ucred);
let mut space = vec![0; msg.size()];
let mut cmsg_buffer = SendAncillaryBuffer::new(&mut space);
let mut space = Vec::with_capacity(msg.size());
let mut cmsg_buffer = SendAncillaryBuffer::from(space.spare_capacity_mut());
assert!(cmsg_buffer.push(msg));

sendmsg(
Expand All @@ -630,8 +631,8 @@ fn test_unix_peercred() {
)
.unwrap();

let mut cmsg_space = vec![0; rustix::cmsg_space!(ScmCredentials(1))];
let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg_space);
let mut cmsg_space = Vec::with_capacity(rustix::cmsg_space!(ScmCredentials(1)));
let mut cmsg_buffer = RecvAncillaryBuffer::from(cmsg_space.spare_capacity_mut());

let mut buffer = vec![0; BUFFER_SIZE];
recvmsg(
Expand Down Expand Up @@ -682,13 +683,15 @@ fn test_unix_msg_with_combo() {
let mut yet_another_pipe_end = None;

let mut buffer = vec![0; BUFFER_SIZE];
let mut cmsg_space = vec![0; rustix::cmsg_space!(ScmRights(1), ScmRights(2))];
let mut cmsg_space =
Vec::with_capacity(rustix::cmsg_space!(ScmRights(1), ScmRights(2)));

'exit: loop {
let data_socket = accept(&connection_socket).unwrap();
let mut sum = 0;
loop {
let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg_space);
let mut cmsg_buffer =
RecvAncillaryBuffer::from(cmsg_space.spare_capacity_mut());
let nread = recvmsg(
&data_socket,
&mut [IoSliceMut::new(&mut buffer)],
Expand Down Expand Up @@ -804,8 +807,8 @@ fn test_unix_msg_with_combo() {

let data_socket = socket(AddressFamily::UNIX, SocketType::SEQPACKET, None).unwrap();

let mut space = vec![0; rustix::cmsg_space!(ScmRights(1), ScmRights(2))];
let mut cmsg_buffer = SendAncillaryBuffer::new(&mut space);
let mut space = Vec::with_capacity(rustix::cmsg_space!(ScmRights(1), ScmRights(2)));
let mut cmsg_buffer = SendAncillaryBuffer::from(space.spare_capacity_mut());

// Format a CMSG.
let we = [write_end.as_fd(), another_write_end.as_fd()];
Expand Down

0 comments on commit 70b188a

Please sign in to comment.