Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async DNS resolve #106

Merged
merged 17 commits into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion compio-driver/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ socket2 = { version = "0.5", features = ["all"] }
# Windows specific dependencies
[target.'cfg(windows)'.dependencies]
compio-buf = { workspace = true, features = ["arrayvec"] }
# may be excluded from linking if the unstable equivalent is used
aligned-array = "1"
once_cell = "1"
windows-sys = { version = "0.48", features = [
Expand Down
10 changes: 10 additions & 0 deletions compio-driver/src/iocp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use windows_sys::Win32::{
RtlNtStatusToDosError, ERROR_HANDLE_EOF, ERROR_IO_INCOMPLETE, ERROR_NO_DATA,
ERROR_OPERATION_ABORTED, INVALID_HANDLE_VALUE, NTSTATUS, STATUS_PENDING, STATUS_SUCCESS,
},
Networking::WinSock::{WSACleanup, WSAStartup, WSADATA},
Storage::FileSystem::SetFileCompletionNotificationModes,
System::{
Threading::INFINITE,
Expand Down Expand Up @@ -134,6 +135,9 @@ impl Driver {
const DEFAULT_CAPACITY: usize = 1024;

pub fn new(_entries: u32) -> io::Result<Self> {
let mut data: WSADATA = unsafe { std::mem::zeroed() };
syscall!(SOCKET, WSAStartup(0x202, &mut data))?;

let port = syscall!(BOOL, CreateIoCompletionPort(INVALID_HANDLE_VALUE, 0, 0, 0))?;
let port = unsafe { OwnedHandle::from_raw_handle(port as _) };
Ok(Self {
Expand Down Expand Up @@ -271,6 +275,12 @@ impl AsRawFd for Driver {
}
}

impl Drop for Driver {
fn drop(&mut self) {
syscall!(SOCKET, WSACleanup()).ok();
}
}

/// The overlapped struct we actually used for IOCP.
#[repr(C)]
pub struct Overlapped<T: ?Sized> {
Expand Down
9 changes: 6 additions & 3 deletions compio-fs/src/file.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
use std::{fs::Metadata, io, path::Path};

#[cfg(all(feature = "runtime", unix))]
use compio_driver::op::{ReadVectoredAt, WriteVectoredAt};
use compio_driver::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
#[cfg(feature = "runtime")]
use {
compio_buf::{buf_try, BufResult, IntoInner, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut},
compio_buf::{buf_try, BufResult, IntoInner, IoBuf, IoBufMut},
compio_driver::op::{BufResultExt, ReadAt, Sync, WriteAt},
compio_io::{AsyncReadAt, AsyncWriteAt},
compio_runtime::{submit, Attachable, Attacher},
};
#[cfg(all(feature = "runtime", unix))]
use {
compio_buf::{IoVectoredBuf, IoVectoredBufMut},
compio_driver::op::{ReadVectoredAt, WriteVectoredAt},
};

use crate::OpenOptions;

Expand Down
12 changes: 12 additions & 0 deletions compio-net/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,20 @@ compio-driver = { workspace = true }
compio-io = { workspace = true, optional = true }
compio-runtime = { workspace = true, optional = true }

either = "1"
socket2 = { version = "0.5", features = ["all"] }

[target.'cfg(windows)'.dependencies]
widestring = "1"
windows-sys = { version = "0.48", features = [
"Win32_Foundation",
"Win32_Networking_WinSock",
"Win32_System_IO",
] }

[target.'cfg(all(target_os = "linux", target_env = "gnu"))'.dependencies]
libc = "0.2"

# Shared dev dependencies for all platforms
[dev-dependencies]
futures-util = "0.3"
Expand Down
159 changes: 70 additions & 89 deletions compio-net/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
#![warn(missing_docs)]

mod resolve;
mod socket;
mod tcp;
mod udp;
Expand All @@ -13,38 +14,36 @@ mod unix;
use std::{
future::Future,
io,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs},
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
};

use compio_buf::BufResult;
use either::Either;
pub(crate) use socket::*;
use socket2::SockAddr;
pub use tcp::*;
pub use udp::*;
pub use unix::*;

/// A trait for objects which can be converted or resolved to one or more
/// [`SockAddr`] values.
/// [`SocketAddr`] values.
///
/// See [`ToSocketAddrs`].
pub trait ToSockAddrs {
/// See [`ToSocketAddrs::Iter`].
type Iter: Iterator<Item = SockAddr>;

/// See [`ToSocketAddrs::to_socket_addrs`].
fn to_sock_addrs(&self) -> io::Result<Self::Iter>;
/// See [`std::net::ToSocketAddrs`].
#[allow(async_fn_in_trait)]
pub trait ToSocketAddrsAsync {
/// See [`std::net::ToSocketAddrs::Iter`].
type Iter: Iterator<Item = SocketAddr>;

/// See [`std::net::ToSocketAddrs::to_socket_addrs`].
async fn to_socket_addrs_async(&self) -> io::Result<Self::Iter>;
Berrysoft marked this conversation as resolved.
Show resolved Hide resolved
}

// impl_to_sock_addrs_for_into_socket_addr
macro_rules! itsafisa {
($t:ty) => {
impl ToSockAddrs for $t {
type Iter =
std::iter::Map<<$t as std::net::ToSocketAddrs>::Iter, fn(SocketAddr) -> SockAddr>;
impl ToSocketAddrsAsync for $t {
type Iter = std::iter::Once<SocketAddr>;

fn to_sock_addrs(&self) -> io::Result<Self::Iter> {
std::net::ToSocketAddrs::to_socket_addrs(self)
.map(|iter| iter.map(SockAddr::from as _))
async fn to_socket_addrs_async(&self) -> io::Result<Self::Iter> {
Ok(std::iter::once(SocketAddr::from(*self)))
}
}
};
Expand All @@ -53,71 +52,84 @@ macro_rules! itsafisa {
itsafisa!(SocketAddr);
itsafisa!(SocketAddrV4);
itsafisa!(SocketAddrV6);
itsafisa!(str);
itsafisa!(String);
itsafisa!((IpAddr, u16));
itsafisa!((Ipv4Addr, u16));
itsafisa!((Ipv6Addr, u16));
itsafisa!((String, u16));

impl ToSockAddrs for (&str, u16) {
type Iter = std::iter::Map<std::vec::IntoIter<SocketAddr>, fn(SocketAddr) -> SockAddr>;
impl ToSocketAddrsAsync for (&str, u16) {
type Iter = Either<std::iter::Once<SocketAddr>, std::vec::IntoIter<SocketAddr>>;

async fn to_socket_addrs_async(&self) -> io::Result<Self::Iter> {
let (host, port) = self;
if let Ok(addr) = host.parse::<Ipv4Addr>() {
return Ok(Either::Left(std::iter::once(SocketAddr::from((
addr, *port,
)))));
}
if let Ok(addr) = host.parse::<Ipv6Addr>() {
return Ok(Either::Left(std::iter::once(SocketAddr::from((
addr, *port,
)))));
}

fn to_sock_addrs(&self) -> io::Result<Self::Iter> {
ToSocketAddrs::to_socket_addrs(self).map(|iter| iter.map(SockAddr::from as _))
resolve::resolve_sock_addrs(host, *port)
.await
.map(Either::Right)
}
}

impl ToSockAddrs for SockAddr {
type Iter = std::option::IntoIter<SockAddr>;
impl ToSocketAddrsAsync for (String, u16) {
type Iter = <(&'static str, u16) as ToSocketAddrsAsync>::Iter;

fn to_sock_addrs(&self) -> io::Result<Self::Iter> {
Ok(Some(self.clone()).into_iter())
async fn to_socket_addrs_async(&self) -> io::Result<Self::Iter> {
(&*self.0, self.1).to_socket_addrs_async().await
}
}

impl<'a> ToSockAddrs for &'a [SockAddr] {
type Iter = std::iter::Cloned<std::slice::Iter<'a, SockAddr>>;
impl ToSocketAddrsAsync for str {
type Iter = <(&'static str, u16) as ToSocketAddrsAsync>::Iter;

fn to_sock_addrs(&self) -> io::Result<Self::Iter> {
Ok(self.iter().cloned())
async fn to_socket_addrs_async(&self) -> io::Result<Self::Iter> {
if let Ok(addr) = self.parse::<SocketAddr>() {
return Ok(Either::Left(std::iter::once(addr)));
}

let (host, port_str) = self.rsplit_once(':').expect("invalid socket address");
let port: u16 = port_str.parse().expect("invalid port value");
(host, port).to_socket_addrs_async().await
}
}

impl<T: ToSockAddrs + ?Sized> ToSockAddrs for &T {
type Iter = T::Iter;
impl ToSocketAddrsAsync for String {
type Iter = <(&'static str, u16) as ToSocketAddrsAsync>::Iter;

fn to_sock_addrs(&self) -> io::Result<Self::Iter> {
(**self).to_sock_addrs()
async fn to_socket_addrs_async(&self) -> io::Result<Self::Iter> {
self.as_str().to_socket_addrs_async().await
}
}

fn each_addr<T>(
addr: impl ToSockAddrs,
mut f: impl FnMut(SockAddr) -> io::Result<T>,
) -> io::Result<T> {
let addrs = addr.to_sock_addrs()?;
let mut last_err = None;
for addr in addrs {
match f(addr) {
Ok(l) => return Ok(l),
Err(e) => last_err = Some(e),
}
impl<'a> ToSocketAddrsAsync for &'a [SocketAddr] {
type Iter = std::iter::Copied<std::slice::Iter<'a, SocketAddr>>;

async fn to_socket_addrs_async(&self) -> io::Result<Self::Iter> {
Ok(self.iter().copied())
}
Err(last_err.unwrap_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve to any addresses",
)
}))
}

#[allow(dead_code)]
async fn each_addr_async<T, F: Future<Output = io::Result<T>>>(
addr: impl ToSockAddrs,
mut f: impl FnMut(SockAddr) -> F,
impl<T: ToSocketAddrsAsync + ?Sized> ToSocketAddrsAsync for &T {
type Iter = T::Iter;

async fn to_socket_addrs_async(&self) -> io::Result<Self::Iter> {
(**self).to_socket_addrs_async().await
}
}

#[cfg(feature = "runtime")]
async fn each_addr<T, F: Future<Output = io::Result<T>>>(
addr: impl ToSocketAddrsAsync,
mut f: impl FnMut(SocketAddr) -> F,
) -> io::Result<T> {
let addrs = addr.to_sock_addrs()?;
let addrs = addr.to_socket_addrs_async().await?;
let mut last_err = None;
for addr in addrs {
match f(addr).await {
Expand All @@ -132,34 +144,3 @@ async fn each_addr_async<T, F: Future<Output = io::Result<T>>>(
)
}))
}

#[allow(dead_code)]
async fn each_addr_async_buf<T, B, F: Future<Output = BufResult<T, B>>>(
addr: impl ToSockAddrs,
mut buffer: B,
mut f: impl FnMut(SockAddr, B) -> F,
) -> BufResult<T, B> {
match addr.to_sock_addrs() {
Ok(addrs) => {
let mut last_err = None;
let mut res;
for addr in addrs {
BufResult(res, buffer) = f(addr, buffer).await;
match res {
Ok(l) => return BufResult(Ok(l), buffer),
Err(e) => last_err = Some(e),
}
}
BufResult(
Err(last_err.unwrap_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve to any addresses",
)
})),
buffer,
)
}
Err(e) => BufResult(Err(e), buffer),
}
}
Loading