diff --git a/compio-driver/src/iocp/op.rs b/compio-driver/src/iocp/op.rs index a725709e..2b7212f2 100644 --- a/compio-driver/src/iocp/op.rs +++ b/compio-driver/src/iocp/op.rs @@ -11,7 +11,9 @@ use std::{ }; use aligned_array::{Aligned, A8}; -use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut}; +use compio_buf::{ + BufResult, IntoInner, IoBuf, IoBufMut, IoSlice, IoSliceMut, IoVectoredBuf, IoVectoredBufMut, +}; #[cfg(not(feature = "once_cell_try"))] use once_cell::sync::OnceCell as OnceLock; use socket2::SockAddr; @@ -25,10 +27,11 @@ use windows_sys::{ }, Networking::WinSock::{ closesocket, setsockopt, shutdown, socklen_t, WSAIoctl, WSARecv, WSARecvFrom, WSASend, - WSASendTo, LPFN_ACCEPTEX, LPFN_CONNECTEX, LPFN_GETACCEPTEXSOCKADDRS, SD_BOTH, - SD_RECEIVE, SD_SEND, SIO_GET_EXTENSION_FUNCTION_POINTER, SOCKADDR, SOCKADDR_STORAGE, - SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT, SO_UPDATE_CONNECT_CONTEXT, WSAID_ACCEPTEX, - WSAID_CONNECTEX, WSAID_GETACCEPTEXSOCKADDRS, + WSASendMsg, WSASendTo, CMSGHDR, LPFN_ACCEPTEX, LPFN_CONNECTEX, + LPFN_GETACCEPTEXSOCKADDRS, LPFN_WSARECVMSG, SD_BOTH, SD_RECEIVE, SD_SEND, + SIO_GET_EXTENSION_FUNCTION_POINTER, SOCKADDR, SOCKADDR_STORAGE, SOL_SOCKET, + SO_UPDATE_ACCEPT_CONTEXT, SO_UPDATE_CONNECT_CONTEXT, WSABUF, WSAID_ACCEPTEX, + WSAID_CONNECTEX, WSAID_GETACCEPTEXSOCKADDRS, WSAID_WSARECVMSG, WSAMSG, }, Storage::FileSystem::{FlushFileBuffers, ReadFile, WriteFile}, System::{ @@ -774,6 +777,146 @@ impl OpCode for SendToVectored { } } +static WSA_RECVMSG: OnceLock = OnceLock::new(); + +/// Receive data and source address with ancillary data into vectored buffer. +pub struct RecvMsg { + addr: SOCKADDR_STORAGE, + addr_len: socklen_t, + fd: SharedFd, + buffer: T, + control: C, + _p: PhantomPinned, +} + +impl RecvMsg { + /// Create [`RecvMsg`]. + /// + /// # Panics + /// + /// This function will panic if the control message buffer is misaligned. + pub fn new(fd: SharedFd, buffer: T, control: C) -> Self { + assert!( + control.as_buf_ptr().cast::().is_aligned(), + "misaligned control message buffer" + ); + Self { + addr: unsafe { std::mem::zeroed() }, + addr_len: std::mem::size_of::() as _, + fd, + buffer, + control, + _p: PhantomPinned, + } + } +} + +impl IntoInner for RecvMsg { + type Inner = ((T, C), SOCKADDR_STORAGE, socklen_t); + + fn into_inner(self) -> Self::Inner { + ((self.buffer, self.control), self.addr, self.addr_len) + } +} + +impl OpCode for RecvMsg { + unsafe fn operate(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> Poll> { + let recvmsg_fn = WSA_RECVMSG + .get_or_try_init(|| get_wsa_fn(self.fd.as_raw_fd(), WSAID_WSARECVMSG))? + .ok_or_else(|| { + io::Error::new(io::ErrorKind::Unsupported, "cannot retrieve WSARecvMsg") + })?; + + let this = self.get_unchecked_mut(); + let mut slices = this.buffer.as_io_slices_mut(); + let mut msg = WSAMSG { + name: &mut this.addr as *mut _ as _, + namelen: this.addr_len, + lpBuffers: slices.as_mut_ptr() as _, + dwBufferCount: slices.len() as _, + Control: std::mem::transmute::(this.control.as_io_slice_mut()), + dwFlags: 0, + }; + + let mut received = 0; + let res = recvmsg_fn( + this.fd.as_raw_fd() as _, + &mut msg, + &mut received, + optr, + None, + ); + winsock_result(res, received) + } + + unsafe fn cancel(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> io::Result<()> { + cancel(self.fd.as_raw_fd(), optr) + } +} + +/// Send data to specified address accompanied by ancillary data from vectored +/// buffer. +pub struct SendMsg { + fd: SharedFd, + buffer: T, + control: C, + addr: SockAddr, + _p: PhantomPinned, +} + +impl SendMsg { + /// Create [`SendMsg`]. + /// + /// # Panics + /// + /// This function will panic if the control message buffer is misaligned. + pub fn new(fd: SharedFd, buffer: T, control: C, addr: SockAddr) -> Self { + assert!( + control.as_buf_ptr().cast::().is_aligned(), + "misaligned control message buffer" + ); + Self { + fd, + buffer, + control, + addr, + _p: PhantomPinned, + } + } +} + +impl IntoInner for SendMsg { + type Inner = (T, C); + + fn into_inner(self) -> Self::Inner { + (self.buffer, self.control) + } +} + +impl OpCode for SendMsg { + unsafe fn operate(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> Poll> { + let this = self.get_unchecked_mut(); + + let slices = this.buffer.as_io_slices(); + let msg = WSAMSG { + name: this.addr.as_ptr() as _, + namelen: this.addr.len(), + lpBuffers: slices.as_ptr() as _, + dwBufferCount: slices.len() as _, + Control: std::mem::transmute::(this.control.as_io_slice()), + dwFlags: 0, + }; + + let mut sent = 0; + let res = WSASendMsg(this.fd.as_raw_fd() as _, &msg, 0, &mut sent, optr, None); + winsock_result(res, sent) + } + + unsafe fn cancel(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> io::Result<()> { + cancel(self.fd.as_raw_fd(), optr) + } +} + /// Connect a named pipe server. pub struct ConnectNamedPipe { pub(crate) fd: SharedFd, diff --git a/compio-driver/src/iour/op.rs b/compio-driver/src/iour/op.rs index 75ba51d4..c3a38a3b 100644 --- a/compio-driver/src/iour/op.rs +++ b/compio-driver/src/iour/op.rs @@ -556,6 +556,26 @@ impl IntoInner for SendToVectored { } } +impl OpCode for RecvMsg { + fn create_entry(self: Pin<&mut Self>) -> OpEntry { + let this = unsafe { self.get_unchecked_mut() }; + unsafe { this.set_msg() }; + opcode::RecvMsg::new(Fd(this.fd.as_raw_fd()), &mut this.msg) + .build() + .into() + } +} + +impl OpCode for SendMsg { + fn create_entry(self: Pin<&mut Self>) -> OpEntry { + let this = unsafe { self.get_unchecked_mut() }; + unsafe { this.set_msg() }; + opcode::SendMsg::new(Fd(this.fd.as_raw_fd()), &this.msg) + .build() + .into() + } +} + impl OpCode for PollOnce { fn create_entry(self: Pin<&mut Self>) -> OpEntry { let flags = match self.interest { diff --git a/compio-driver/src/op.rs b/compio-driver/src/op.rs index b453275d..89b60d10 100644 --- a/compio-driver/src/op.rs +++ b/compio-driver/src/op.rs @@ -11,8 +11,8 @@ use socket2::SockAddr; #[cfg(windows)] pub use crate::sys::op::ConnectNamedPipe; pub use crate::sys::op::{ - Accept, Recv, RecvFrom, RecvFromVectored, RecvVectored, Send, SendTo, SendToVectored, - SendVectored, + Accept, Recv, RecvFrom, RecvFromVectored, RecvMsg, RecvVectored, Send, SendMsg, SendTo, + SendToVectored, SendVectored, }; #[cfg(unix)] pub use crate::sys::op::{ diff --git a/compio-driver/src/poll/op.rs b/compio-driver/src/poll/op.rs index 3107b1d2..3a05849c 100644 --- a/compio-driver/src/poll/op.rs +++ b/compio-driver/src/poll/op.rs @@ -749,6 +749,47 @@ impl IntoInner for SendToVectored { } } +impl RecvMsg { + unsafe fn call(&mut self) -> libc::ssize_t { + libc::recvmsg(self.fd.as_raw_fd(), &mut self.msg, 0) + } +} + +impl OpCode for RecvMsg { + fn pre_submit(self: Pin<&mut Self>) -> io::Result { + let this = unsafe { self.get_unchecked_mut() }; + unsafe { this.set_msg() }; + syscall!(this.call(), wait_readable(this.fd.as_raw_fd())) + } + + fn on_event(self: Pin<&mut Self>, event: &Event) -> Poll> { + debug_assert!(event.readable); + + let this = unsafe { self.get_unchecked_mut() }; + syscall!(break this.call()) + } +} + +impl SendMsg { + unsafe fn call(&self) -> libc::ssize_t { + libc::sendmsg(self.fd.as_raw_fd(), &self.msg, 0) + } +} + +impl OpCode for SendMsg { + fn pre_submit(self: Pin<&mut Self>) -> io::Result { + let this = unsafe { self.get_unchecked_mut() }; + unsafe { this.set_msg() }; + syscall!(this.call(), wait_writable(this.fd.as_raw_fd())) + } + + fn on_event(self: Pin<&mut Self>, event: &Event) -> Poll> { + debug_assert!(event.writable); + + syscall!(break self.call()) + } +} + impl OpCode for PollOnce { fn pre_submit(self: Pin<&mut Self>) -> io::Result { Ok(Decision::wait_for(self.fd.as_raw_fd(), self.interest)) diff --git a/compio-driver/src/unix/op.rs b/compio-driver/src/unix/op.rs index b10f0377..aff5899a 100644 --- a/compio-driver/src/unix/op.rs +++ b/compio-driver/src/unix/op.rs @@ -370,6 +370,113 @@ impl IntoInner for SendVectored { } } +/// Receive data and source address with ancillary data into vectored buffer. +pub struct RecvMsg { + pub(crate) msg: libc::msghdr, + pub(crate) addr: sockaddr_storage, + pub(crate) fd: SharedFd, + pub(crate) buffer: T, + pub(crate) control: C, + pub(crate) slices: Vec, + _p: PhantomPinned, +} + +impl RecvMsg { + /// Create [`RecvMsg`]. + /// + /// # Panics + /// + /// This function will panic if the control message buffer is misaligned. + pub fn new(fd: SharedFd, buffer: T, control: C) -> Self { + assert!( + control.as_buf_ptr().cast::().is_aligned(), + "misaligned control message buffer" + ); + Self { + addr: unsafe { std::mem::zeroed() }, + msg: unsafe { std::mem::zeroed() }, + fd, + buffer, + control, + slices: vec![], + _p: PhantomPinned, + } + } + + pub(crate) unsafe fn set_msg(&mut self) { + self.slices = self.buffer.as_io_slices_mut(); + + self.msg.msg_name = std::ptr::addr_of_mut!(self.addr) as _; + self.msg.msg_namelen = std::mem::size_of_val(&self.addr) as _; + self.msg.msg_iov = self.slices.as_mut_ptr() as _; + self.msg.msg_iovlen = self.slices.len() as _; + self.msg.msg_control = self.control.as_buf_mut_ptr() as _; + self.msg.msg_controllen = self.control.buf_len() as _; + } +} + +impl IntoInner for RecvMsg { + type Inner = ((T, C), sockaddr_storage, socklen_t); + + fn into_inner(self) -> Self::Inner { + ((self.buffer, self.control), self.addr, self.msg.msg_namelen) + } +} + +/// Send data to specified address accompanied by ancillary data from vectored +/// buffer. +pub struct SendMsg { + pub(crate) msg: libc::msghdr, + pub(crate) fd: SharedFd, + pub(crate) buffer: T, + pub(crate) control: C, + pub(crate) addr: SockAddr, + pub(crate) slices: Vec, + _p: PhantomPinned, +} + +impl SendMsg { + /// Create [`SendMsg`]. + /// + /// # Panics + /// + /// This function will panic if the control message buffer is misaligned. + pub fn new(fd: SharedFd, buffer: T, control: C, addr: SockAddr) -> Self { + assert!( + control.as_buf_ptr().cast::().is_aligned(), + "misaligned control message buffer" + ); + Self { + msg: unsafe { std::mem::zeroed() }, + fd, + buffer, + control, + addr, + slices: vec![], + _p: PhantomPinned, + } + } + + pub(crate) unsafe fn set_msg(&mut self) { + self.slices = self.buffer.as_io_slices(); + + self.msg.msg_name = self.addr.as_ptr() as _; + self.msg.msg_namelen = self.addr.len(); + self.msg.msg_iov = self.slices.as_ptr() as _; + self.msg.msg_iovlen = self.slices.len() as _; + self.msg.msg_control = self.control.as_buf_ptr() as _; + self.msg.msg_controllen = self.control.buf_len() as _; + } +} + +impl IntoInner for SendMsg { + type Inner = (T, C); + + fn into_inner(self) -> Self::Inner { + (self.buffer, self.control) + } +} + /// The interest to poll a file descriptor. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Interest { diff --git a/compio-net/src/cmsg/mod.rs b/compio-net/src/cmsg/mod.rs new file mode 100644 index 00000000..8614fde0 --- /dev/null +++ b/compio-net/src/cmsg/mod.rs @@ -0,0 +1,130 @@ +use std::marker::PhantomData; + +cfg_if::cfg_if! { + if #[cfg(windows)] { + #[path = "windows.rs"] + mod sys; + } else if #[cfg(unix)] { + #[path = "unix.rs"] + mod sys; + } +} + +/// Reference to a control message. +pub struct CMsgRef<'a>(sys::CMsgRef<'a>); + +impl<'a> CMsgRef<'a> { + /// Returns the level of the control message. + pub fn level(&self) -> i32 { + self.0.level() + } + + /// Returns the type of the control message. + pub fn ty(&self) -> i32 { + self.0.ty() + } + + /// Returns the length of the control message. + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> usize { + self.0.len() as _ + } + + /// Returns a reference to the data of the control message. + /// + /// # Safety + /// + /// The data part must be properly aligned and contains an initialized + /// instance of `T`. + pub unsafe fn data(&self) -> &T { + self.0.data() + } +} + +/// An iterator for control messages. +pub struct CMsgIter<'a> { + inner: sys::CMsgIter, + _p: PhantomData<&'a ()>, +} + +impl<'a> CMsgIter<'a> { + /// Create [`CMsgIter`] with the given buffer. + /// + /// # Panics + /// + /// This function will panic if the buffer is too short or not properly + /// aligned. + /// + /// # Safety + /// + /// The buffer should contain valid control messages. + pub unsafe fn new(buffer: &'a [u8]) -> Self { + Self { + inner: sys::CMsgIter::new(buffer.as_ptr(), buffer.len()), + _p: PhantomData, + } + } +} + +impl<'a> Iterator for CMsgIter<'a> { + type Item = CMsgRef<'a>; + + fn next(&mut self) -> Option { + unsafe { + let cmsg = self.inner.current(); + self.inner.next(); + cmsg.map(CMsgRef) + } + } +} + +/// Helper to construct control message. +pub struct CMsgBuilder<'a> { + inner: sys::CMsgIter, + len: usize, + _p: PhantomData<&'a mut ()>, +} + +impl<'a> CMsgBuilder<'a> { + /// Create [`CMsgBuilder`] with the given buffer. The buffer will be zeroed + /// on creation. + /// + /// # Panics + /// + /// This function will panic if the buffer is too short or not properly + /// aligned. + pub fn new(buffer: &'a mut [u8]) -> Self { + buffer.fill(0); + Self { + inner: sys::CMsgIter::new(buffer.as_ptr(), buffer.len()), + len: 0, + _p: PhantomData, + } + } + + /// Finishes building, returns length of the control message. + pub fn finish(self) -> usize { + self.len + } + + /// Try to append a control message entry into the buffer. If the buffer + /// does not have enough space or is not properly aligned with the value + /// type, returns `None`. + pub fn try_push(&mut self, level: i32, ty: i32, value: T) -> Option<()> { + if !self.inner.is_aligned::() || !self.inner.is_space_enough::() { + return None; + } + + // SAFETY: the buffer is zeroed and the pointer is valid and aligned + unsafe { + let mut cmsg = self.inner.current_mut()?; + cmsg.set_level(level); + cmsg.set_ty(ty); + self.len += cmsg.set_data(value); + + self.inner.next(); + } + + Some(()) + } +} diff --git a/compio-net/src/cmsg/unix.rs b/compio-net/src/cmsg/unix.rs new file mode 100644 index 00000000..1a1c04bf --- /dev/null +++ b/compio-net/src/cmsg/unix.rs @@ -0,0 +1,89 @@ +use libc::{c_int, cmsghdr, msghdr, CMSG_DATA, CMSG_FIRSTHDR, CMSG_LEN, CMSG_NXTHDR, CMSG_SPACE}; + +pub(crate) struct CMsgRef<'a>(&'a cmsghdr); + +impl<'a> CMsgRef<'a> { + pub(crate) fn level(&self) -> c_int { + self.0.cmsg_level + } + + pub(crate) fn ty(&self) -> c_int { + self.0.cmsg_type + } + + pub(crate) fn len(&self) -> usize { + self.0.cmsg_len as _ + } + + pub(crate) unsafe fn data(&self) -> &T { + let data_ptr = CMSG_DATA(self.0); + data_ptr.cast::().as_ref().unwrap() + } +} + +pub(crate) struct CMsgMut<'a>(&'a mut cmsghdr); + +impl<'a> CMsgMut<'a> { + pub(crate) fn set_level(&mut self, level: c_int) { + self.0.cmsg_level = level; + } + + pub(crate) fn set_ty(&mut self, ty: c_int) { + self.0.cmsg_type = ty; + } + + pub(crate) unsafe fn set_data(&mut self, data: T) -> usize { + self.0.cmsg_len = CMSG_LEN(std::mem::size_of::() as _) as _; + let data_ptr = CMSG_DATA(self.0); + std::ptr::write(data_ptr.cast::(), data); + CMSG_SPACE(std::mem::size_of::() as _) as _ + } +} + +pub(crate) struct CMsgIter { + msg: msghdr, + cmsg: *mut cmsghdr, +} + +impl CMsgIter { + pub(crate) fn new(ptr: *const u8, len: usize) -> Self { + assert!(len >= unsafe { CMSG_SPACE(0) as _ }, "buffer too short"); + assert!(ptr.cast::().is_aligned(), "misaligned buffer"); + + let mut msg: msghdr = unsafe { std::mem::zeroed() }; + msg.msg_control = ptr as _; + msg.msg_controllen = len as _; + // SAFETY: msg is initialized and valid + let cmsg = unsafe { CMSG_FIRSTHDR(&msg) }; + Self { msg, cmsg } + } + + pub(crate) unsafe fn current<'a>(&self) -> Option> { + self.cmsg.as_ref().map(CMsgRef) + } + + pub(crate) unsafe fn next(&mut self) { + if !self.cmsg.is_null() { + self.cmsg = CMSG_NXTHDR(&self.msg, self.cmsg); + } + } + + pub(crate) unsafe fn current_mut<'a>(&self) -> Option> { + self.cmsg.as_mut().map(CMsgMut) + } + + pub(crate) fn is_aligned(&self) -> bool { + self.msg.msg_control.cast::().is_aligned() + } + + pub(crate) fn is_space_enough(&self) -> bool { + if !self.cmsg.is_null() { + let space = unsafe { CMSG_SPACE(std::mem::size_of::() as _) as usize }; + #[allow(clippy::unnecessary_cast)] + let max = self.msg.msg_control as usize + self.msg.msg_controllen as usize; + self.cmsg as usize + space <= max + } else { + false + } + } +} diff --git a/compio-net/src/cmsg/windows.rs b/compio-net/src/cmsg/windows.rs new file mode 100644 index 00000000..7efd452b --- /dev/null +++ b/compio-net/src/cmsg/windows.rs @@ -0,0 +1,143 @@ +use std::{ + mem::{align_of, size_of}, + ptr::null_mut, +}; + +use windows_sys::Win32::Networking::WinSock::{CMSGHDR, WSABUF, WSAMSG}; + +// Macros from https://github.com/microsoft/win32metadata/blob/main/generation/WinSDK/RecompiledIdlHeaders/shared/ws2def.h +#[inline] +const fn wsa_cmsghdr_align(length: usize) -> usize { + (length + align_of::() - 1) & !(align_of::() - 1) +} + +// WSA_CMSGDATA_ALIGN(sizeof(CMSGHDR)) +const WSA_CMSGDATA_OFFSET: usize = + (size_of::() + align_of::() - 1) & !(align_of::() - 1); + +#[inline] +unsafe fn wsa_cmsg_firsthdr(msg: *const WSAMSG) -> *mut CMSGHDR { + if (*msg).Control.len as usize >= size_of::() { + (*msg).Control.buf as _ + } else { + null_mut() + } +} + +#[inline] +unsafe fn wsa_cmsg_nxthdr(msg: *const WSAMSG, cmsg: *const CMSGHDR) -> *mut CMSGHDR { + if cmsg.is_null() { + wsa_cmsg_firsthdr(msg) + } else { + let next = cmsg as usize + wsa_cmsghdr_align((*cmsg).cmsg_len); + if next + size_of::() > (*msg).Control.buf as usize + (*msg).Control.len as usize { + null_mut() + } else { + next as _ + } + } +} + +#[inline] +unsafe fn wsa_cmsg_data(cmsg: *const CMSGHDR) -> *mut u8 { + (cmsg as usize + WSA_CMSGDATA_OFFSET) as _ +} + +#[inline] +const fn wsa_cmsg_space(length: usize) -> usize { + WSA_CMSGDATA_OFFSET + wsa_cmsghdr_align(length) +} + +#[inline] +const fn wsa_cmsg_len(length: usize) -> usize { + WSA_CMSGDATA_OFFSET + length +} + +pub struct CMsgRef<'a>(&'a CMSGHDR); + +impl<'a> CMsgRef<'a> { + pub fn level(&self) -> i32 { + self.0.cmsg_level + } + + pub fn ty(&self) -> i32 { + self.0.cmsg_type + } + + pub fn len(&self) -> usize { + self.0.cmsg_len + } + + pub unsafe fn data(&self) -> &T { + let data_ptr = wsa_cmsg_data(self.0); + data_ptr.cast::().as_ref().unwrap() + } +} + +pub(crate) struct CMsgMut<'a>(&'a mut CMSGHDR); + +impl<'a> CMsgMut<'a> { + pub(crate) fn set_level(&mut self, level: i32) { + self.0.cmsg_level = level; + } + + pub(crate) fn set_ty(&mut self, ty: i32) { + self.0.cmsg_type = ty; + } + + pub(crate) unsafe fn set_data(&mut self, data: T) -> usize { + self.0.cmsg_len = wsa_cmsg_len(size_of::() as _) as _; + let data_ptr = wsa_cmsg_data(self.0); + std::ptr::write(data_ptr.cast::(), data); + wsa_cmsg_space(size_of::() as _) + } +} + +pub(crate) struct CMsgIter { + msg: WSAMSG, + cmsg: *mut CMSGHDR, +} + +impl CMsgIter { + pub(crate) fn new(ptr: *const u8, len: usize) -> Self { + assert!(len >= wsa_cmsg_space(0) as _, "buffer too short"); + assert!(ptr.cast::().is_aligned(), "misaligned buffer"); + + let mut msg: WSAMSG = unsafe { std::mem::zeroed() }; + msg.Control = WSABUF { + len: len as _, + buf: ptr as _, + }; + // SAFETY: msg is initialized and valid + let cmsg = unsafe { wsa_cmsg_firsthdr(&msg) }; + Self { msg, cmsg } + } + + pub(crate) unsafe fn current<'a>(&self) -> Option> { + self.cmsg.as_ref().map(CMsgRef) + } + + pub(crate) unsafe fn next(&mut self) { + if !self.cmsg.is_null() { + self.cmsg = wsa_cmsg_nxthdr(&self.msg, self.cmsg); + } + } + + pub(crate) unsafe fn current_mut<'a>(&self) -> Option> { + self.cmsg.as_mut().map(CMsgMut) + } + + pub(crate) fn is_aligned(&self) -> bool { + self.msg.Control.buf.cast::().is_aligned() + } + + pub(crate) fn is_space_enough(&self) -> bool { + if !self.cmsg.is_null() { + let space = wsa_cmsg_space(size_of::() as _); + let max = self.msg.Control.buf as usize + self.msg.Control.len as usize; + self.cmsg as usize + space <= max + } else { + false + } + } +} diff --git a/compio-net/src/lib.rs b/compio-net/src/lib.rs index 143adc2e..e34be159 100644 --- a/compio-net/src/lib.rs +++ b/compio-net/src/lib.rs @@ -5,6 +5,7 @@ #![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))] #![warn(missing_docs)] +mod cmsg; mod poll_fd; mod resolve; mod socket; @@ -13,6 +14,7 @@ mod tcp; mod udp; mod unix; +pub use cmsg::*; pub use poll_fd::*; pub use resolve::ToSocketAddrsAsync; pub(crate) use resolve::{each_addr, first_addr_buf}; diff --git a/compio-net/src/socket.rs b/compio-net/src/socket.rs index 523ac36a..0f84dd08 100644 --- a/compio-net/src/socket.rs +++ b/compio-net/src/socket.rs @@ -6,10 +6,11 @@ use compio_driver::op::CreateSocket; use compio_driver::{ impl_raw_fd, op::{ - Accept, BufResultExt, CloseSocket, Connect, Recv, RecvFrom, RecvFromVectored, - RecvResultExt, RecvVectored, Send, SendTo, SendToVectored, SendVectored, ShutdownSocket, + Accept, BufResultExt, CloseSocket, Connect, Recv, RecvFrom, RecvFromVectored, RecvMsg, + RecvResultExt, RecvVectored, Send, SendMsg, SendTo, SendToVectored, SendVectored, + ShutdownSocket, }, - ToSharedFd, + syscall, AsRawFd, ToSharedFd, }; use compio_runtime::Attacher; use socket2::{Domain, Protocol, SockAddr, Socket as Socket2, Type}; @@ -256,6 +257,36 @@ impl Socket { .map_advanced() } + pub async fn recv_msg( + &self, + buffer: T, + control: C, + ) -> BufResult<(usize, SockAddr), (T, C)> { + self.recv_msg_vectored([buffer], control) + .await + .map_buffer(|([buffer], control)| (buffer, control)) + } + + pub async fn recv_msg_vectored( + &self, + buffer: T, + control: C, + ) -> BufResult<(usize, SockAddr), (T, C)> { + let fd = self.to_shared_fd(); + let op = RecvMsg::new(fd, buffer, control); + compio_runtime::submit(op) + .await + .into_inner() + .map_addr() + .map(|(init, obj), (mut buffer, control)| { + // SAFETY: The number of bytes received would not bypass the buffer capacity. + unsafe { + buffer.set_buf_init(init); + } + ((init, obj), (buffer, control)) + }) + } + pub async fn send_to(&self, buffer: T, addr: &SockAddr) -> BufResult { let fd = self.to_shared_fd(); let op = SendTo::new(fd, buffer, addr.clone()); @@ -271,6 +302,55 @@ impl Socket { let op = SendToVectored::new(fd, buffer, addr.clone()); compio_runtime::submit(op).await.into_inner() } + + pub async fn send_msg( + &self, + buffer: T, + control: C, + addr: &SockAddr, + ) -> BufResult { + self.send_msg_vectored([buffer], control, addr) + .await + .map_buffer(|([buffer], control)| (buffer, control)) + } + + pub async fn send_msg_vectored( + &self, + buffer: T, + control: C, + addr: &SockAddr, + ) -> BufResult { + let fd = self.to_shared_fd(); + let op = SendMsg::new(fd, buffer, control, addr.clone()); + compio_runtime::submit(op).await.into_inner() + } + + #[cfg(unix)] + pub fn set_socket_option(&self, level: i32, name: i32, value: &T) -> io::Result<()> { + syscall!(libc::setsockopt( + self.socket.as_raw_fd(), + level, + name, + value as *const _ as _, + std::mem::size_of::() as _ + )) + .map(|_| ()) + } + + #[cfg(windows)] + pub fn set_socket_option(&self, level: i32, name: i32, value: &T) -> io::Result<()> { + syscall!( + SOCKET, + windows_sys::Win32::Networking::WinSock::setsockopt( + self.socket.as_raw_fd() as _, + level, + name, + value as *const _ as _, + std::mem::size_of::() as _ + ) + ) + .map(|_| ()) + } } impl_raw_fd!(Socket, Socket2, socket, socket); diff --git a/compio-net/src/udp.rs b/compio-net/src/udp.rs index f3b86f9a..1063eed9 100644 --- a/compio-net/src/udp.rs +++ b/compio-net/src/udp.rs @@ -222,6 +222,32 @@ impl UdpSocket { .map_res(|(n, addr)| (n, addr.as_socket().expect("should be SocketAddr"))) } + /// Receives a single datagram message and ancillary data on the socket. On + /// success, returns the number of bytes received and the origin. + pub async fn recv_msg( + &self, + buffer: T, + control: C, + ) -> BufResult<(usize, SocketAddr), (T, C)> { + self.inner + .recv_msg(buffer, control) + .await + .map_res(|(n, addr)| (n, addr.as_socket().expect("should be SocketAddr"))) + } + + /// Receives a single datagram message and ancillary data on the socket. On + /// success, returns the number of bytes received and the origin. + pub async fn recv_msg_vectored( + &self, + buffer: T, + control: C, + ) -> BufResult<(usize, SocketAddr), (T, C)> { + self.inner + .recv_msg_vectored(buffer, control) + .await + .map_res(|(n, addr)| (n, addr.as_socket().expect("should be SocketAddr"))) + } + /// Sends data on the socket to the given address. On success, returns the /// number of bytes sent. pub async fn send_to( @@ -249,6 +275,51 @@ impl UdpSocket { }) .await } + + /// Sends data on the socket to the given address accompanied by ancillary + /// data. On success, returns the number of bytes sent. + pub async fn send_msg( + &self, + buffer: T, + control: C, + addr: impl ToSocketAddrsAsync, + ) -> BufResult { + super::first_addr_buf( + addr, + (buffer, control), + |addr, (buffer, control)| async move { + self.inner + .send_msg(buffer, control, &SockAddr::from(addr)) + .await + }, + ) + .await + } + + /// Sends data on the socket to the given address accompanied by ancillary + /// data. On success, returns the number of bytes sent. + pub async fn send_msg_vectored( + &self, + buffer: T, + control: C, + addr: impl ToSocketAddrsAsync, + ) -> BufResult { + super::first_addr_buf( + addr, + (buffer, control), + |addr, (buffer, control)| async move { + self.inner + .send_msg_vectored(buffer, control, &SockAddr::from(addr)) + .await + }, + ) + .await + } + + /// Sets a socket option. + pub fn set_socket_option(&self, level: i32, name: i32, value: &T) -> io::Result<()> { + self.inner.set_socket_option(level, name, value) + } } impl_raw_fd!(UdpSocket, socket2::Socket, inner, socket); diff --git a/compio-net/tests/cmsg.rs b/compio-net/tests/cmsg.rs new file mode 100644 index 00000000..60a86224 --- /dev/null +++ b/compio-net/tests/cmsg.rs @@ -0,0 +1,47 @@ +use compio_buf::IoBuf; +use compio_net::{CMsgBuilder, CMsgIter}; + +#[test] +fn test_cmsg() { + let mut buf = [0u8; 64]; + let mut builder = CMsgBuilder::new(&mut buf); + + builder.try_push(0, 0, ()).unwrap(); // 16 / 12 + builder.try_push(1, 1, u32::MAX).unwrap(); // 16 + 4 + 4 / 12 + 4 + builder.try_push(2, 2, i64::MIN).unwrap(); // 16 + 8 / 12 + 8 + let len = builder.finish(); + assert!(len == 64 || len == 48); + + unsafe { + let buf = buf.slice(..len); + let mut iter = CMsgIter::new(&buf); + + let cmsg = iter.next().unwrap(); + assert_eq!((cmsg.level(), cmsg.ty(), cmsg.data::<()>()), (0, 0, &())); + let cmsg = iter.next().unwrap(); + assert_eq!( + (cmsg.level(), cmsg.ty(), cmsg.data::()), + (1, 1, &u32::MAX) + ); + let cmsg = iter.next().unwrap(); + assert_eq!( + (cmsg.level(), cmsg.ty(), cmsg.data::()), + (2, 2, &i64::MIN) + ); + assert!(iter.next().is_none()); + } +} + +#[test] +#[should_panic] +fn invalid_buffer_length() { + let mut buf = [0u8; 1]; + CMsgBuilder::new(&mut buf); +} + +#[test] +#[should_panic] +fn invalid_buffer_alignment() { + let mut buf = [0u8; 64]; + CMsgBuilder::new(&mut buf[1..]); +} diff --git a/compio-net/tests/udp.rs b/compio-net/tests/udp.rs index 699dec25..fe7dc263 100644 --- a/compio-net/tests/udp.rs +++ b/compio-net/tests/udp.rs @@ -1,4 +1,4 @@ -use compio_net::UdpSocket; +use compio_net::{CMsgBuilder, CMsgIter, UdpSocket}; #[compio_macros::test] async fn connect() { @@ -64,3 +64,53 @@ async fn send_to() { active_addr ); } + +#[compio_macros::test] +async fn send_msg_with_ipv6_ecn() { + #[cfg(unix)] + use libc::{IPPROTO_IPV6, IPV6_RECVTCLASS, IPV6_TCLASS}; + #[cfg(windows)] + use windows_sys::Win32::Networking::WinSock::{ + IPPROTO_IPV6, IPV6_ECN, IPV6_RECVTCLASS, IPV6_TCLASS, + }; + + const MSG: &str = "foo bar baz"; + + let passive = UdpSocket::bind("[::1]:0").await.unwrap(); + let passive_addr = passive.local_addr().unwrap(); + + passive + .set_socket_option(IPPROTO_IPV6, IPV6_RECVTCLASS, &1) + .unwrap(); + + let active = UdpSocket::bind("[::1]:0").await.unwrap(); + let active_addr = active.local_addr().unwrap(); + + let mut control = vec![0u8; 32]; + let mut builder = CMsgBuilder::new(&mut control); + + const ECN_BITS: i32 = 0b10; + + #[cfg(unix)] + builder + .try_push(IPPROTO_IPV6, IPV6_TCLASS, ECN_BITS) + .unwrap(); + #[cfg(windows)] + builder.try_push(IPPROTO_IPV6, IPV6_ECN, ECN_BITS).unwrap(); + + let len = builder.finish(); + control.truncate(len); + + active.send_msg(MSG, control, passive_addr).await.unwrap(); + + let res = passive.recv_msg(Vec::with_capacity(20), [0u8; 32]).await; + assert_eq!(res.0.unwrap().1, active_addr); + assert_eq!(res.1.0, MSG.as_bytes()); + unsafe { + let mut iter = CMsgIter::new(&res.1.1); + let cmsg = iter.next().unwrap(); + assert_eq!(cmsg.level(), IPPROTO_IPV6); + assert_eq!(cmsg.ty(), IPV6_TCLASS); + assert_eq!(cmsg.data::(), &ECN_BITS); + } +}