From 4606ec3b4652abf505a91fc8385c94ffb15a39b1 Mon Sep 17 00:00:00 2001 From: Thomas de Zeeuw Date: Mon, 21 Dec 2020 15:20:36 +0100 Subject: [PATCH] Change Socket::recv to accept [MaybeUninit] Allow uninitialised buffers to be used. Also in the following functions: * Socket::recv_out_of_band * Socket::recv_with_flags * Socket::peek * Socket::recv_from * Socket::recv_from_with_flags * Socket::peek_from --- src/socket.rs | 58 ++++++++++++++++++++++++++++++++++++++++------ src/sys/unix.rs | 8 +++++-- src/sys/windows.rs | 4 ++-- tests/socket.rs | 13 ++++++++--- 4 files changed, 69 insertions(+), 14 deletions(-) diff --git a/src/socket.rs b/src/socket.rs index c025cfa8..f8afa29d 100644 --- a/src/socket.rs +++ b/src/socket.rs @@ -10,6 +10,7 @@ use std::fmt; use std::io::{self, Read, Write}; #[cfg(not(target_os = "redox"))] use std::io::{IoSlice, IoSliceMut}; +use std::mem::MaybeUninit; use std::net::{self, Ipv4Addr, Ipv6Addr, Shutdown}; #[cfg(unix)] use std::os::unix::io::{FromRawFd, IntoRawFd}; @@ -284,7 +285,20 @@ impl Socket { /// This method might fail if the socket is not connected. /// /// [`connect`]: Socket::connect - pub fn recv(&self, buf: &mut [u8]) -> io::Result { + /// + /// # Safety + /// + /// Normally casting a `&mut [u8]` to `&mut [MaybeUninit]` would be + /// unsound, as that allows us to write uninitialised bytes to the buffer. + /// However this implementation promises to not write uninitialised bytes to + /// the `buf`fer and passes it directly to `recv(2)` system call. This + /// promise ensures that this function can be called using a `buf`fer of + /// type `&mut [u8]`. + /// + /// Note that the [`io::Read::read`] implementation calls this function with + /// a `buf`fer of type `&mut [u8]`, allowing initialised buffers to be used + /// without using `unsafe`. + pub fn recv(&self, buf: &mut [MaybeUninit]) -> io::Result { self.recv_with_flags(buf, 0) } @@ -295,7 +309,7 @@ impl Socket { /// /// [`recv`]: Socket::recv /// [`out_of_band_inline`]: Socket::out_of_band_inline - pub fn recv_out_of_band(&self, buf: &mut [u8]) -> io::Result { + pub fn recv_out_of_band(&self, buf: &mut [MaybeUninit]) -> io::Result { self.recv_with_flags(buf, sys::MSG_OOB) } @@ -303,7 +317,11 @@ impl Socket { /// the underlying `recv` call. /// /// [`recv`]: Socket::recv - pub fn recv_with_flags(&self, buf: &mut [u8], flags: sys::c_int) -> io::Result { + pub fn recv_with_flags( + &self, + buf: &mut [MaybeUninit], + flags: sys::c_int, + ) -> io::Result { sys::recv(self.inner, buf, flags) } @@ -343,13 +361,27 @@ impl Socket { /// /// Successive calls return the same data. This is accomplished by passing /// `MSG_PEEK` as a flag to the underlying `recv` system call. - pub fn peek(&self, buf: &mut [u8]) -> io::Result { + /// + /// # Safety + /// + /// `peek` makes the same safety guarantees regarding the `buf`fer as + /// [`recv`]. + /// + /// [`recv`]: Socket::recv + pub fn peek(&self, buf: &mut [MaybeUninit]) -> io::Result { self.recv_with_flags(buf, sys::MSG_PEEK) } /// Receives data from the socket. On success, returns the number of bytes /// read and the address from whence the data came. - pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SockAddr)> { + /// + /// # Safety + /// + /// `recv_from` makes the same safety guarantees regarding the `buf`fer as + /// [`recv`]. + /// + /// [`recv`]: Socket::recv + pub fn recv_from(&self, buf: &mut [MaybeUninit]) -> io::Result<(usize, SockAddr)> { self.recv_from_with_flags(buf, 0) } @@ -359,7 +391,7 @@ impl Socket { /// [`recv_from`]: Socket::recv_from pub fn recv_from_with_flags( &self, - buf: &mut [u8], + buf: &mut [MaybeUninit], flags: i32, ) -> io::Result<(usize, SockAddr)> { sys::recv_from(self.inner, buf, flags) @@ -398,7 +430,14 @@ impl Socket { /// /// On success, returns the number of bytes peeked and the address from /// whence the data came. - pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SockAddr)> { + /// + /// # Safety + /// + /// `peek_from` makes the same safety guarantees regarding the `buf`fer as + /// [`recv`]. + /// + /// [`recv`]: Socket::recv + pub fn peek_from(&self, buf: &mut [MaybeUninit]) -> io::Result<(usize, SockAddr)> { self.recv_from_with_flags(buf, sys::MSG_PEEK) } @@ -1241,6 +1280,9 @@ impl Socket { impl Read for Socket { fn read(&mut self, buf: &mut [u8]) -> io::Result { + // Safety: the `recv` implementation promises not to write uninitialised + // bytes to the `buf`fer, so this casting is safe. + let buf = unsafe { &mut *(buf as *mut [u8] as *mut [MaybeUninit]) }; self.recv(buf) } @@ -1252,6 +1294,8 @@ impl Read for Socket { impl<'a> Read for &'a Socket { fn read(&mut self, buf: &mut [u8]) -> io::Result { + // Safety: see other `Read::read` impl. + let buf = unsafe { &mut *(buf as *mut [u8] as *mut [MaybeUninit]) }; self.recv(buf) } diff --git a/src/sys/unix.rs b/src/sys/unix.rs index 84699b78..99270780 100644 --- a/src/sys/unix.rs +++ b/src/sys/unix.rs @@ -416,7 +416,7 @@ pub(crate) fn shutdown(fd: Socket, how: Shutdown) -> io::Result<()> { syscall!(shutdown(fd, how)).map(|_| ()) } -pub(crate) fn recv(fd: Socket, buf: &mut [u8], flags: c_int) -> io::Result { +pub(crate) fn recv(fd: Socket, buf: &mut [MaybeUninit], flags: c_int) -> io::Result { syscall!(recv( fd, buf.as_mut_ptr().cast(), @@ -426,7 +426,11 @@ pub(crate) fn recv(fd: Socket, buf: &mut [u8], flags: c_int) -> io::Result io::Result<(usize, SockAddr)> { +pub(crate) fn recv_from( + fd: Socket, + buf: &mut [MaybeUninit], + flags: c_int, +) -> io::Result<(usize, SockAddr)> { // Safety: `recvfrom` initialises the `SockAddr` for us. unsafe { SockAddr::init(|addr, addrlen| { diff --git a/src/sys/windows.rs b/src/sys/windows.rs index 3720eefc..e595b58e 100644 --- a/src/sys/windows.rs +++ b/src/sys/windows.rs @@ -273,7 +273,7 @@ pub(crate) fn shutdown(socket: Socket, how: Shutdown) -> io::Result<()> { syscall!(shutdown(socket, how), PartialEq::eq, sock::SOCKET_ERROR).map(|_| ()) } -pub(crate) fn recv(socket: Socket, buf: &mut [u8], flags: c_int) -> io::Result { +pub(crate) fn recv(socket: Socket, buf: &mut [MaybeUninit], flags: c_int) -> io::Result { let res = syscall!( recv( socket, @@ -325,7 +325,7 @@ pub(crate) fn recv_vectored( pub(crate) fn recv_from( socket: Socket, - buf: &mut [u8], + buf: &mut [MaybeUninit], flags: c_int, ) -> io::Result<(usize, SockAddr)> { // Safety: `recvfrom` initialises the `SockAddr` for us. diff --git a/tests/socket.rs b/tests/socket.rs index 5200fc42..2b44b60e 100644 --- a/tests/socket.rs +++ b/tests/socket.rs @@ -7,6 +7,7 @@ use std::io::Read; use std::io::Write; #[cfg(not(target_os = "redox"))] use std::io::{IoSlice, IoSliceMut}; +use std::mem::MaybeUninit; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; #[cfg(unix)] use std::os::unix::io::AsRawFd; @@ -359,14 +360,14 @@ fn out_of_band() { // this from happening we'll sleep to ensure the data is present. thread::sleep(Duration::from_millis(10)); - let mut buf = [1; DATA.len() + 1]; + let mut buf = [MaybeUninit::new(1); DATA.len() + 1]; let n = receiver.recv_out_of_band(&mut buf).unwrap(); assert_eq!(n, FIRST.len()); - assert_eq!(&buf[..n], FIRST); + assert_eq!(unsafe { assume_init(&buf[..n]) }, FIRST); let n = receiver.recv(&mut buf).unwrap(); assert_eq!(n, DATA.len()); - assert_eq!(&buf[..n], DATA); + assert_eq!(unsafe { assume_init(&buf[..n]) }, DATA); } #[test] @@ -643,6 +644,12 @@ fn any_ipv4() -> SockAddr { SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0).into() } +/// Assume the `buf`fer to be initialised. +// TODO: replace with `MaybeUninit::slice_assume_init_ref` once stable. +unsafe fn assume_init(buf: &[MaybeUninit]) -> &[u8] { + &*(buf as *const [MaybeUninit] as *const [u8]) +} + /// Macro to create a simple test to set and get a socket option. macro_rules! test { // Test using the `arg`ument as expected return value.