From 25f9bc25322ea5d3bd96ec1d4ee9bd42c03782d3 Mon Sep 17 00:00:00 2001 From: Roman Volosatovs Date: Fri, 6 Oct 2023 19:34:09 +0200 Subject: [PATCH] refactor(wasi-sockets): simplify UDP implementation This introduces quite a few changes compared to TCP, which should most probably be integrated there as well Signed-off-by: Roman Volosatovs --- crates/wasi/src/preview2/host/udp.rs | 194 +++++++++++---------------- crates/wasi/src/preview2/udp.rs | 13 +- 2 files changed, 85 insertions(+), 122 deletions(-) diff --git a/crates/wasi/src/preview2/host/udp.rs b/crates/wasi/src/preview2/host/udp.rs index 38dadc938808..8e1c420770f0 100644 --- a/crates/wasi/src/preview2/host/udp.rs +++ b/crates/wasi/src/preview2/host/udp.rs @@ -9,7 +9,7 @@ use crate::preview2::{ Table, }; use crate::preview2::{Pollable, SocketResult, WasiView}; -use cap_net_ext::PoolExt; +use cap_net_ext::{AddressFamily, PoolExt}; use io_lifetimes::AsSocketlike; use rustix::io::Errno; use rustix::net::sockopt; @@ -29,7 +29,10 @@ fn start_bind( let socket = table.get_resource(&this)?; match socket.udp_state { UdpState::Default => {} - _ => return Err(ErrorCode::NotInProgress.into()), + UdpState::BindStarted | UdpState::Connecting | UdpState::ConnectReady => { + return Err(ErrorCode::ConcurrencyConflict.into()) + } + UdpState::Bound | UdpState::Connected => return Err(ErrorCode::AlreadyBound.into()), } let network = table.get_resource(&network)?; @@ -51,56 +54,11 @@ fn start_bind( fn finish_bind(table: &mut Table, this: Resource) -> SocketResult<()> { let socket = table.get_resource_mut(&this)?; match socket.udp_state { - UdpState::BindStarted => {} - _ => return Err(ErrorCode::NotInProgress.into()), - } - - socket.udp_state = UdpState::Bound; - - Ok(()) -} - -fn address_family(table: &Table, this: Resource) -> SocketResult { - let socket = table.get_resource(&this)?; - - // If `SO_DOMAIN` is available, use it. - // - // TODO: OpenBSD also supports this; upstream PRs are posted. - #[cfg(not(any( - windows, - target_os = "ios", - target_os = "macos", - target_os = "netbsd", - target_os = "openbsd" - )))] - { - use rustix::net::AddressFamily; - - let family = sockopt::get_socket_domain(socket.udp_socket())?; - let family = match family { - AddressFamily::INET => IpAddressFamily::Ipv4, - AddressFamily::INET6 => IpAddressFamily::Ipv6, - _ => return Err(ErrorCode::NotSupported.into()), - }; - Ok(family) - } - - // When `SO_DOMAIN` is not available, emulate it. - #[cfg(any( - windows, - target_os = "ios", - target_os = "macos", - target_os = "netbsd", - target_os = "openbsd" - ))] - { - if let Ok(_) = sockopt::get_ipv6_unicast_hops(socket.udp_socket()) { - return Ok(IpAddressFamily::Ipv6); - } - if let Ok(_) = sockopt::get_ip_ttl(socket.udp_socket()) { - return Ok(IpAddressFamily::Ipv4); + UdpState::BindStarted => { + socket.udp_state = UdpState::Bound; + Ok(()) } - Err(ErrorCode::NotSupported.into()) + _ => Err(ErrorCode::NotInProgress.into()), } } @@ -127,70 +85,64 @@ impl crate::preview2::host::udp::udp::HostUdpSocket for T { remote_address: IpSocketAddress, ) -> SocketResult<()> { let table = self.table_mut(); - let r = { - let socket = table.get_resource(&this)?; - match socket.udp_state { - UdpState::Default => { - let family = address_family(table, Resource::new_borrow(this.rep()))?; - let addr = match family { - IpAddressFamily::Ipv4 => Ipv4Addr::UNSPECIFIED.into(), - IpAddressFamily::Ipv6 => Ipv6Addr::UNSPECIFIED.into(), - }; - start_bind( - table, - Resource::new_borrow(this.rep()), - Resource::new_borrow(network.rep()), - SocketAddr::new(addr, 0).into(), - )?; - finish_bind(table, Resource::new_borrow(this.rep()))?; - } - UdpState::BindStarted => { - finish_bind(table, Resource::new_borrow(this.rep()))?; - } - UdpState::Bound => {} - UdpState::Connected => return Err(ErrorCode::AlreadyConnected.into()), - _ => return Err(ErrorCode::NotInProgress.into()), + let socket = table.get_resource(&this)?; + match socket.udp_state { + UdpState::Default => { + let addr = match socket.family { + AddressFamily::Ipv4 => Ipv4Addr::UNSPECIFIED.into(), + AddressFamily::Ipv6 => Ipv6Addr::UNSPECIFIED.into(), + }; + start_bind( + table, + Resource::new_borrow(this.rep()), + Resource::new_borrow(network.rep()), + SocketAddr::new(addr, 0).into(), + )?; + finish_bind(table, Resource::new_borrow(this.rep()))?; } - - let socket = table.get_resource(&this)?; - let network = table.get_resource(&network)?; - let connecter = network.pool.udp_connecter(remote_address)?; - - // Do an OS `connect`. Our socket is non-blocking, so it'll either... - { - let view = &*socket - .udp_socket() - .as_socketlike_view::(); - let r = connecter.connect_existing_udp_socket(view); - r + UdpState::Bound => {} + UdpState::BindStarted | UdpState::Connecting | UdpState::ConnectReady => { + return Err(ErrorCode::ConcurrencyConflict.into()) } - }; + UdpState::Connected => return Err(ErrorCode::AlreadyConnected.into()), + } - match r { + let socket = table.get_resource(&this)?; + let network = table.get_resource(&network)?; + let connecter = network.pool.udp_connecter(remote_address)?; + + // Do an OS `connect`. Our socket is non-blocking, so it'll either... + let res = connecter.connect_existing_udp_socket( + &*socket + .udp_socket() + .as_socketlike_view::(), + ); + match res { // succeed immediately, Ok(()) => { let socket = table.get_resource_mut(&this)?; socket.udp_state = UdpState::ConnectReady; - return Ok(()); + Ok(()) } // continue in progress, - Err(err) if err.raw_os_error() == Some(INPROGRESS.raw_os_error()) => {} + Err(err) if err.raw_os_error() == Some(INPROGRESS.raw_os_error()) => { + let socket = table.get_resource_mut(&this)?; + socket.udp_state = UdpState::Connecting; + Ok(()) + } // or fail immediately. - Err(err) => return Err(err.into()), + Err(err) => Err(err.into()), } - - let socket = table.get_resource_mut(&this)?; - socket.udp_state = UdpState::Connecting; - - Ok(()) } fn finish_connect(&mut self, this: Resource) -> SocketResult<()> { let table = self.table_mut(); let socket = table.get_resource_mut(&this)?; - match socket.udp_state { - UdpState::ConnectReady => {} + UdpState::ConnectReady => { + socket.udp_state = UdpState::Connected; + Ok(()) + } UdpState::Connecting => { // Do a `poll` to test for completion, using a timeout of zero // to avoid blocking. @@ -202,21 +154,21 @@ impl crate::preview2::host::udp::udp::HostUdpSocket for T { 0, ) { Ok(0) => return Err(ErrorCode::WouldBlock.into()), - Ok(_) => (), - Err(err) => Err(err).unwrap(), + Ok(_) => {} + Err(err) => return Err(err.into()), } // Check whether the connect succeeded. match sockopt::get_socket_error(socket.udp_socket()) { - Ok(Ok(())) => {} - Err(err) | Ok(Err(err)) => return Err(err.into()), + Ok(Ok(())) => { + socket.udp_state = UdpState::Connected; + Ok(()) + } + Err(err) | Ok(Err(err)) => Err(err.into()), } } - _ => return Err(ErrorCode::NotInProgress.into()), - }; - - socket.udp_state = UdpState::Connected; - Ok(()) + _ => Err(ErrorCode::NotInProgress.into()), + } } fn receive( @@ -232,7 +184,7 @@ impl crate::preview2::host::udp::udp::HostUdpSocket for T { let socket = table.get_resource(&this)?; let udp_socket = socket.udp_socket(); - let mut datagrams = Vec::with_capacity(max_results.try_into().unwrap_or(usize::MAX)); + let mut datagrams = vec![]; let mut buf = [0; MAX_UDP_DATAGRAM_SIZE]; match socket.udp_state { UdpState::Default | UdpState::BindStarted => return Err(ErrorCode::NotBound.into()), @@ -352,8 +304,12 @@ impl crate::preview2::host::udp::udp::HostUdpSocket for T { &mut self, this: Resource, ) -> Result { - let family = address_family(self.table(), this)?; - Ok(family) + let table = self.table(); + let socket = table.get_resource(&this)?; + match socket.family { + AddressFamily::Ipv4 => Ok(IpAddressFamily::Ipv4), + AddressFamily::Ipv6 => Ok(IpAddressFamily::Ipv6), + } } fn ipv6_only(&mut self, this: Resource) -> SocketResult { @@ -477,12 +433,12 @@ impl crate::preview2::host::udp::udp::HostUdpSocket for T { } } -// On POSIX, non-blocking UDP socket `connect` uses `EINPROGRESS`. -// -#[cfg(not(windows))] -const INPROGRESS: Errno = Errno::INPROGRESS; - -// On Windows, non-blocking UDP socket `connect` uses `WSAEWOULDBLOCK`. -// -#[cfg(windows)] -const INPROGRESS: Errno = Errno::WOULDBLOCK; +const INPROGRESS: Errno = if cfg!(windows) { + // On Windows, non-blocking UDP socket `connect` uses `WSAEWOULDBLOCK`. + // + Errno::WOULDBLOCK +} else { + // On POSIX, non-blocking UDP socket `connect` uses `EINPROGRESS`. + // + Errno::INPROGRESS +}; diff --git a/crates/wasi/src/preview2/udp.rs b/crates/wasi/src/preview2/udp.rs index b146a47ccd14..53a20572c19c 100644 --- a/crates/wasi/src/preview2/udp.rs +++ b/crates/wasi/src/preview2/udp.rs @@ -43,6 +43,9 @@ pub struct UdpSocket { /// The current state in the bind/connect progression. pub(crate) udp_state: UdpState, + + /// Socket address family. + pub(crate) family: AddressFamily, } #[async_trait] @@ -50,7 +53,7 @@ impl Subscribe for UdpSocket { async fn ready(&mut self) { // Some states are ready immediately. match self.udp_state { - UdpState::BindStarted | UdpState::ConnectReady => return, + UdpState::BindStarted => return, _ => {} } @@ -68,16 +71,20 @@ impl UdpSocket { // Create a new host socket and set it to non-blocking, which is needed // by our async implementation. let udp_socket = cap_std::net::UdpSocket::new(family, Blocking::No)?; - Self::from_udp_socket(udp_socket) + Self::from_udp_socket(udp_socket, family) } - pub fn from_udp_socket(udp_socket: cap_std::net::UdpSocket) -> io::Result { + pub fn from_udp_socket( + udp_socket: cap_std::net::UdpSocket, + family: AddressFamily, + ) -> io::Result { let fd = udp_socket.into_raw_socketlike(); let std_socket = unsafe { std::net::UdpSocket::from_raw_socketlike(fd) }; let socket = with_ambient_tokio_runtime(|| tokio::net::UdpSocket::try_from(std_socket))?; Ok(Self { inner: Arc::new(socket), udp_state: UdpState::Default, + family, }) }