diff --git a/.github/workflows/ci-preemptive.sh b/.github/workflows/ci-preemptive.sh index 4a2598b6..b5d2da90 100644 --- a/.github/workflows/ci-preemptive.sh +++ b/.github/workflows/ci-preemptive.sh @@ -34,3 +34,13 @@ if [ "${TARGET}" = "x86_64-unknown-linux-gnu" ]; then "${CARGO}" test --target "${TARGET}" --no-default-features --features io_uring,preemptive,ci "${CARGO}" test --target "${TARGET}" --no-default-features --features io_uring,preemptive,ci --release fi + +# test IOCP +if [ "${OS}" = "windows-latest" ]; then + cd "${PROJECT_DIR}"/core + "${CARGO}" test --target "${TARGET}" --no-default-features --features iocp,preemptive + "${CARGO}" test --target "${TARGET}" --no-default-features --features iocp,preemptive --release + cd "${PROJECT_DIR}"/open-coroutine + "${CARGO}" test --target "${TARGET}" --no-default-features --features iocp,preemptive + "${CARGO}" test --target "${TARGET}" --no-default-features --features iocp,preemptive --release +fi diff --git a/.github/workflows/ci.sh b/.github/workflows/ci.sh index 5bddad7e..d5e3d5c9 100644 --- a/.github/workflows/ci.sh +++ b/.github/workflows/ci.sh @@ -15,6 +15,13 @@ fi export RUST_TEST_THREADS=1 export RUST_BACKTRACE=1 +# todo remove this +if [ "${OS}" = "windows-latest" ]; then + cd "${PROJECT_DIR}"/open-coroutine + "${CARGO}" test --target "${TARGET}" --no-default-features --features iocp + "${CARGO}" test --target "${TARGET}" --no-default-features --features iocp --release +fi + # test open-coroutine-core mod cd "${PROJECT_DIR}"/core "${CARGO}" test --target "${TARGET}" --features ci @@ -34,3 +41,13 @@ if [ "${TARGET}" = "x86_64-unknown-linux-gnu" ]; then "${CARGO}" test --target "${TARGET}" --no-default-features --features io_uring,ci "${CARGO}" test --target "${TARGET}" --no-default-features --features io_uring,ci --release fi + +# test IOCP +if [ "${OS}" = "windows-latest" ]; then + cd "${PROJECT_DIR}"/core + "${CARGO}" test --target "${TARGET}" --no-default-features --features iocp + "${CARGO}" test --target "${TARGET}" --no-default-features --features iocp --release + cd "${PROJECT_DIR}"/open-coroutine + "${CARGO}" test --target "${TARGET}" --no-default-features --features iocp + "${CARGO}" test --target "${TARGET}" --no-default-features --features iocp --release +fi diff --git a/core/Cargo.toml b/core/Cargo.toml index fb221583..325d350b 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -95,5 +95,11 @@ net = ["korosensei", "polling", "mio", "crossbeam-utils", "core_affinity"] # Provide io_uring adaptation, this feature only works in linux. io_uring = ["net", "io-uring"] +# Provide IOCP adaptation, this feature only works in windows. +iocp = ["net"] + +# Provide completion IOCP adaptation +completion_io = ["io_uring", "iocp"] + # Provide syscall implementation. syscall = ["net"] diff --git a/core/src/net/event_loop.rs b/core/src/net/event_loop.rs index 4e61f3da..b9d4ae41 100644 --- a/core/src/net/event_loop.rs +++ b/core/src/net/event_loop.rs @@ -24,16 +24,34 @@ cfg_if::cfg_if! { } } +cfg_if::cfg_if! { + if #[cfg(all(windows, feature = "iocp"))] { + use std::ffi::c_uint; + use windows_sys::core::{PCSTR, PSTR}; + use windows_sys::Win32::Networking::WinSock::{ + setsockopt, LPWSAOVERLAPPED_COMPLETION_ROUTINE, SEND_RECV_FLAGS, SOCKADDR, SOCKET, SOL_SOCKET, + SO_UPDATE_ACCEPT_CONTEXT, WSABUF, + }; + use windows_sys::Win32::System::IO::OVERLAPPED; + } +} + #[repr(C)] #[derive(Debug)] pub(crate) struct EventLoop<'e> { stop: Arc<(Mutex, Condvar)>, shared_stop: Arc<(Mutex, Condvar)>, cpu: usize, - #[cfg(all(target_os = "linux", feature = "io_uring"))] + #[cfg(any( + all(target_os = "linux", feature = "io_uring"), + all(windows, feature = "iocp") + ))] operator: crate::net::operator::Operator<'e>, #[allow(clippy::type_complexity)] - #[cfg(all(target_os = "linux", feature = "io_uring"))] + #[cfg(any( + all(target_os = "linux", feature = "io_uring"), + all(windows, feature = "iocp") + ))] syscall_wait_table: DashMap>, Condvar)>>, selector: Poller, pool: CoroutinePool<'e>, @@ -87,9 +105,15 @@ impl<'e> EventLoop<'e> { stop: Arc::new((Mutex::new(false), Condvar::new())), shared_stop, cpu, - #[cfg(all(target_os = "linux", feature = "io_uring"))] + #[cfg(any( + all(target_os = "linux", feature = "io_uring"), + all(windows, feature = "iocp") + ))] operator: crate::net::operator::Operator::new(cpu)?, - #[cfg(all(target_os = "linux", feature = "io_uring"))] + #[cfg(any( + all(target_os = "linux", feature = "io_uring"), + all(windows, feature = "iocp") + ))] syscall_wait_table: DashMap::new(), selector: Poller::new()?, pool: CoroutinePool::new(name, stack_size, min_size, max_size, keep_alive_time), @@ -222,6 +246,8 @@ impl<'e> EventLoop<'e> { cfg_if::cfg_if! { if #[cfg(all(target_os = "linux", feature = "io_uring"))] { left_time = self.adapt_io_uring(left_time)?; + } else if #[cfg(all(windows, feature = "iocp"))] { + left_time = self.adapt_iocp(left_time)?; } } @@ -267,6 +293,51 @@ impl<'e> EventLoop<'e> { Ok(left_time) } + #[cfg(all(windows, feature = "iocp"))] + fn adapt_iocp(&self, mut left_time: Option) -> std::io::Result> { + // use IOCP + let (count, mut cq, left) = self.operator.select(left_time, 0)?; + if count > 0 { + for cqe in &mut cq { + let token = cqe.token; + let bytes_transferred = cqe.bytes_transferred; + // resolve completed read/write tasks + // todo refactor IOCP impl + let result = match cqe.syscall { + Syscall::accept => unsafe { + if setsockopt( + cqe.socket, + SOL_SOCKET, + SO_UPDATE_ACCEPT_CONTEXT, + std::ptr::from_ref(&cqe.from_fd).cast(), + c_int::try_from(size_of::()).expect("overflow"), + ) == 0 + { + cqe.socket.try_into().expect("result overflow") + } else { + -c_longlong::from(windows_sys::Win32::Foundation::GetLastError()) + } + }, + Syscall::recv | Syscall::WSARecv | Syscall::send | Syscall::WSASend => { + bytes_transferred.into() + } + _ => panic!("unsupported"), + }; + if let Some((_, pair)) = self.syscall_wait_table.remove(&token) { + let (lock, cvar) = &*pair; + let mut pending = lock.lock().expect("lock failed"); + *pending = Some(result); + cvar.notify_one(); + } + unsafe { self.resume(token) }; + } + } + if left != left_time { + left_time = Some(left.unwrap_or(Duration::ZERO)); + } + Ok(left_time) + } + unsafe fn resume(&self, token: usize) { if COROUTINE_TOKENS.remove(&token).is_none() { return; @@ -446,6 +517,34 @@ impl_io_uring!(mkdirat(dirfd: c_int, pathname: *const c_char, mode: mode_t) -> c impl_io_uring!(renameat(olddirfd: c_int, oldpath: *const c_char, newdirfd: c_int, newpath: *const c_char) -> c_int); impl_io_uring!(renameat2(olddirfd: c_int, oldpath: *const c_char, newdirfd: c_int, newpath: *const c_char, flags: c_uint) -> c_int); +macro_rules! impl_iocp { + ( $syscall: ident($($arg: ident : $arg_type: ty),*) -> $result: ty ) => { + #[cfg(all(windows, feature = "iocp"))] + impl EventLoop<'_> { + #[allow(non_snake_case, clippy::too_many_arguments)] + pub(super) fn $syscall( + &self, + $($arg: $arg_type),* + ) -> std::io::Result>, Condvar)>> { + let token = EventLoop::token(Syscall::$syscall); + self.operator.$syscall(token, $($arg, )*)?; + let arc = Arc::new((Mutex::new(None), Condvar::new())); + assert!( + self.syscall_wait_table.insert(token, arc.clone()).is_none(), + "The previous token was not retrieved in a timely manner" + ); + Ok(arc) + } + } + } +} + +impl_iocp!(accept(fd: SOCKET, addr: *mut SOCKADDR, len: *mut c_int) -> c_int); +impl_iocp!(recv(fd: SOCKET, buf: PSTR, len: c_int, flags: SEND_RECV_FLAGS) -> c_int); +impl_iocp!(WSARecv(fd: SOCKET, buf: *const WSABUF, dwbuffercount: c_uint, lpnumberofbytesrecvd: *mut c_uint, lpflags : *mut c_uint, lpoverlapped: *mut OVERLAPPED, lpcompletionroutine : LPWSAOVERLAPPED_COMPLETION_ROUTINE) -> c_int); +impl_iocp!(send(fd: SOCKET, buf: PCSTR, len: c_int, flags: SEND_RECV_FLAGS) -> c_int); +impl_iocp!(WSASend(fd: SOCKET, buf: *const WSABUF, dwbuffercount: c_uint, lpnumberofbytesrecvd: *mut c_uint, dwflags : c_uint, lpoverlapped: *mut OVERLAPPED, lpcompletionroutine : LPWSAOVERLAPPED_COMPLETION_ROUTINE) -> c_int); + #[cfg(all(test, not(all(unix, feature = "preemptive"))))] mod tests { use crate::net::event_loop::EventLoop; diff --git a/core/src/net/mod.rs b/core/src/net/mod.rs index bcb2bf37..ba6afa73 100644 --- a/core/src/net/mod.rs +++ b/core/src/net/mod.rs @@ -18,6 +18,17 @@ cfg_if::cfg_if! { } } +cfg_if::cfg_if! { + if #[cfg(all(windows, feature = "iocp"))] { + use std::ffi::c_uint; + use windows_sys::core::{PCSTR, PSTR}; + use windows_sys::Win32::Networking::WinSock::{ + LPWSAOVERLAPPED_COMPLETION_ROUTINE, SEND_RECV_FLAGS, SOCKADDR, SOCKET, WSABUF, + }; + use windows_sys::Win32::System::IO::OVERLAPPED; + } +} + /// 做C兼容时会用到 pub type UserFunc = extern "C" fn(usize) -> usize; @@ -25,7 +36,11 @@ mod selector; #[allow(clippy::too_many_arguments)] #[cfg(all(target_os = "linux", feature = "io_uring"))] -mod operator; +#[cfg(any( + all(target_os = "linux", feature = "io_uring"), + all(windows, feature = "iocp") +))] +pub(crate) mod operator; #[allow(missing_docs)] pub mod event_loop; @@ -280,3 +295,24 @@ impl_io_uring!(fsync(fd: c_int) -> c_int); impl_io_uring!(mkdirat(dirfd: c_int, pathname: *const c_char, mode: mode_t) -> c_int); impl_io_uring!(renameat(olddirfd: c_int, oldpath: *const c_char, newdirfd: c_int, newpath: *const c_char) -> c_int); impl_io_uring!(renameat2(olddirfd: c_int, oldpath: *const c_char, newdirfd: c_int, newpath: *const c_char, flags: c_uint) -> c_int); + +macro_rules! impl_iocp { + ( $syscall: ident($($arg: ident : $arg_type: ty),*) -> $result: ty ) => { + #[allow(non_snake_case)] + #[cfg(all(windows, feature = "iocp"))] + impl EventLoops { + #[allow(missing_docs)] + pub fn $syscall( + $($arg: $arg_type),* + ) -> std::io::Result>, Condvar)>> { + Self::event_loop().$syscall($($arg, )*) + } + } + } +} + +impl_iocp!(accept(fd: SOCKET, addr: *mut SOCKADDR, len: *mut c_int) -> c_int); +impl_iocp!(recv(fd: SOCKET, buf: PSTR, len: c_int, flags: SEND_RECV_FLAGS) -> c_int); +impl_iocp!(WSARecv(fd: SOCKET, buf: *const WSABUF, dwbuffercount: c_uint, lpnumberofbytesrecvd: *mut c_uint, lpflags : *mut c_uint, lpoverlapped: *mut OVERLAPPED, lpcompletionroutine : LPWSAOVERLAPPED_COMPLETION_ROUTINE) -> c_int); +impl_iocp!(send(fd: SOCKET, buf: PCSTR, len: c_int, flags: SEND_RECV_FLAGS) -> c_int); +impl_iocp!(WSASend(fd: SOCKET, buf: *const WSABUF, dwbuffercount: c_uint, lpnumberofbytesrecvd: *mut c_uint, dwflags : c_uint, lpoverlapped: *mut OVERLAPPED, lpcompletionroutine : LPWSAOVERLAPPED_COMPLETION_ROUTINE) -> c_int); diff --git a/core/src/net/operator/mod.rs b/core/src/net/operator/mod.rs index 6a821a4d..1b246c8b 100644 --- a/core/src/net/operator/mod.rs +++ b/core/src/net/operator/mod.rs @@ -2,3 +2,9 @@ mod linux; #[cfg(all(target_os = "linux", feature = "io_uring"))] pub(crate) use linux::*; + +#[allow(non_snake_case)] +#[cfg(all(windows, feature = "iocp"))] +mod windows; +#[cfg(all(windows, feature = "iocp"))] +pub(crate) use windows::*; diff --git a/core/src/net/operator/windows/mod.rs b/core/src/net/operator/windows/mod.rs new file mode 100644 index 00000000..4952e626 --- /dev/null +++ b/core/src/net/operator/windows/mod.rs @@ -0,0 +1,443 @@ +use crate::common::constants::Syscall; +use crate::common::{get_timeout_time, now}; +use crate::impl_display_by_debug; +use dashmap::{DashMap, DashSet}; +use once_cell::sync::Lazy; +use std::ffi::{c_int, c_uint, c_void}; +use std::io::{Error, ErrorKind}; +use std::marker::PhantomData; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::time::{Duration, Instant}; +use windows_sys::core::{PCSTR, PSTR}; +use windows_sys::Win32::Foundation::{FALSE, HANDLE, INVALID_HANDLE_VALUE}; +use windows_sys::Win32::Networking::WinSock::{ + AcceptEx, WSAGetLastError, WSARecv, WSASend, WSASocketW, INVALID_SOCKET, IPPROTO, + LPWSAOVERLAPPED_COMPLETION_ROUTINE, SEND_RECV_FLAGS, SOCKADDR, SOCKADDR_IN, SOCKET, + SOCKET_ERROR, WINSOCK_SOCKET_TYPE, WSABUF, WSA_FLAG_OVERLAPPED, WSA_IO_PENDING, +}; +use windows_sys::Win32::System::IO::{ + CreateIoCompletionPort, GetQueuedCompletionStatusEx, OVERLAPPED, OVERLAPPED_ENTRY, +}; + +#[cfg(test)] +mod tests; + +#[repr(C)] +#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] +pub(crate) struct SocketContext { + pub(crate) domain: c_int, + pub(crate) ty: WINSOCK_SOCKET_TYPE, + pub(crate) protocol: IPPROTO, +} + +pub(crate) static SOCKET_CONTEXT: Lazy> = + Lazy::new(Default::default); + +/// The overlapped struct we actually used for IOCP. +#[repr(C)] +#[derive(educe::Educe)] +#[educe(Debug)] +pub(crate) struct Overlapped { + /// The base [`OVERLAPPED`]. + #[educe(Debug(ignore))] + pub base: OVERLAPPED, + pub from_fd: SOCKET, + pub socket: SOCKET, + pub token: usize, + pub syscall: Syscall, + pub bytes_transferred: u32, +} + +impl Default for Overlapped { + fn default() -> Self { + unsafe { std::mem::zeroed() } + } +} + +impl_display_by_debug!(Overlapped); + +#[repr(C)] +#[derive(Debug)] +pub(crate) struct Operator<'o> { + cpu: usize, + iocp: HANDLE, + entering: AtomicBool, + handles: DashSet, + phantom_data: PhantomData<&'o Overlapped>, +} + +impl<'o> Operator<'o> { + pub(crate) fn new(cpu: usize) -> std::io::Result { + let iocp = + unsafe { CreateIoCompletionPort(INVALID_HANDLE_VALUE, std::ptr::null_mut(), 0, 0) }; + if iocp.is_null() { + return Err(Error::last_os_error()); + } + Ok(Self { + cpu, + iocp, + entering: AtomicBool::new(false), + handles: DashSet::default(), + phantom_data: PhantomData, + }) + } + + /// Associates a new `HANDLE` to this I/O completion port. + /// + /// This function will associate the given handle to this port with the + /// given `token` to be returned in status messages whenever it receives a + /// notification. + /// + /// Any object which is convertible to a `HANDLE` via the `AsRawHandle` + /// trait can be provided to this function, such as `std::fs::File` and + /// friends. + fn add_handle(&self, handle: HANDLE) -> std::io::Result<()> { + if self.handles.contains(&handle) { + return Ok(()); + } + let ret = unsafe { CreateIoCompletionPort(handle, self.iocp, self.cpu, 0) }; + if ret.is_null() { + return Err(Error::new( + ErrorKind::Other, + format!("bind handle:{} to IOCP failed", handle as usize), + )); + } + debug_assert_eq!(ret, self.iocp); + if unsafe { + windows_sys::Win32::Storage::FileSystem::SetFileCompletionNotificationModes( + handle, + windows_sys::Win32::System::WindowsProgramming::FILE_SKIP_SET_EVENT_ON_HANDLE as u8, + ) + } == 0 + { + return Err(Error::last_os_error()); + } + Ok(()) + } + + pub(crate) fn select( + &self, + timeout: Option, + want: usize, + ) -> std::io::Result<(usize, Vec, Option)> { + if self + .entering + .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed) + .is_err() + { + return Ok((0, Vec::new(), timeout)); + } + let result = self.do_select(timeout, want); + self.entering.store(false, Ordering::Release); + result + } + + fn do_select( + &self, + timeout: Option, + want: usize, + ) -> std::io::Result<(usize, Vec, Option)> { + let start_time = Instant::now(); + let timeout_time = timeout.map_or(u64::MAX, get_timeout_time); + let mut cq = Vec::new(); + loop { + let left_ms = (timeout_time.saturating_sub(now()) / 1_000_000) + .try_into() + .expect("overflow"); + if left_ms == 0 { + break; + } + let mut entries: Vec = Vec::with_capacity(1024); + let uninit = entries.spare_capacity_mut(); + let mut recv_count = 0; + let ret = unsafe { + GetQueuedCompletionStatusEx( + self.iocp, + uninit.as_mut_ptr().cast(), + uninit.len().try_into().expect("overflow"), + &mut recv_count, + left_ms, + 0, + ) + }; + let e = Error::last_os_error(); + if FALSE == ret { + if ErrorKind::TimedOut == e.kind() { + continue; + } + return Err(e); + } + unsafe { entries.set_len(recv_count as _) }; + for entry in entries { + let mut overlapped = + unsafe { *Box::from_raw(entry.lpOverlapped.cast::()) }; + overlapped.bytes_transferred = entry.dwNumberOfBytesTransferred; + eprintln!("IOCP got Overlapped:{overlapped}"); + cq.push(overlapped); + } + if cq.len() >= want { + break; + } + } + let cost = Instant::now().saturating_duration_since(start_time); + Ok((cq.len(), cq, timeout.map(|t| t.saturating_sub(cost)))) + } + + pub(crate) fn accept( + &self, + user_data: usize, + fd: SOCKET, + _address: *mut SOCKADDR, + _address_len: *mut c_int, + ) -> std::io::Result<()> { + self.add_handle(fd as HANDLE)?; + let context = SOCKET_CONTEXT.get(&fd).expect("socket context not found"); + let ctx = context.value(); + unsafe { + let socket = WSASocketW( + ctx.domain, + ctx.ty, + ctx.protocol, + std::ptr::null(), + 0, + WSA_FLAG_OVERLAPPED, + ); + if INVALID_SOCKET == socket { + return Err(Error::new( + ErrorKind::WouldBlock, + "add accept operation failed", + )); + } + let size = size_of::() + .saturating_add(16) + .try_into() + .expect("size overflow"); + let overlapped: &'o mut Overlapped = Box::leak(Box::default()); + overlapped.from_fd = fd; + overlapped.socket = socket; + overlapped.token = user_data; + overlapped.syscall = Syscall::accept; + let mut buf: Vec = Vec::with_capacity(size as usize * 2); + while AcceptEx( + fd, + socket, + buf.as_mut_ptr().cast(), + 0, + size, + size, + std::ptr::null_mut(), + std::ptr::from_mut(overlapped).cast(), + ) == FALSE + { + if WSA_IO_PENDING == WSAGetLastError() { + break; + } + } + eprintln!("add accept operation Overlapped:{overlapped}"); + } + Ok(()) + } + + pub(crate) fn recv( + &self, + user_data: usize, + fd: SOCKET, + buf: PSTR, + len: c_int, + flags: SEND_RECV_FLAGS, + ) -> std::io::Result<()> { + self.add_handle(fd as HANDLE)?; + unsafe { + let overlapped: &'o mut Overlapped = Box::leak(Box::default()); + overlapped.from_fd = fd; + overlapped.token = user_data; + overlapped.syscall = Syscall::recv; + let buf = [WSABUF { + len: len.try_into().expect("len overflow"), + buf: buf.cast(), + }]; + if WSARecv( + fd, + buf.as_ptr(), + buf.len().try_into().expect("len overflow"), + std::ptr::null_mut(), + &mut u32::try_from(flags).expect("overflow"), + std::ptr::from_mut(overlapped).cast(), + None, + ) == SOCKET_ERROR + && WSA_IO_PENDING != WSAGetLastError() + { + return Err(Error::new( + ErrorKind::WouldBlock, + "add recv operation failed", + )); + } + eprintln!("add recv operation Overlapped:{overlapped}"); + } + Ok(()) + } + + pub(crate) fn WSARecv( + &self, + user_data: usize, + fd: SOCKET, + buf: *const WSABUF, + dwbuffercount: c_uint, + lpnumberofbytesrecvd: *mut c_uint, + lpflags: *mut c_uint, + lpoverlapped: *mut OVERLAPPED, + lpcompletionroutine: LPWSAOVERLAPPED_COMPLETION_ROUTINE, + ) -> std::io::Result<()> { + assert!( + lpoverlapped.is_null(), + "the WSARecv in Operator should be called without lpoverlapped! Please report bug to open-coroutine!" + ); + self.add_handle(fd as HANDLE)?; + unsafe { + let overlapped: &'o mut Overlapped = Box::leak(Box::default()); + overlapped.from_fd = fd; + overlapped.token = user_data; + overlapped.syscall = Syscall::WSARecv; + if WSARecv( + fd, + buf, + dwbuffercount, + lpnumberofbytesrecvd, + lpflags, + std::ptr::from_mut(overlapped).cast(), + lpcompletionroutine, + ) == SOCKET_ERROR + && WSA_IO_PENDING != WSAGetLastError() + { + return Err(Error::new( + ErrorKind::WouldBlock, + "add WSARecv operation failed", + )); + } + eprintln!("add WSARecv operation Overlapped:{overlapped}"); + } + Ok(()) + } + + pub(crate) fn send( + &self, + user_data: usize, + fd: SOCKET, + buf: PCSTR, + len: c_int, + flags: SEND_RECV_FLAGS, + ) -> std::io::Result<()> { + self.add_handle(fd as HANDLE)?; + unsafe { + let overlapped: &'o mut Overlapped = Box::leak(Box::default()); + overlapped.from_fd = fd; + overlapped.token = user_data; + overlapped.syscall = Syscall::send; + let buf = [WSABUF { + len: len.try_into().expect("len overflow"), + buf: buf.cast_mut(), + }]; + if WSASend( + fd, + buf.as_ptr(), + buf.len().try_into().expect("len overflow"), + std::ptr::null_mut(), + u32::try_from(flags).expect("overflow"), + std::ptr::from_mut(overlapped).cast(), + None, + ) == SOCKET_ERROR + && WSA_IO_PENDING != WSAGetLastError() + { + return Err(Error::new( + ErrorKind::WouldBlock, + "add send operation failed", + )); + } + eprintln!("add send operation Overlapped:{overlapped}"); + } + Ok(()) + } + + pub(crate) fn write( + &self, + user_data: usize, + fd: c_int, + buf: *const c_void, + count: c_uint, + ) -> std::io::Result<()> { + self.add_handle(fd as HANDLE)?; + unsafe { + if SOCKET_CONTEXT.get(&(fd as SOCKET)).is_some() { + let overlapped: &'o mut Overlapped = Box::leak(Box::default()); + overlapped.from_fd = fd as _; + overlapped.token = user_data; + overlapped.syscall = Syscall::write; + let buf = [WSABUF { + len: count, + buf: buf.cast_mut().cast(), + }]; + if WSASend( + fd as _, + buf.as_ptr(), + buf.len().try_into().expect("len overflow"), + std::ptr::null_mut(), + 0, + std::ptr::from_mut(overlapped).cast(), + None, + ) == SOCKET_ERROR + && WSA_IO_PENDING != WSAGetLastError() + { + return Err(Error::new( + ErrorKind::WouldBlock, + "add write operation failed", + )); + } + eprintln!("add write operation Overlapped:{overlapped}"); + return Ok(()); + } + } + todo!() + } + + pub(crate) fn WSASend( + &self, + user_data: usize, + fd: SOCKET, + buf: *const WSABUF, + dwbuffercount: c_uint, + lpnumberofbytesrecvd: *mut c_uint, + dwflags: c_uint, + lpoverlapped: *mut OVERLAPPED, + lpcompletionroutine: LPWSAOVERLAPPED_COMPLETION_ROUTINE, + ) -> std::io::Result<()> { + assert!( + lpoverlapped.is_null(), + "the WSASend in Operator should be called without lpoverlapped! Please report bug to open-coroutine!" + ); + self.add_handle(fd as HANDLE)?; + unsafe { + let overlapped: &'o mut Overlapped = Box::leak(Box::default()); + overlapped.from_fd = fd; + overlapped.token = user_data; + overlapped.syscall = Syscall::WSASend; + if WSASend( + fd, + buf, + dwbuffercount, + lpnumberofbytesrecvd, + dwflags, + std::ptr::from_mut(overlapped).cast(), + lpcompletionroutine, + ) == SOCKET_ERROR + && WSA_IO_PENDING != WSAGetLastError() + { + return Err(Error::new( + ErrorKind::WouldBlock, + "add WSASend operation failed", + )); + } + eprintln!("add WSASend operation Overlapped:{overlapped}"); + } + Ok(()) + } +} diff --git a/core/src/net/operator/windows/tests.rs b/core/src/net/operator/windows/tests.rs new file mode 100644 index 00000000..b8feabf3 --- /dev/null +++ b/core/src/net/operator/windows/tests.rs @@ -0,0 +1,174 @@ +use crate::net::operator::Operator; +use slab::Slab; +use std::io::{BufRead, BufReader, Write}; +use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpListener, TcpStream}; +use std::os::windows::io::AsRawSocket; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::time::Duration; +use windows_sys::Win32::Networking::WinSock::{closesocket, SOCKET}; + +#[derive(Clone, Debug)] +enum Token { + Accept, + Read { + fd: SOCKET, + buf_index: usize, + }, + Write { + fd: SOCKET, + buf_index: usize, + offset: usize, + len: usize, + }, +} + +fn crate_client(port: u16, server_started: Arc) { + //等服务端起来 + while !server_started.load(Ordering::Acquire) {} + let socket = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), port); + let mut stream = TcpStream::connect_timeout(&socket, Duration::from_secs(3)) + .unwrap_or_else(|_| panic!("connect to 127.0.0.1:{port} failed !")); + let mut data: [u8; 512] = [b'1'; 512]; + data[511] = b'\n'; + let mut buffer: Vec = Vec::with_capacity(512); + for _ in 0..3 { + //写入stream流,如果写入失败,提示"写入失败" + assert_eq!(512, stream.write(&data).expect("Failed to write!")); + print!("Client Send: {}", String::from_utf8_lossy(&data[..])); + + let mut reader = BufReader::new(&stream); + //一直读到换行为止(b'\n'中的b表示字节),读到buffer里面 + assert_eq!( + 512, + reader + .read_until(b'\n', &mut buffer) + .expect("Failed to read into buffer") + ); + print!("Client Received: {}", String::from_utf8_lossy(&buffer[..])); + assert_eq!(&data, &buffer as &[u8]); + buffer.clear(); + } + //发送终止符 + assert_eq!(1, stream.write(&[b'e']).expect("Failed to write!")); + println!("client closed"); +} + +fn crate_server2(port: u16, server_started: Arc) -> anyhow::Result<()> { + let operator = Operator::new(0)?; + let listener = TcpListener::bind(("127.0.0.1", port))?; + + let mut bufpool = Vec::with_capacity(64); + let mut buf_alloc = Slab::with_capacity(64); + let mut token_alloc = Slab::with_capacity(64); + + println!("listen {}", listener.local_addr()?); + server_started.store(true, Ordering::Release); + + operator.accept( + token_alloc.insert(Token::Accept), + listener.as_raw_socket() as _, + std::ptr::null_mut(), + std::ptr::null_mut(), + )?; + + loop { + let (_, mut cq, _) = operator.select(None, 1)?; + for cqe in &mut cq { + let token_index = cqe.token; + let token = &mut token_alloc[token_index]; + match token.clone() { + Token::Accept => { + println!("accept"); + let fd = cqe.socket; + let (buf_index, buf) = match bufpool.pop() { + Some(buf_index) => (buf_index, &mut buf_alloc[buf_index]), + None => { + let buf = vec![0u8; 2048].into_boxed_slice(); + let buf_entry = buf_alloc.vacant_entry(); + let buf_index = buf_entry.key(); + (buf_index, buf_entry.insert(buf)) + } + }; + *token = Token::Read { fd, buf_index }; + operator.recv(token_index, fd, buf.as_mut_ptr() as _, buf.len() as _, 0)?; + } + Token::Read { fd, buf_index } => { + let ret = cqe.bytes_transferred as _; + if ret == 0 { + bufpool.push(buf_index); + _ = token_alloc.remove(token_index); + println!("shutdown connection"); + _ = unsafe { closesocket(fd) }; + println!("Server closed"); + return Ok(()); + } else { + let len = ret; + let buf = &buf_alloc[buf_index]; + *token = Token::Write { + fd, + buf_index, + len, + offset: 0, + }; + operator.send(token_index, fd, buf.as_ptr() as _, len as _, 0)?; + } + } + Token::Write { + fd, + buf_index, + offset, + len, + } => { + let write_len = cqe.bytes_transferred as usize; + if offset + write_len >= len { + bufpool.push(buf_index); + let (buf_index, buf) = match bufpool.pop() { + Some(buf_index) => (buf_index, &mut buf_alloc[buf_index]), + None => { + let buf = vec![0u8; 2048].into_boxed_slice(); + let buf_entry = buf_alloc.vacant_entry(); + let buf_index = buf_entry.key(); + (buf_index, buf_entry.insert(buf)) + } + }; + *token = Token::Read { fd, buf_index }; + operator.recv(token_index, fd, buf.as_mut_ptr() as _, buf.len() as _, 0)?; + } else { + let offset = offset + write_len; + let len = len - offset; + let buf = &buf_alloc[buf_index][offset..]; + *token = Token::Write { + fd, + buf_index, + offset, + len, + }; + operator.write(token_index, fd as _, buf.as_ptr() as _, len as _)?; + }; + } + } + } + } +} + +#[test] +fn framework() -> anyhow::Result<()> { + #[cfg(feature = "log")] + let _ = tracing_subscriber::fmt() + .with_thread_names(true) + .with_line_number(true) + .with_timer(tracing_subscriber::fmt::time::OffsetTime::new( + time::UtcOffset::from_hms(8, 0, 0).expect("create UtcOffset failed !"), + time::format_description::well_known::Rfc2822, + )) + .try_init(); + let port = 7061; + let server_started = Arc::new(AtomicBool::new(false)); + let clone = server_started.clone(); + let handle = std::thread::spawn(move || crate_server2(port, clone)); + std::thread::spawn(move || crate_client(port, server_started)) + .join() + .expect("client has error"); + handle.join().expect("server has error") +} diff --git a/core/src/syscall/mod.rs b/core/src/syscall/mod.rs index cc4f184e..d7a65a72 100644 --- a/core/src/syscall/mod.rs +++ b/core/src/syscall/mod.rs @@ -16,6 +16,6 @@ mod unix; #[cfg(windows)] pub use windows::*; -#[allow(non_snake_case)] +#[allow(non_snake_case, dead_code)] #[cfg(windows)] mod windows; diff --git a/core/src/syscall/windows/WSARecv.rs b/core/src/syscall/windows/WSARecv.rs index f4a467bc..a216fcfd 100644 --- a/core/src/syscall/windows/WSARecv.rs +++ b/core/src/syscall/windows/WSARecv.rs @@ -2,6 +2,9 @@ use once_cell::sync::Lazy; use std::ffi::{c_int, c_uint}; use windows_sys::Win32::Networking::WinSock::{LPWSAOVERLAPPED_COMPLETION_ROUTINE, SOCKET, WSABUF}; use windows_sys::Win32::System::IO::OVERLAPPED; +use crate::common::constants::{CoroutineState, Syscall, SyscallState}; +use crate::{error, info}; +use crate::scheduler::SchedulableCoroutine; #[must_use] pub extern "system" fn WSARecv( @@ -24,6 +27,16 @@ pub extern "system" fn WSARecv( lpoverlapped: *mut OVERLAPPED, lpcompletionroutine: LPWSAOVERLAPPED_COMPLETION_ROUTINE, ) -> c_int { + // cfg_if::cfg_if! { + // if #[cfg(all(windows, feature = "iocp"))] { + // static CHAIN: Lazy< + // WSARecvSyscallFacade>> + // > = Lazy::new(Default::default); + // } else { + // static CHAIN: Lazy>> = + // Lazy::new(Default::default); + // } + // } static CHAIN: Lazy>> = Lazy::new(Default::default); CHAIN.WSARecv( @@ -62,17 +75,179 @@ trait WSARecvSyscall { ) -> c_int; } -impl_facade!(WSARecvSyscallFacade, WSARecvSyscall, - WSARecv( + +#[repr(C)] +#[derive(Debug, Default)] +struct WSARecvSyscallFacade { + inner: I, +} + +impl WSARecvSyscall for WSARecvSyscallFacade { + extern "system" fn WSARecv( + &self, + fn_ptr: Option< + &extern "system" fn( + SOCKET, + *const WSABUF, + c_uint, + *mut c_uint, + *mut c_uint, + *mut OVERLAPPED, + LPWSAOVERLAPPED_COMPLETION_ROUTINE, + ) -> c_int, + >, fd: SOCKET, buf: *const WSABUF, dwbuffercount: c_uint, lpnumberofbytesrecvd: *mut c_uint, lpflags: *mut c_uint, lpoverlapped: *mut OVERLAPPED, - lpcompletionroutine: LPWSAOVERLAPPED_COMPLETION_ROUTINE - ) -> c_int -); + lpcompletionroutine: LPWSAOVERLAPPED_COMPLETION_ROUTINE, + ) -> c_int { + let syscall = Syscall::WSARecv; + info!("enter syscall {}", syscall); + if let Some(co) = SchedulableCoroutine::current() { + _ = co.syscall((), syscall, SyscallState::Executing); + } + let r = self.inner.WSARecv( + fn_ptr, + fd, + buf, + dwbuffercount, + lpnumberofbytesrecvd, + lpflags, + lpoverlapped, + lpcompletionroutine, + ); + if let Some(co) = SchedulableCoroutine::current() { + if let CoroutineState::SystemCall((), Syscall::WSARecv, SyscallState::Executing) = + co.state() + { + if co.running().is_err() { + error!("{} change to running state failed !", co.name()); + } + } + } + info!("exit syscall {}", syscall); + r + } +} + +#[cfg(all(windows, feature = "iocp"))] +#[repr(C)] +#[derive(Debug, Default)] +struct IocpWSARecvSyscall { + inner: I, +} + +#[cfg(all(windows, feature = "iocp"))] +impl WSARecvSyscall for IocpWSARecvSyscall { + #[allow(clippy::too_many_lines)] + extern "system" fn WSARecv( + &self, + fn_ptr: Option< + &extern "system" fn( + SOCKET, + *const WSABUF, + c_uint, + *mut c_uint, + *mut c_uint, + *mut OVERLAPPED, + LPWSAOVERLAPPED_COMPLETION_ROUTINE, + ) -> c_int, + >, + fd: SOCKET, + buf: *const WSABUF, + dwbuffercount: c_uint, + lpnumberofbytesrecvd: *mut c_uint, + lpflags: *mut c_uint, + lpoverlapped: *mut OVERLAPPED, + lpcompletionroutine: LPWSAOVERLAPPED_COMPLETION_ROUTINE, + ) -> c_int { + use windows_sys::Win32::Networking::WinSock::{SOCKET_ERROR, WSAEWOULDBLOCK}; + use crate::net::EventLoops; + use crate::scheduler::SchedulableSuspender; + + if !lpoverlapped.is_null() { + return RawWSARecvSyscall::default().WSARecv( + fn_ptr, + fd, + buf, + dwbuffercount, + lpnumberofbytesrecvd, + lpflags, + lpoverlapped, + lpcompletionroutine, + ); + } + match EventLoops::WSARecv(fd, buf, dwbuffercount, lpnumberofbytesrecvd, lpflags, lpoverlapped, lpcompletionroutine) { + Ok(arc) => { + if let Some(co) = SchedulableCoroutine::current() { + if let CoroutineState::SystemCall((), syscall, SyscallState::Executing) = co.state() + { + let new_state = SyscallState::Suspend(u64::MAX); + if co.syscall((), syscall, new_state).is_err() { + error!( + "{} change to syscall {} {} failed !", + co.name(), syscall, new_state + ); + } + } + } + if let Some(suspender) = SchedulableSuspender::current() { + suspender.suspend(); + //回来的时候,系统调用已经执行完了 + } + if let Some(co) = SchedulableCoroutine::current() { + if let CoroutineState::SystemCall((), syscall, SyscallState::Callback) = co.state() + { + let new_state = SyscallState::Executing; + if co.syscall((), syscall, new_state).is_err() { + error!( + "{} change to syscall {} {} failed !", + co.name(), syscall, new_state + ); + } + } + } + let (lock, cvar) = &*arc; + let syscall_result: c_int = cvar + .wait_while(lock.lock().expect("lock failed"), + |&mut result| result.is_none() + ) + .expect("lock failed") + .expect("no syscall result") + .try_into() + .expect("IOCP syscall result overflow"); + // fixme 错误处理 + // if syscall_result < 0 { + // let errno: std::ffi::c_int = (-syscall_result).try_into() + // .expect("IOCP errno overflow"); + // $crate::syscall::common::set_errno(errno); + // syscall_result = -1; + // } + syscall_result + } + Err(e) => { + if e.kind() == std::io::ErrorKind::Other { + self.inner.WSARecv( + fn_ptr, + fd, + buf, + dwbuffercount, + lpnumberofbytesrecvd, + lpflags, + lpoverlapped, + lpcompletionroutine, + ) + } else { + crate::syscall::common::set_errno(WSAEWOULDBLOCK.try_into().expect("overflow")); + SOCKET_ERROR + } + } + } + } +} impl_nio_read_iovec!(NioWSARecvSyscall, WSARecvSyscall, WSARecv( diff --git a/core/src/syscall/windows/WSASend.rs b/core/src/syscall/windows/WSASend.rs index a5ef469a..84448cf4 100644 --- a/core/src/syscall/windows/WSASend.rs +++ b/core/src/syscall/windows/WSASend.rs @@ -2,6 +2,9 @@ use once_cell::sync::Lazy; use std::ffi::{c_int, c_uint}; use windows_sys::Win32::Networking::WinSock::{LPWSAOVERLAPPED_COMPLETION_ROUTINE, SOCKET, WSABUF}; use windows_sys::Win32::System::IO::OVERLAPPED; +use crate::common::constants::{CoroutineState, Syscall, SyscallState}; +use crate::{error, info}; +use crate::scheduler::SchedulableCoroutine; #[must_use] pub extern "system" fn WSASend( @@ -24,7 +27,17 @@ pub extern "system" fn WSASend( lpoverlapped: *mut OVERLAPPED, lpcompletionroutine: LPWSAOVERLAPPED_COMPLETION_ROUTINE, ) -> c_int { - static CHAIN: Lazy>> = + // cfg_if::cfg_if! { + // if #[cfg(all(windows, feature = "iocp"))] { + // static CHAIN: Lazy< + // WSASendSyscallFacade>> + // > = Lazy::new(Default::default); + // } else { + // static CHAIN: Lazy>> = + // Lazy::new(Default::default); + // } + // } + static CHAIN: Lazy>> = Lazy::new(Default::default); CHAIN.WSASend( fn_ptr, @@ -38,7 +51,7 @@ pub extern "system" fn WSASend( ) } -trait WSARecvSyscall { +trait WSASendSyscall { extern "system" fn WSASend( &self, fn_ptr: Option< @@ -62,19 +75,179 @@ trait WSARecvSyscall { ) -> c_int; } -impl_facade!(WSARecvSyscallFacade, WSARecvSyscall, - WSASend( +#[repr(C)] +#[derive(Debug, Default)] +struct WSASendSyscallFacade { + inner: I, +} + +impl WSASendSyscall for WSASendSyscallFacade { + extern "system" fn WSASend( + &self, + fn_ptr: Option< + &extern "system" fn( + SOCKET, + *const WSABUF, + c_uint, + *mut c_uint, + c_uint, + *mut OVERLAPPED, + LPWSAOVERLAPPED_COMPLETION_ROUTINE, + ) -> c_int, + >, fd: SOCKET, buf: *const WSABUF, dwbuffercount: c_uint, lpnumberofbytesrecvd: *mut c_uint, dwflags: c_uint, lpoverlapped: *mut OVERLAPPED, - lpcompletionroutine: LPWSAOVERLAPPED_COMPLETION_ROUTINE - ) -> c_int -); + lpcompletionroutine: LPWSAOVERLAPPED_COMPLETION_ROUTINE, + ) -> c_int { + let syscall = Syscall::WSASend; + info!("enter syscall {}", syscall); + if let Some(co) = SchedulableCoroutine::current() { + _ = co.syscall((), syscall, SyscallState::Executing); + } + let r = self.inner.WSASend( + fn_ptr, + fd, + buf, + dwbuffercount, + lpnumberofbytesrecvd, + dwflags, + lpoverlapped, + lpcompletionroutine, + ); + if let Some(co) = SchedulableCoroutine::current() { + if let CoroutineState::SystemCall((), Syscall::WSASend, SyscallState::Executing) = + co.state() + { + if co.running().is_err() { + error!("{} change to running state failed !", co.name()); + } + } + } + info!("exit syscall {}", syscall); + r + } +} + +#[cfg(all(windows, feature = "iocp"))] +#[repr(C)] +#[derive(Debug, Default)] +struct IocpWSASendSyscall { + inner: I, +} + +#[cfg(all(windows, feature = "iocp"))] +impl WSASendSyscall for IocpWSASendSyscall { + extern "system" fn WSASend( + &self, + fn_ptr: Option< + &extern "system" fn( + SOCKET, + *const WSABUF, + c_uint, + *mut c_uint, + c_uint, + *mut OVERLAPPED, + LPWSAOVERLAPPED_COMPLETION_ROUTINE, + ) -> c_int, + >, + fd: SOCKET, + buf: *const WSABUF, + dwbuffercount: c_uint, + lpnumberofbytesrecvd: *mut c_uint, + dwflags: c_uint, + lpoverlapped: *mut OVERLAPPED, + lpcompletionroutine: LPWSAOVERLAPPED_COMPLETION_ROUTINE, + ) -> c_int { + use windows_sys::Win32::Networking::WinSock::{SOCKET_ERROR, WSAEWOULDBLOCK}; + use crate::net::EventLoops; + use crate::scheduler::SchedulableSuspender; + + if !lpoverlapped.is_null() { + return RawWSASendSyscall::default().WSASend( + fn_ptr, + fd, + buf, + dwbuffercount, + lpnumberofbytesrecvd, + dwflags, + lpoverlapped, + lpcompletionroutine, + ); + } + match EventLoops::WSASend(fd, buf, dwbuffercount, lpnumberofbytesrecvd, dwflags, lpoverlapped, lpcompletionroutine) { + Ok(arc) => { + if let Some(co) = SchedulableCoroutine::current() { + if let CoroutineState::SystemCall((), syscall, SyscallState::Executing) = co.state() + { + let new_state = SyscallState::Suspend(u64::MAX); + if co.syscall((), syscall, new_state).is_err() { + error!( + "{} change to syscall {} {} failed !", + co.name(), syscall, new_state + ); + } + } + } + if let Some(suspender) = SchedulableSuspender::current() { + suspender.suspend(); + //回来的时候,系统调用已经执行完了 + } + if let Some(co) = SchedulableCoroutine::current() { + if let CoroutineState::SystemCall((), syscall, SyscallState::Callback) = co.state() + { + let new_state = SyscallState::Executing; + if co.syscall((), syscall, new_state).is_err() { + error!( + "{} change to syscall {} {} failed !", + co.name(), syscall, new_state + ); + } + } + } + let (lock, cvar) = &*arc; + let syscall_result: c_int = cvar + .wait_while(lock.lock().expect("lock failed"), + |&mut result| result.is_none() + ) + .expect("lock failed") + .expect("no syscall result") + .try_into() + .expect("IOCP syscall result overflow"); + // fixme 错误处理 + // if syscall_result < 0 { + // let errno: std::ffi::c_int = (-syscall_result).try_into() + // .expect("IOCP errno overflow"); + // $crate::syscall::common::set_errno(errno); + // syscall_result = -1; + // } + syscall_result + } + Err(e) => { + if e.kind() == std::io::ErrorKind::Other { + self.inner.WSASend( + fn_ptr, + fd, + buf, + dwbuffercount, + lpnumberofbytesrecvd, + dwflags, + lpoverlapped, + lpcompletionroutine, + ) + } else { + crate::syscall::common::set_errno(WSAEWOULDBLOCK.try_into().expect("overflow")); + SOCKET_ERROR + } + } + } + } +} -impl_nio_write_iovec!(NioWSARecvSyscall, WSARecvSyscall, +impl_nio_write_iovec!(NioWSASendSyscall, WSASendSyscall, WSASend( fd: SOCKET, buf: *const WSABUF, @@ -86,7 +259,7 @@ impl_nio_write_iovec!(NioWSARecvSyscall, WSARecvSyscall, ) -> c_int ); -impl_raw!(RawWSARecvSyscall, WSARecvSyscall, windows_sys::Win32::Networking::WinSock, +impl_raw!(RawWSASendSyscall, WSASendSyscall, windows_sys::Win32::Networking::WinSock, WSASend( fd: SOCKET, buf: *const WSABUF, diff --git a/core/src/syscall/windows/WSASocketW.rs b/core/src/syscall/windows/WSASocketW.rs index fe0022ef..f653a217 100644 --- a/core/src/syscall/windows/WSASocketW.rs +++ b/core/src/syscall/windows/WSASocketW.rs @@ -1,8 +1,6 @@ use once_cell::sync::Lazy; use std::ffi::{c_int, c_uint}; -use windows_sys::Win32::Networking::WinSock::{ - IPPROTO, SOCKET, WINSOCK_SOCKET_TYPE, WSAPROTOCOL_INFOW, -}; +use windows_sys::Win32::Networking::WinSock::{IPPROTO, SOCKET, WINSOCK_SOCKET_TYPE, WSAPROTOCOL_INFOW}; #[must_use] pub extern "system" fn WSASocketW( @@ -23,7 +21,7 @@ pub extern "system" fn WSASocketW( g: c_uint, dw_flags: c_uint, ) -> SOCKET { - static CHAIN: Lazy> = Lazy::new(Default::default); + static CHAIN: Lazy>> = Lazy::new(Default::default); CHAIN.WSASocketW(fn_ptr, domain, ty, protocol, lpprotocolinfo, g, dw_flags) } @@ -60,6 +58,45 @@ impl_facade!(WSASocketWSyscallFacade, WSASocketWSyscall, ) -> SOCKET ); +#[repr(C)] +#[derive(Debug, Default)] +struct NioWSASocketWSyscall { + inner: I, +} + +impl WSASocketWSyscall for NioWSASocketWSyscall { + extern "system" fn WSASocketW( + &self, + fn_ptr: Option< + &extern "system" fn( + c_int, + WINSOCK_SOCKET_TYPE, + IPPROTO, + *const WSAPROTOCOL_INFOW, + c_uint, + c_uint, + ) -> SOCKET, + >, + domain: c_int, + ty: WINSOCK_SOCKET_TYPE, + protocol: IPPROTO, + lpprotocolinfo: *const WSAPROTOCOL_INFOW, + g: c_uint, + dw_flags: c_uint + ) -> SOCKET { + let r = self.inner.WSASocketW(fn_ptr, domain, ty, protocol, lpprotocolinfo, g, dw_flags); + #[cfg(feature = "iocp")] + if windows_sys::Win32::Networking::WinSock::INVALID_SOCKET != r { + _ = crate::net::operator::SOCKET_CONTEXT.insert(r,crate::net::operator::SocketContext{ + domain, + ty, + protocol, + }); + } + r + } +} + impl_raw!(RawWSASocketWSyscall, WSASocketWSyscall, windows_sys::Win32::Networking::WinSock, WSASocketW( domain: c_int, diff --git a/core/src/syscall/windows/accept.rs b/core/src/syscall/windows/accept.rs index dbf1dd60..92ae4166 100644 --- a/core/src/syscall/windows/accept.rs +++ b/core/src/syscall/windows/accept.rs @@ -9,6 +9,16 @@ pub extern "system" fn accept( address: *mut SOCKADDR, address_len: *mut c_int, ) -> SOCKET { + // cfg_if::cfg_if! { + // if #[cfg(feature = "iocp")] { + // static CHAIN: Lazy< + // AcceptSyscallFacade>> + // > = Lazy::new(Default::default); + // } else { + // static CHAIN: Lazy>> = + // Lazy::new(Default::default); + // } + // } static CHAIN: Lazy>> = Lazy::new(Default::default); CHAIN.accept(fn_ptr, fd, address, address_len) @@ -28,6 +38,10 @@ impl_facade!(AcceptSyscallFacade, AcceptSyscall, accept(fd: SOCKET, address: *mut SOCKADDR, address_len: *mut c_int) -> SOCKET ); +impl_iocp!(IocpAcceptSyscall, AcceptSyscall, + accept(fd: SOCKET, address: *mut SOCKADDR, address_len: *mut c_int) -> SOCKET +); + impl_nio_read!(NioAcceptSyscall, AcceptSyscall, accept(fd: SOCKET, address: *mut SOCKADDR, address_len: *mut c_int) -> SOCKET ); diff --git a/core/src/syscall/windows/mod.rs b/core/src/syscall/windows/mod.rs index d142634d..84757412 100644 --- a/core/src/syscall/windows/mod.rs +++ b/core/src/syscall/windows/mod.rs @@ -43,6 +43,97 @@ macro_rules! impl_facade { } } +macro_rules! impl_iocp { + ( $struct_name:ident, $trait_name: ident, $syscall: ident($($arg: ident : $arg_type: ty),*) -> $result: ty ) => { + #[repr(C)] + #[derive(Debug, Default)] + #[cfg(all(windows, feature = "iocp"))] + struct $struct_name { + inner: I, + } + + #[cfg(all(windows, feature = "iocp"))] + impl $trait_name for $struct_name { + extern "system" fn $syscall( + &self, + fn_ptr: Option<&extern "system" fn($($arg_type),*) -> $result>, + $($arg: $arg_type),* + ) -> $result { + use $crate::common::constants::{CoroutineState, Syscall, SyscallState}; + use $crate::scheduler::{SchedulableCoroutine, SchedulableSuspender}; + + match $crate::net::EventLoops::$syscall($($arg, )*) { + Ok(arc) => { + if let Some(co) = SchedulableCoroutine::current() { + if let CoroutineState::SystemCall((), syscall, SyscallState::Executing) = co.state() + { + let new_state = SyscallState::Suspend(u64::MAX); + if co.syscall((), syscall, new_state).is_err() { + $crate::error!( + "{} change to syscall {} {} failed !", + co.name(), + syscall, + new_state + ); + } + } + } + if let Some(suspender) = SchedulableSuspender::current() { + suspender.suspend(); + //回来的时候,系统调用已经执行完了 + } + if let Some(co) = SchedulableCoroutine::current() { + if let CoroutineState::SystemCall((), syscall, SyscallState::Callback) = co.state() + { + let new_state = SyscallState::Executing; + if co.syscall((), syscall, new_state).is_err() { + $crate::error!( + "{} change to syscall {} {} failed !", + co.name(), syscall, new_state + ); + } + } + } + let (lock, cvar) = &*arc; + let mut syscall_result = cvar + .wait_while(lock.lock().expect("lock failed"), + |&mut result| result.is_none() + ) + .expect("lock failed") + .expect("no syscall result"); + eprintln!( + "syscall:{} returns:{} e:{}", + Syscall::$syscall, syscall_result, std::io::Error::last_os_error() + ); + if syscall_result >= 0 { + $crate::syscall::common::reset_errno(); + } else { + let errno = -syscall_result; + $crate::syscall::common::set_errno(errno.try_into().expect("errno overflow")); + if Syscall::accept == Syscall::$syscall { + syscall_result = 0; + } else { + syscall_result = -1; + } + } + <$result>::try_from(syscall_result).expect("overflow") + } + Err(e) => { + if e.kind() == std::io::ErrorKind::Other { + self.inner.$syscall(fn_ptr, $($arg, )*) + } else { + $crate::syscall::common::set_errno( + windows_sys::Win32::Networking::WinSock::WSAEWOULDBLOCK.try_into().expect("overflow") + ); + windows_sys::Win32::Networking::WinSock::SOCKET_ERROR.try_into().expect("overflow") + } + } + } + } + } + } +} + macro_rules! impl_nio_read { ( $struct_name:ident, $trait_name: ident, $syscall: ident($fd: ident : $fd_type: ty, $($arg: ident : $arg_type: ty),*) -> $result: ty ) => { #[repr(C)] diff --git a/core/src/syscall/windows/recv.rs b/core/src/syscall/windows/recv.rs index b9d822f7..aee998dc 100644 --- a/core/src/syscall/windows/recv.rs +++ b/core/src/syscall/windows/recv.rs @@ -11,6 +11,16 @@ pub extern "system" fn recv( len: c_int, flags: SEND_RECV_FLAGS, ) -> c_int { + // cfg_if::cfg_if! { + // if #[cfg(all(windows, feature = "iocp"))] { + // static CHAIN: Lazy< + // RecvSyscallFacade>> + // > = Lazy::new(Default::default); + // } else { + // static CHAIN: Lazy>> = + // Lazy::new(Default::default); + // } + // } static CHAIN: Lazy>> = Lazy::new(Default::default); CHAIN.recv(fn_ptr, fd, buf, len, flags) @@ -31,6 +41,10 @@ impl_facade!(RecvSyscallFacade, RecvSyscall, recv(fd: SOCKET, buf: PSTR, len: c_int, flags: SEND_RECV_FLAGS) -> c_int ); +impl_iocp!(IocpRecvSyscall, RecvSyscall, + recv(fd: SOCKET, buf: PSTR, len: c_int, flags: SEND_RECV_FLAGS) -> c_int +); + impl_nio_read_buf!(NioRecvSyscall, RecvSyscall, recv(fd: SOCKET, buf: PSTR, len: c_int, flags: SEND_RECV_FLAGS) -> c_int ); diff --git a/core/src/syscall/windows/send.rs b/core/src/syscall/windows/send.rs index 8b87be11..16d7af4f 100644 --- a/core/src/syscall/windows/send.rs +++ b/core/src/syscall/windows/send.rs @@ -11,8 +11,16 @@ pub extern "system" fn send( len: c_int, flags: SEND_RECV_FLAGS, ) -> c_int { - static CHAIN: Lazy>> = - Lazy::new(Default::default); + cfg_if::cfg_if! { + if #[cfg(all(windows, feature = "iocp"))] { + static CHAIN: Lazy< + SendSyscallFacade>> + > = Lazy::new(Default::default); + } else { + static CHAIN: Lazy>> = + Lazy::new(Default::default); + } + } CHAIN.send(fn_ptr, fd, buf, len, flags) } @@ -31,6 +39,10 @@ impl_facade!(SendSyscallFacade, SendSyscall, send(fd: SOCKET, buf: PCSTR, len: c_int, flags: SEND_RECV_FLAGS) -> c_int ); +impl_iocp!(IocpSendSyscall, SendSyscall, + send(fd: SOCKET, buf: PCSTR, len: c_int, flags: SEND_RECV_FLAGS) -> c_int +); + impl_nio_write_buf!(NioSendSyscall, SendSyscall, send(fd: SOCKET, buf: PCSTR, len: c_int, flags: SEND_RECV_FLAGS) -> c_int ); diff --git a/core/src/syscall/windows/socket.rs b/core/src/syscall/windows/socket.rs index aff69705..b5ca20fc 100644 --- a/core/src/syscall/windows/socket.rs +++ b/core/src/syscall/windows/socket.rs @@ -9,7 +9,7 @@ pub extern "system" fn socket( ty: WINSOCK_SOCKET_TYPE, protocol: IPPROTO, ) -> SOCKET { - static CHAIN: Lazy> = Lazy::new(Default::default); + static CHAIN: Lazy>> = Lazy::new(Default::default); CHAIN.socket(fn_ptr, domain, ty, protocol) } @@ -27,6 +27,27 @@ impl_facade!(SocketSyscallFacade, SocketSyscall, socket(domain: c_int, ty: WINSOCK_SOCKET_TYPE, protocol: IPPROTO) -> SOCKET ); +#[repr(C)] +#[derive(Debug, Default)] +struct NioSocketSyscall { + inner: I, +} + +impl SocketSyscall for NioSocketSyscall { + extern "system" fn socket(&self, fn_ptr: Option<&extern "system" fn(c_int, WINSOCK_SOCKET_TYPE, IPPROTO) -> SOCKET>, domain: c_int, ty: WINSOCK_SOCKET_TYPE, protocol: IPPROTO) -> SOCKET { + let r = self.inner.socket(fn_ptr, domain, ty, protocol); + #[cfg(feature = "iocp")] + if windows_sys::Win32::Networking::WinSock::INVALID_SOCKET != r { + _ = crate::net::operator::SOCKET_CONTEXT.insert(r,crate::net::operator::SocketContext{ + domain, + ty, + protocol, + }); + } + r + } +} + impl_raw!(RawSocketSyscall, SocketSyscall, windows_sys::Win32::Networking::WinSock, socket(domain: c_int, ty: WINSOCK_SOCKET_TYPE, protocol: IPPROTO) -> SOCKET ); diff --git a/hook/Cargo.toml b/hook/Cargo.toml index e1c324cc..57098f77 100644 --- a/hook/Cargo.toml +++ b/hook/Cargo.toml @@ -49,6 +49,12 @@ net = ["open-coroutine-core/net"] # Provide io_uring adaptation, this feature only works in linux. io_uring = ["open-coroutine-core/io_uring"] +# Provide IOCP adaptation, this feature only works in windows. +iocp = ["open-coroutine-core/iocp"] + +# Provide completion IOCP adaptation +completion_io = ["open-coroutine-core/completion_io"] + # Provide syscall implementation. syscall = ["open-coroutine-core/syscall"] diff --git a/open-coroutine/Cargo.toml b/open-coroutine/Cargo.toml index dc948442..1740ba3e 100644 --- a/open-coroutine/Cargo.toml +++ b/open-coroutine/Cargo.toml @@ -60,5 +60,11 @@ net = ["open-coroutine-hook/net", "open-coroutine-core/net"] # This feature only works in linux. io_uring = ["open-coroutine-hook/io_uring", "open-coroutine-core/io_uring"] +# Provide IOCP adaptation, this feature only works in windows. +iocp = ["open-coroutine-hook/iocp", "open-coroutine-core/iocp"] + +# Provide completion IOCP adaptation +completion_io = ["open-coroutine-hook/completion_io", "open-coroutine-core/completion_io"] + # Provide syscall implementation. syscall = ["open-coroutine-hook/syscall", "open-coroutine-core/syscall"] diff --git a/open-coroutine/build.rs b/open-coroutine/build.rs index d2b167a6..72dd42ea 100644 --- a/open-coroutine/build.rs +++ b/open-coroutine/build.rs @@ -155,6 +155,12 @@ fn main() { if cfg!(feature = "io_uring") { features.push("io_uring"); } + if cfg!(feature = "iocp") { + features.push("iocp"); + } + if cfg!(feature = "completion_io") { + features.push("completion_io"); + } if cfg!(feature = "syscall") { features.push("syscall"); }