diff --git a/compio-net/src/unix.rs b/compio-net/src/unix.rs index ae3ff654..0eff2b02 100644 --- a/compio-net/src/unix.rs +++ b/compio-net/src/unix.rs @@ -52,6 +52,13 @@ impl UnixListener { /// the specified file path. The file path cannot yet exist, and will be /// cleaned up upon dropping [`UnixListener`] pub fn bind_addr(addr: &SockAddr) -> io::Result { + if !addr.is_unix() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "addr is not unix socket address", + )); + } + let socket = Socket::bind(addr, Type::STREAM, None)?; socket.listen(1024)?; Ok(UnixListener { inner: socket }) @@ -129,27 +136,16 @@ impl UnixStream { /// [`UnixListener`] or equivalent listening on the corresponding Unix /// domain socket to successfully connect and return a `UnixStream`. pub async fn connect_addr(addr: &SockAddr) -> io::Result { + if !addr.is_unix() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "addr is not unix socket address", + )); + } + #[cfg(windows)] let socket = { - use windows_sys::Win32::Networking::WinSock::{AF_UNIX, SOCKADDR_UN}; - - let new_addr = unsafe { - SockAddr::try_init(|addr, len| { - let addr: *mut SOCKADDR_UN = addr.cast(); - std::ptr::write( - addr, - SOCKADDR_UN { - sun_family: AF_UNIX, - sun_path: [0; 108], - }, - ); - std::ptr::write(len, 3); - Ok(()) - }) - } - // it is always Ok - .unwrap() - .1; + let new_addr = empty_unix_socket(); Socket::bind(&new_addr, Type::STREAM, None)? }; #[cfg(unix)] @@ -181,21 +177,9 @@ impl UnixStream { pub fn peer_addr(&self) -> io::Result { #[allow(unused_mut)] let mut addr = self.inner.peer_addr()?; - // The peer addr returned after ConnectEx is buggy. It contains bytes that - // should not belong to the address. Luckily a unix path should not contain `\0` - // until the end. We can determine the path ending by that. #[cfg(windows)] { - use windows_sys::Win32::Networking::WinSock::SOCKADDR_UN; - - let unix_addr: &SOCKADDR_UN = unsafe { &*addr.as_ptr().cast() }; - let addr_len = match std::ffi::CStr::from_bytes_until_nul(&unix_addr.sun_path) { - Ok(str) => str.to_bytes_with_nul().len() + 2, - Err(_) => std::mem::size_of::(), - }; - unsafe { - addr.set_length(addr_len as _); - } + fix_unix_socket_length(&mut addr); } Ok(addr) } @@ -296,3 +280,47 @@ impl AsyncWrite for &UnixStream { impl_try_as_raw_fd!(UnixStream, inner); impl_attachable!(UnixStream, inner); + +#[cfg(windows)] +#[inline] +fn empty_unix_socket() -> SockAddr { + use windows_sys::Win32::Networking::WinSock::{AF_UNIX, SOCKADDR_UN}; + + // SAFETY: the length is correct + unsafe { + SockAddr::try_init(|addr, len| { + let addr: *mut SOCKADDR_UN = addr.cast(); + std::ptr::write( + addr, + SOCKADDR_UN { + sun_family: AF_UNIX, + sun_path: [0; 108], + }, + ); + std::ptr::write(len, 3); + Ok(()) + }) + } + // it is always Ok + .unwrap() + .1 +} + +// The peer addr returned after ConnectEx is buggy. It contains bytes that +// should not belong to the address. Luckily a unix path should not contain `\0` +// until the end. We can determine the path ending by that. +#[cfg(windows)] +#[inline] +fn fix_unix_socket_length(addr: &mut SockAddr) { + use windows_sys::Win32::Networking::WinSock::SOCKADDR_UN; + + // SAFETY: cannot construct non-unix socket address in safe way. + let unix_addr: &SOCKADDR_UN = unsafe { &*addr.as_ptr().cast() }; + let addr_len = match std::ffi::CStr::from_bytes_until_nul(&unix_addr.sun_path) { + Ok(str) => str.to_bytes_with_nul().len() + 2, + Err(_) => std::mem::size_of::(), + }; + unsafe { + addr.set_length(addr_len as _); + } +}