diff --git a/CHANGELOG.md b/CHANGELOG.md index 297c0cd..7f6c09d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +main: + * Do not add a trailing null byte to abstract namespace socket paths on Linux and Android in `bind()` and `connect()`. + * Do not strip a traling null byte from abstract namespace socket paths on Linux and Android in `local_addr()`. + v0.7.1 - 2023-11-15: * Ensure proper alginment of control message buffer in the writer. * Fix compilation on Illumos and Solaris platforms. diff --git a/src/sys.rs b/src/sys.rs index ce06ded..5a236d1 100644 --- a/src/sys.rs +++ b/src/sys.rs @@ -261,7 +261,13 @@ fn path_to_sockaddr(path: &Path) -> std::io::Result<(libc::sockaddr_un, usize)> core::ptr::copy_nonoverlapping(path.as_ptr(), sockaddr.sun_path.as_mut_ptr() as *mut u8, path.len()); sockaddr.sun_path[path.len()] = 0; let path_offset = sockaddr.sun_path.as_ptr() as usize - (&sockaddr as *const _ as usize); - Ok((sockaddr, path_offset + path.len() + 1)) + + // Do not add trailing zero byte to abstract UNIX socket paths on Linux and Android. + if cfg!(any(target_os = "linux", target_os = "android")) && path.first() == Some(&0) { + Ok((sockaddr, path_offset + path.len())) + } else { + Ok((sockaddr, path_offset + path.len() + 1)) + } } } @@ -284,13 +290,13 @@ fn sockaddr_to_path(address: &libc::sockaddr_un, len: libc::socklen_t) -> std::i let offset = sun_path.offset_from(address as *const _ as *const u8); let path = core::slice::from_raw_parts(sun_path, len as usize - offset as usize); - // Some platforms include a trailing null byte in the path length. - let path = if path.last() == Some(&0) { - &path[..path.len() - 1] + // Do not strip trailing zero byte from abstract UNIX socket paths on Linux and Android. + if cfg!(any(target_os = "linux", target_os = "android")) && path.first() == Some(&0) { + Ok(Path::new(OsStr::from_bytes(path))) } else { - path - }; - Ok(Path::new(OsStr::from_bytes(path))) + let path = path.strip_suffix(&[0]).unwrap_or(path); + Ok(Path::new(OsStr::from_bytes(path))) + } } } } diff --git a/tests/abstract_namespace.rs b/tests/abstract_namespace.rs new file mode 100644 index 0000000..df70b77 --- /dev/null +++ b/tests/abstract_namespace.rs @@ -0,0 +1,82 @@ +#![cfg(any(target_os = "linux", target_os = "android"))] + +use std::path::PathBuf; + +use assert2::{assert, let_assert}; +use tokio_seqpacket::{UnixSeqpacket, UnixSeqpacketListener}; + +#[track_caller] +fn random_abstract_name(suffix: &str) -> PathBuf { + use std::io::Read; + use std::ffi::OsString; + use std::os::unix::ffi::OsStringExt; + + let_assert!(Ok(mut urandom) = std::fs::File::open("/dev/urandom")); + let mut buffer = Vec::with_capacity(63 + suffix.len()); + buffer.resize(63, 0); + assert!(let Ok(()) = urandom.read_exact(&mut buffer[1..])); + for byte in &mut buffer[1..] { + let c = *byte % (10 + 26 + 26); + if c < 10 { + *byte = b'0' + c; + } else if c < 10 + 26 { + *byte = b'A' + c - 10; + } else { + *byte = b'a' + c - 10 - 26; + } + } + buffer.extend(suffix.bytes()); + OsString::from_vec(buffer).into() +} + +/// Create a listening socket with an abstract name, connect to it and exchange a message. +/// +/// Use an abstract socket path without terminating null byte. +#[tokio::test] +async fn address_without_null_byte() { + let name = random_abstract_name("\x01"); + assert!(name.as_os_str().as_encoded_bytes().ends_with(&[1]), "{name:?}"); + + let_assert!(Ok(mut listener) = UnixSeqpacketListener::bind(&name)); + let_assert!(Ok(local_addr) = listener.local_addr()); + assert!(local_addr == name); + + let (server_socket, client_socket) = tokio::join!( + listener.accept(), + UnixSeqpacket::connect(name), + ); + let_assert!(Ok(server_socket) = server_socket); + let_assert!(Ok(client_socket) = client_socket); + + assert!(let Ok(12) = client_socket.send(b"Hello world!").await); + + let mut buffer = [0u8; 128]; + assert!(let Ok(12) = server_socket.recv(&mut buffer).await); + assert!(&buffer[..12] == b"Hello world!"); +} + +/// Create a listening socket with an abstract name, connect to it and exchange a message. +/// +/// Use an abstract socket path with terminating null byte. +#[tokio::test] +async fn address_ending_with_null_byte() { + let name = random_abstract_name("\x00"); + assert!(name.as_os_str().as_encoded_bytes().ends_with(&[0]), "{name:?}"); + + let_assert!(Ok(mut listener) = UnixSeqpacketListener::bind(&name)); + let_assert!(Ok(local_addr) = listener.local_addr()); + assert!(local_addr == name); + + let (server_socket, client_socket) = tokio::join!( + listener.accept(), + UnixSeqpacket::connect(name), + ); + let_assert!(Ok(server_socket) = server_socket); + let_assert!(Ok(client_socket) = client_socket); + + assert!(let Ok(12) = client_socket.send(b"Hello world!").await); + + let mut buffer = [0u8; 128]; + assert!(let Ok(12) = server_socket.recv(&mut buffer).await); + assert!(&buffer[..12] == b"Hello world!"); +}