Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add ancillary data support #275

Merged
merged 16 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 148 additions & 5 deletions compio-driver/src/iocp/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ use std::{
};

use aligned_array::{Aligned, A8};
use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
use compio_buf::{
BufResult, IntoInner, IoBuf, IoBufMut, IoSlice, IoSliceMut, IoVectoredBuf, IoVectoredBufMut,
};
#[cfg(not(feature = "once_cell_try"))]
use once_cell::sync::OnceCell as OnceLock;
use socket2::SockAddr;
Expand All @@ -25,10 +27,11 @@ use windows_sys::{
},
Networking::WinSock::{
closesocket, setsockopt, shutdown, socklen_t, WSAIoctl, WSARecv, WSARecvFrom, WSASend,
WSASendTo, LPFN_ACCEPTEX, LPFN_CONNECTEX, LPFN_GETACCEPTEXSOCKADDRS, SD_BOTH,
SD_RECEIVE, SD_SEND, SIO_GET_EXTENSION_FUNCTION_POINTER, SOCKADDR, SOCKADDR_STORAGE,
SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT, SO_UPDATE_CONNECT_CONTEXT, WSAID_ACCEPTEX,
WSAID_CONNECTEX, WSAID_GETACCEPTEXSOCKADDRS,
WSASendMsg, WSASendTo, CMSGHDR, LPFN_ACCEPTEX, LPFN_CONNECTEX,
LPFN_GETACCEPTEXSOCKADDRS, LPFN_WSARECVMSG, SD_BOTH, SD_RECEIVE, SD_SEND,
SIO_GET_EXTENSION_FUNCTION_POINTER, SOCKADDR, SOCKADDR_STORAGE, SOL_SOCKET,
SO_UPDATE_ACCEPT_CONTEXT, SO_UPDATE_CONNECT_CONTEXT, WSABUF, WSAID_ACCEPTEX,
WSAID_CONNECTEX, WSAID_GETACCEPTEXSOCKADDRS, WSAID_WSARECVMSG, WSAMSG,
},
Storage::FileSystem::{FlushFileBuffers, ReadFile, WriteFile},
System::{
Expand Down Expand Up @@ -774,6 +777,146 @@ impl<T: IoVectoredBuf, S: AsRawFd> OpCode for SendToVectored<T, S> {
}
}

static WSA_RECVMSG: OnceLock<LPFN_WSARECVMSG> = OnceLock::new();

/// Receive data and source address with ancillary data into vectored buffer.
pub struct RecvMsg<T: IoVectoredBufMut, C: IoBufMut, S> {
addr: SOCKADDR_STORAGE,
addr_len: socklen_t,
fd: SharedFd<S>,
buffer: T,
control: C,
_p: PhantomPinned,
}

impl<T: IoVectoredBufMut, C: IoBufMut, S> RecvMsg<T, C, S> {
/// Create [`RecvMsg`].
///
/// # Panics
///
/// This function will panic if the control message buffer is misaligned.
pub fn new(fd: SharedFd<S>, buffer: T, control: C) -> Self {
AsakuraMizu marked this conversation as resolved.
Show resolved Hide resolved
assert!(
control.as_buf_ptr().cast::<CMSGHDR>().is_aligned(),
"misaligned control message buffer"
);
Self {
addr: unsafe { std::mem::zeroed() },
addr_len: std::mem::size_of::<SOCKADDR_STORAGE>() as _,
fd,
buffer,
control,
_p: PhantomPinned,
}
}
}

impl<T: IoVectoredBufMut, C: IoBufMut, S> IntoInner for RecvMsg<T, C, S> {
type Inner = ((T, C), SOCKADDR_STORAGE, socklen_t);

fn into_inner(self) -> Self::Inner {
((self.buffer, self.control), self.addr, self.addr_len)
}
}

impl<T: IoVectoredBufMut, C: IoBufMut, S: AsRawFd> OpCode for RecvMsg<T, C, S> {
unsafe fn operate(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> Poll<io::Result<usize>> {
let recvmsg_fn = WSA_RECVMSG
.get_or_try_init(|| get_wsa_fn(self.fd.as_raw_fd(), WSAID_WSARECVMSG))?
.ok_or_else(|| {
io::Error::new(io::ErrorKind::Unsupported, "cannot retrieve WSARecvMsg")
})?;

let this = self.get_unchecked_mut();
let mut slices = this.buffer.as_io_slices_mut();
let mut msg = WSAMSG {
name: &mut this.addr as *mut _ as _,
namelen: this.addr_len,
lpBuffers: slices.as_mut_ptr() as _,
dwBufferCount: slices.len() as _,
Control: std::mem::transmute::<IoSliceMut, WSABUF>(this.control.as_io_slice_mut()),
dwFlags: 0,
};

let mut received = 0;
let res = recvmsg_fn(
this.fd.as_raw_fd() as _,
&mut msg,
&mut received,
optr,
None,
);
winsock_result(res, received)
}

unsafe fn cancel(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> io::Result<()> {
cancel(self.fd.as_raw_fd(), optr)
}
}

/// Send data to specified address accompanied by ancillary data from vectored
/// buffer.
pub struct SendMsg<T: IoVectoredBuf, C: IoBuf, S> {
fd: SharedFd<S>,
buffer: T,
control: C,
addr: SockAddr,
_p: PhantomPinned,
}

impl<T: IoVectoredBuf, C: IoBuf, S> SendMsg<T, C, S> {
/// Create [`SendMsg`].
///
/// # Panics
///
/// This function will panic if the control message buffer is misaligned.
pub fn new(fd: SharedFd<S>, buffer: T, control: C, addr: SockAddr) -> Self {
AsakuraMizu marked this conversation as resolved.
Show resolved Hide resolved
assert!(
control.as_buf_ptr().cast::<CMSGHDR>().is_aligned(),
"misaligned control message buffer"
);
Self {
fd,
buffer,
control,
addr,
_p: PhantomPinned,
}
}
}

impl<T: IoVectoredBuf, C: IoBuf, S> IntoInner for SendMsg<T, C, S> {
type Inner = (T, C);

fn into_inner(self) -> Self::Inner {
(self.buffer, self.control)
}
}

impl<T: IoVectoredBuf, C: IoBuf, S: AsRawFd> OpCode for SendMsg<T, C, S> {
unsafe fn operate(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> Poll<io::Result<usize>> {
let this = self.get_unchecked_mut();

let slices = this.buffer.as_io_slices();
let msg = WSAMSG {
name: this.addr.as_ptr() as _,
namelen: this.addr.len(),
lpBuffers: slices.as_ptr() as _,
dwBufferCount: slices.len() as _,
Control: std::mem::transmute::<IoSlice, WSABUF>(this.control.as_io_slice()),
dwFlags: 0,
};

let mut sent = 0;
let res = WSASendMsg(this.fd.as_raw_fd() as _, &msg, 0, &mut sent, optr, None);
winsock_result(res, sent)
}

unsafe fn cancel(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> io::Result<()> {
cancel(self.fd.as_raw_fd(), optr)
}
}

/// Connect a named pipe server.
pub struct ConnectNamedPipe<S> {
pub(crate) fd: SharedFd<S>,
Expand Down
20 changes: 20 additions & 0 deletions compio-driver/src/iour/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,26 @@ impl<T: IoVectoredBuf, S> IntoInner for SendToVectored<T, S> {
}
}

impl<T: IoVectoredBufMut, C: IoBufMut, S: AsRawFd> OpCode for RecvMsg<T, C, S> {
fn create_entry(self: Pin<&mut Self>) -> OpEntry {
let this = unsafe { self.get_unchecked_mut() };
unsafe { this.set_msg() };
opcode::RecvMsg::new(Fd(this.fd.as_raw_fd()), &mut this.msg)
.build()
.into()
}
}

impl<T: IoVectoredBuf, C: IoBuf, S: AsRawFd> OpCode for SendMsg<T, C, S> {
fn create_entry(self: Pin<&mut Self>) -> OpEntry {
let this = unsafe { self.get_unchecked_mut() };
unsafe { this.set_msg() };
opcode::SendMsg::new(Fd(this.fd.as_raw_fd()), &this.msg)
.build()
.into()
}
}

impl<S: AsRawFd> OpCode for PollOnce<S> {
fn create_entry(self: Pin<&mut Self>) -> OpEntry {
let flags = match self.interest {
Expand Down
4 changes: 2 additions & 2 deletions compio-driver/src/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ use socket2::SockAddr;
#[cfg(windows)]
pub use crate::sys::op::ConnectNamedPipe;
pub use crate::sys::op::{
Accept, Recv, RecvFrom, RecvFromVectored, RecvVectored, Send, SendTo, SendToVectored,
SendVectored,
Accept, Recv, RecvFrom, RecvFromVectored, RecvMsg, RecvVectored, Send, SendMsg, SendTo,
SendToVectored, SendVectored,
};
#[cfg(unix)]
pub use crate::sys::op::{
Expand Down
41 changes: 41 additions & 0 deletions compio-driver/src/poll/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,47 @@ impl<T: IoVectoredBuf, S> IntoInner for SendToVectored<T, S> {
}
}

impl<T: IoVectoredBufMut, C: IoBufMut, S: AsRawFd> RecvMsg<T, C, S> {
unsafe fn call(&mut self) -> libc::ssize_t {
libc::recvmsg(self.fd.as_raw_fd(), &mut self.msg, 0)
}
}

impl<T: IoVectoredBufMut, C: IoBufMut, S: AsRawFd> OpCode for RecvMsg<T, C, S> {
fn pre_submit(self: Pin<&mut Self>) -> io::Result<Decision> {
let this = unsafe { self.get_unchecked_mut() };
unsafe { this.set_msg() };
syscall!(this.call(), wait_readable(this.fd.as_raw_fd()))
}

fn on_event(self: Pin<&mut Self>, event: &Event) -> Poll<io::Result<usize>> {
debug_assert!(event.readable);

let this = unsafe { self.get_unchecked_mut() };
syscall!(break this.call())
}
}

impl<T: IoVectoredBuf, C: IoBuf, S: AsRawFd> SendMsg<T, C, S> {
unsafe fn call(&self) -> libc::ssize_t {
libc::sendmsg(self.fd.as_raw_fd(), &self.msg, 0)
}
}

impl<T: IoVectoredBuf, C: IoBuf, S: AsRawFd> OpCode for SendMsg<T, C, S> {
fn pre_submit(self: Pin<&mut Self>) -> io::Result<Decision> {
let this = unsafe { self.get_unchecked_mut() };
unsafe { this.set_msg() };
syscall!(this.call(), wait_writable(this.fd.as_raw_fd()))
}

fn on_event(self: Pin<&mut Self>, event: &Event) -> Poll<io::Result<usize>> {
debug_assert!(event.writable);

syscall!(break self.call())
}
}

impl<S: AsRawFd> OpCode for PollOnce<S> {
fn pre_submit(self: Pin<&mut Self>) -> io::Result<Decision> {
Ok(Decision::wait_for(self.fd.as_raw_fd(), self.interest))
Expand Down
107 changes: 107 additions & 0 deletions compio-driver/src/unix/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,113 @@ impl<T: IoVectoredBuf, S> IntoInner for SendVectored<T, S> {
}
}

/// Receive data and source address with ancillary data into vectored buffer.
pub struct RecvMsg<T: IoVectoredBufMut, C: IoBufMut, S> {
pub(crate) msg: libc::msghdr,
pub(crate) addr: sockaddr_storage,
pub(crate) fd: SharedFd<S>,
pub(crate) buffer: T,
pub(crate) control: C,
pub(crate) slices: Vec<IoSliceMut>,
_p: PhantomPinned,
}

impl<T: IoVectoredBufMut, C: IoBufMut, S> RecvMsg<T, C, S> {
/// Create [`RecvMsg`].
///
/// # Panics
///
/// This function will panic if the control message buffer is misaligned.
pub fn new(fd: SharedFd<S>, buffer: T, control: C) -> Self {
assert!(
control.as_buf_ptr().cast::<libc::cmsghdr>().is_aligned(),
"misaligned control message buffer"
);
Self {
addr: unsafe { std::mem::zeroed() },
msg: unsafe { std::mem::zeroed() },
fd,
buffer,
control,
slices: vec![],
_p: PhantomPinned,
}
}

pub(crate) unsafe fn set_msg(&mut self) {
self.slices = self.buffer.as_io_slices_mut();

self.msg.msg_name = std::ptr::addr_of_mut!(self.addr) as _;
self.msg.msg_namelen = std::mem::size_of_val(&self.addr) as _;
self.msg.msg_iov = self.slices.as_mut_ptr() as _;
self.msg.msg_iovlen = self.slices.len() as _;
self.msg.msg_control = self.control.as_buf_mut_ptr() as _;
self.msg.msg_controllen = self.control.buf_len() as _;
}
}

impl<T: IoVectoredBufMut, C: IoBufMut, S> IntoInner for RecvMsg<T, C, S> {
type Inner = ((T, C), sockaddr_storage, socklen_t);

fn into_inner(self) -> Self::Inner {
((self.buffer, self.control), self.addr, self.msg.msg_namelen)
}
}

/// Send data to specified address accompanied by ancillary data from vectored
/// buffer.
pub struct SendMsg<T: IoVectoredBuf, C: IoBuf, S> {
pub(crate) msg: libc::msghdr,
pub(crate) fd: SharedFd<S>,
pub(crate) buffer: T,
pub(crate) control: C,
pub(crate) addr: SockAddr,
pub(crate) slices: Vec<IoSlice>,
_p: PhantomPinned,
}

impl<T: IoVectoredBuf, C: IoBuf, S> SendMsg<T, C, S> {
/// Create [`SendMsg`].
///
/// # Panics
///
/// This function will panic if the control message buffer is misaligned.
pub fn new(fd: SharedFd<S>, buffer: T, control: C, addr: SockAddr) -> Self {
assert!(
control.as_buf_ptr().cast::<libc::cmsghdr>().is_aligned(),
"misaligned control message buffer"
);
Self {
msg: unsafe { std::mem::zeroed() },
fd,
buffer,
control,
addr,
slices: vec![],
_p: PhantomPinned,
}
}

pub(crate) unsafe fn set_msg(&mut self) {
self.slices = self.buffer.as_io_slices();

self.msg.msg_name = self.addr.as_ptr() as _;
self.msg.msg_namelen = self.addr.len();
self.msg.msg_iov = self.slices.as_ptr() as _;
self.msg.msg_iovlen = self.slices.len() as _;
self.msg.msg_control = self.control.as_buf_ptr() as _;
self.msg.msg_controllen = self.control.buf_len() as _;
}
}

impl<T: IoVectoredBuf, C: IoBuf, S> IntoInner for SendMsg<T, C, S> {
type Inner = (T, C);

fn into_inner(self) -> Self::Inner {
(self.buffer, self.control)
}
}

/// The interest to poll a file descriptor.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Interest {
Expand Down
Loading