Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change Socket::recv to accept [MaybeUninit<u8>] #161

Merged
merged 1 commit into from
Dec 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 51 additions & 7 deletions src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use std::fmt;
use std::io::{self, Read, Write};
#[cfg(not(target_os = "redox"))]
use std::io::{IoSlice, IoSliceMut};
use std::mem::MaybeUninit;
use std::net::{self, Ipv4Addr, Ipv6Addr, Shutdown};
#[cfg(unix)]
use std::os::unix::io::{FromRawFd, IntoRawFd};
Expand Down Expand Up @@ -284,7 +285,20 @@ impl Socket {
/// This method might fail if the socket is not connected.
///
/// [`connect`]: Socket::connect
pub fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
///
/// # Safety
///
/// Normally casting a `&mut [u8]` to `&mut [MaybeUninit<u8>]` would be
/// unsound, as that allows us to write uninitialised bytes to the buffer.
/// However this implementation promises to not write uninitialised bytes to
/// the `buf`fer and passes it directly to `recv(2)` system call. This
/// promise ensures that this function can be called using a `buf`fer of
/// type `&mut [u8]`.
///
/// Note that the [`io::Read::read`] implementation calls this function with
/// a `buf`fer of type `&mut [u8]`, allowing initialised buffers to be used
/// without using `unsafe`.
pub fn recv(&self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
self.recv_with_flags(buf, 0)
}

Expand All @@ -295,15 +309,19 @@ impl Socket {
///
/// [`recv`]: Socket::recv
/// [`out_of_band_inline`]: Socket::out_of_band_inline
pub fn recv_out_of_band(&self, buf: &mut [u8]) -> io::Result<usize> {
pub fn recv_out_of_band(&self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
self.recv_with_flags(buf, sys::MSG_OOB)
}

/// Identical to [`recv`] but allows for specification of arbitrary flags to
/// the underlying `recv` call.
///
/// [`recv`]: Socket::recv
pub fn recv_with_flags(&self, buf: &mut [u8], flags: sys::c_int) -> io::Result<usize> {
pub fn recv_with_flags(
&self,
buf: &mut [MaybeUninit<u8>],
flags: sys::c_int,
) -> io::Result<usize> {
sys::recv(self.inner, buf, flags)
}

Expand Down Expand Up @@ -343,13 +361,27 @@ impl Socket {
///
/// Successive calls return the same data. This is accomplished by passing
/// `MSG_PEEK` as a flag to the underlying `recv` system call.
pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
///
/// # Safety
///
/// `peek` makes the same safety guarantees regarding the `buf`fer as
/// [`recv`].
///
/// [`recv`]: Socket::recv
pub fn peek(&self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
self.recv_with_flags(buf, sys::MSG_PEEK)
}

/// Receives data from the socket. On success, returns the number of bytes
/// read and the address from whence the data came.
pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SockAddr)> {
///
/// # Safety
///
/// `recv_from` makes the same safety guarantees regarding the `buf`fer as
/// [`recv`].
///
/// [`recv`]: Socket::recv
pub fn recv_from(&self, buf: &mut [MaybeUninit<u8>]) -> io::Result<(usize, SockAddr)> {
self.recv_from_with_flags(buf, 0)
}

Expand All @@ -359,7 +391,7 @@ impl Socket {
/// [`recv_from`]: Socket::recv_from
pub fn recv_from_with_flags(
&self,
buf: &mut [u8],
buf: &mut [MaybeUninit<u8>],
flags: i32,
) -> io::Result<(usize, SockAddr)> {
sys::recv_from(self.inner, buf, flags)
Expand Down Expand Up @@ -398,7 +430,14 @@ impl Socket {
///
/// On success, returns the number of bytes peeked and the address from
/// whence the data came.
pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SockAddr)> {
///
/// # Safety
///
/// `peek_from` makes the same safety guarantees regarding the `buf`fer as
/// [`recv`].
///
/// [`recv`]: Socket::recv
pub fn peek_from(&self, buf: &mut [MaybeUninit<u8>]) -> io::Result<(usize, SockAddr)> {
self.recv_from_with_flags(buf, sys::MSG_PEEK)
}

Expand Down Expand Up @@ -1241,6 +1280,9 @@ impl Socket {

impl Read for Socket {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
// Safety: the `recv` implementation promises not to write uninitialised
// bytes to the `buf`fer, so this casting is safe.
let buf = unsafe { &mut *(buf as *mut [u8] as *mut [MaybeUninit<u8>]) };
self.recv(buf)
}

Expand All @@ -1252,6 +1294,8 @@ impl Read for Socket {

impl<'a> Read for &'a Socket {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
// Safety: see other `Read::read` impl.
let buf = unsafe { &mut *(buf as *mut [u8] as *mut [MaybeUninit<u8>]) };
self.recv(buf)
}

Expand Down
8 changes: 6 additions & 2 deletions src/sys/unix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ pub(crate) fn shutdown(fd: Socket, how: Shutdown) -> io::Result<()> {
syscall!(shutdown(fd, how)).map(|_| ())
}

pub(crate) fn recv(fd: Socket, buf: &mut [u8], flags: c_int) -> io::Result<usize> {
pub(crate) fn recv(fd: Socket, buf: &mut [MaybeUninit<u8>], flags: c_int) -> io::Result<usize> {
syscall!(recv(
fd,
buf.as_mut_ptr().cast(),
Expand All @@ -426,7 +426,11 @@ pub(crate) fn recv(fd: Socket, buf: &mut [u8], flags: c_int) -> io::Result<usize
.map(|n| n as usize)
}

pub(crate) fn recv_from(fd: Socket, buf: &mut [u8], flags: c_int) -> io::Result<(usize, SockAddr)> {
pub(crate) fn recv_from(
fd: Socket,
buf: &mut [MaybeUninit<u8>],
flags: c_int,
) -> io::Result<(usize, SockAddr)> {
// Safety: `recvfrom` initialises the `SockAddr` for us.
unsafe {
SockAddr::init(|addr, addrlen| {
Expand Down
4 changes: 2 additions & 2 deletions src/sys/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ pub(crate) fn shutdown(socket: Socket, how: Shutdown) -> io::Result<()> {
syscall!(shutdown(socket, how), PartialEq::eq, sock::SOCKET_ERROR).map(|_| ())
}

pub(crate) fn recv(socket: Socket, buf: &mut [u8], flags: c_int) -> io::Result<usize> {
pub(crate) fn recv(socket: Socket, buf: &mut [MaybeUninit<u8>], flags: c_int) -> io::Result<usize> {
let res = syscall!(
recv(
socket,
Expand Down Expand Up @@ -325,7 +325,7 @@ pub(crate) fn recv_vectored(

pub(crate) fn recv_from(
socket: Socket,
buf: &mut [u8],
buf: &mut [MaybeUninit<u8>],
flags: c_int,
) -> io::Result<(usize, SockAddr)> {
// Safety: `recvfrom` initialises the `SockAddr` for us.
Expand Down
13 changes: 10 additions & 3 deletions tests/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::io::Read;
use std::io::Write;
#[cfg(not(target_os = "redox"))]
use std::io::{IoSlice, IoSliceMut};
use std::mem::MaybeUninit;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
#[cfg(unix)]
use std::os::unix::io::AsRawFd;
Expand Down Expand Up @@ -359,14 +360,14 @@ fn out_of_band() {
// this from happening we'll sleep to ensure the data is present.
thread::sleep(Duration::from_millis(10));

let mut buf = [1; DATA.len() + 1];
let mut buf = [MaybeUninit::new(1); DATA.len() + 1];
let n = receiver.recv_out_of_band(&mut buf).unwrap();
assert_eq!(n, FIRST.len());
assert_eq!(&buf[..n], FIRST);
assert_eq!(unsafe { assume_init(&buf[..n]) }, FIRST);

let n = receiver.recv(&mut buf).unwrap();
assert_eq!(n, DATA.len());
assert_eq!(&buf[..n], DATA);
assert_eq!(unsafe { assume_init(&buf[..n]) }, DATA);
}

#[test]
Expand Down Expand Up @@ -643,6 +644,12 @@ fn any_ipv4() -> SockAddr {
SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0).into()
}

/// Assume the `buf`fer to be initialised.
// TODO: replace with `MaybeUninit::slice_assume_init_ref` once stable.
unsafe fn assume_init(buf: &[MaybeUninit<u8>]) -> &[u8] {
&*(buf as *const [MaybeUninit<u8>] as *const [u8])
}

/// Macro to create a simple test to set and get a socket option.
macro_rules! test {
// Test using the `arg`ument as expected return value.
Expand Down