Skip to content

Commit

Permalink
refactor(wasi-sockets): simplify UDP implementation
Browse files Browse the repository at this point in the history
This introduces quite a few changes compared to TCP, which should most probably be integrated there as well

Signed-off-by: Roman Volosatovs <rvolosatovs@riseup.net>
  • Loading branch information
rvolosatovs committed Oct 6, 2023
1 parent 48c8f01 commit 25f9bc2
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 122 deletions.
194 changes: 75 additions & 119 deletions crates/wasi/src/preview2/host/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::preview2::{
Table,
};
use crate::preview2::{Pollable, SocketResult, WasiView};
use cap_net_ext::PoolExt;
use cap_net_ext::{AddressFamily, PoolExt};
use io_lifetimes::AsSocketlike;
use rustix::io::Errno;
use rustix::net::sockopt;
Expand All @@ -29,7 +29,10 @@ fn start_bind(
let socket = table.get_resource(&this)?;
match socket.udp_state {
UdpState::Default => {}
_ => return Err(ErrorCode::NotInProgress.into()),
UdpState::BindStarted | UdpState::Connecting | UdpState::ConnectReady => {
return Err(ErrorCode::ConcurrencyConflict.into())
}
UdpState::Bound | UdpState::Connected => return Err(ErrorCode::AlreadyBound.into()),
}

let network = table.get_resource(&network)?;
Expand All @@ -51,56 +54,11 @@ fn start_bind(
fn finish_bind(table: &mut Table, this: Resource<udp::UdpSocket>) -> SocketResult<()> {
let socket = table.get_resource_mut(&this)?;
match socket.udp_state {
UdpState::BindStarted => {}
_ => return Err(ErrorCode::NotInProgress.into()),
}

socket.udp_state = UdpState::Bound;

Ok(())
}

fn address_family(table: &Table, this: Resource<udp::UdpSocket>) -> SocketResult<IpAddressFamily> {
let socket = table.get_resource(&this)?;

// If `SO_DOMAIN` is available, use it.
//
// TODO: OpenBSD also supports this; upstream PRs are posted.
#[cfg(not(any(
windows,
target_os = "ios",
target_os = "macos",
target_os = "netbsd",
target_os = "openbsd"
)))]
{
use rustix::net::AddressFamily;

let family = sockopt::get_socket_domain(socket.udp_socket())?;
let family = match family {
AddressFamily::INET => IpAddressFamily::Ipv4,
AddressFamily::INET6 => IpAddressFamily::Ipv6,
_ => return Err(ErrorCode::NotSupported.into()),
};
Ok(family)
}

// When `SO_DOMAIN` is not available, emulate it.
#[cfg(any(
windows,
target_os = "ios",
target_os = "macos",
target_os = "netbsd",
target_os = "openbsd"
))]
{
if let Ok(_) = sockopt::get_ipv6_unicast_hops(socket.udp_socket()) {
return Ok(IpAddressFamily::Ipv6);
}
if let Ok(_) = sockopt::get_ip_ttl(socket.udp_socket()) {
return Ok(IpAddressFamily::Ipv4);
UdpState::BindStarted => {
socket.udp_state = UdpState::Bound;
Ok(())
}
Err(ErrorCode::NotSupported.into())
_ => Err(ErrorCode::NotInProgress.into()),
}
}

Expand All @@ -127,70 +85,64 @@ impl<T: WasiView> crate::preview2::host::udp::udp::HostUdpSocket for T {
remote_address: IpSocketAddress,
) -> SocketResult<()> {
let table = self.table_mut();
let r = {
let socket = table.get_resource(&this)?;
match socket.udp_state {
UdpState::Default => {
let family = address_family(table, Resource::new_borrow(this.rep()))?;
let addr = match family {
IpAddressFamily::Ipv4 => Ipv4Addr::UNSPECIFIED.into(),
IpAddressFamily::Ipv6 => Ipv6Addr::UNSPECIFIED.into(),
};
start_bind(
table,
Resource::new_borrow(this.rep()),
Resource::new_borrow(network.rep()),
SocketAddr::new(addr, 0).into(),
)?;
finish_bind(table, Resource::new_borrow(this.rep()))?;
}
UdpState::BindStarted => {
finish_bind(table, Resource::new_borrow(this.rep()))?;
}
UdpState::Bound => {}
UdpState::Connected => return Err(ErrorCode::AlreadyConnected.into()),
_ => return Err(ErrorCode::NotInProgress.into()),
let socket = table.get_resource(&this)?;
match socket.udp_state {
UdpState::Default => {
let addr = match socket.family {
AddressFamily::Ipv4 => Ipv4Addr::UNSPECIFIED.into(),
AddressFamily::Ipv6 => Ipv6Addr::UNSPECIFIED.into(),
};
start_bind(
table,
Resource::new_borrow(this.rep()),
Resource::new_borrow(network.rep()),
SocketAddr::new(addr, 0).into(),
)?;
finish_bind(table, Resource::new_borrow(this.rep()))?;
}

let socket = table.get_resource(&this)?;
let network = table.get_resource(&network)?;
let connecter = network.pool.udp_connecter(remote_address)?;

// Do an OS `connect`. Our socket is non-blocking, so it'll either...
{
let view = &*socket
.udp_socket()
.as_socketlike_view::<cap_std::net::UdpSocket>();
let r = connecter.connect_existing_udp_socket(view);
r
UdpState::Bound => {}
UdpState::BindStarted | UdpState::Connecting | UdpState::ConnectReady => {
return Err(ErrorCode::ConcurrencyConflict.into())
}
};
UdpState::Connected => return Err(ErrorCode::AlreadyConnected.into()),
}

match r {
let socket = table.get_resource(&this)?;
let network = table.get_resource(&network)?;
let connecter = network.pool.udp_connecter(remote_address)?;

// Do an OS `connect`. Our socket is non-blocking, so it'll either...
let res = connecter.connect_existing_udp_socket(
&*socket
.udp_socket()
.as_socketlike_view::<cap_std::net::UdpSocket>(),
);
match res {
// succeed immediately,
Ok(()) => {
let socket = table.get_resource_mut(&this)?;
socket.udp_state = UdpState::ConnectReady;
return Ok(());
Ok(())
}
// continue in progress,
Err(err) if err.raw_os_error() == Some(INPROGRESS.raw_os_error()) => {}
Err(err) if err.raw_os_error() == Some(INPROGRESS.raw_os_error()) => {
let socket = table.get_resource_mut(&this)?;
socket.udp_state = UdpState::Connecting;
Ok(())
}
// or fail immediately.
Err(err) => return Err(err.into()),
Err(err) => Err(err.into()),
}

let socket = table.get_resource_mut(&this)?;
socket.udp_state = UdpState::Connecting;

Ok(())
}

fn finish_connect(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<()> {
let table = self.table_mut();
let socket = table.get_resource_mut(&this)?;

match socket.udp_state {
UdpState::ConnectReady => {}
UdpState::ConnectReady => {
socket.udp_state = UdpState::Connected;
Ok(())
}
UdpState::Connecting => {
// Do a `poll` to test for completion, using a timeout of zero
// to avoid blocking.
Expand All @@ -202,21 +154,21 @@ impl<T: WasiView> crate::preview2::host::udp::udp::HostUdpSocket for T {
0,
) {
Ok(0) => return Err(ErrorCode::WouldBlock.into()),
Ok(_) => (),
Err(err) => Err(err).unwrap(),
Ok(_) => {}
Err(err) => return Err(err.into()),
}

// Check whether the connect succeeded.
match sockopt::get_socket_error(socket.udp_socket()) {
Ok(Ok(())) => {}
Err(err) | Ok(Err(err)) => return Err(err.into()),
Ok(Ok(())) => {
socket.udp_state = UdpState::Connected;
Ok(())
}
Err(err) | Ok(Err(err)) => Err(err.into()),
}
}
_ => return Err(ErrorCode::NotInProgress.into()),
};

socket.udp_state = UdpState::Connected;
Ok(())
_ => Err(ErrorCode::NotInProgress.into()),
}
}

fn receive(
Expand All @@ -232,7 +184,7 @@ impl<T: WasiView> crate::preview2::host::udp::udp::HostUdpSocket for T {
let socket = table.get_resource(&this)?;

let udp_socket = socket.udp_socket();
let mut datagrams = Vec::with_capacity(max_results.try_into().unwrap_or(usize::MAX));
let mut datagrams = vec![];
let mut buf = [0; MAX_UDP_DATAGRAM_SIZE];
match socket.udp_state {
UdpState::Default | UdpState::BindStarted => return Err(ErrorCode::NotBound.into()),
Expand Down Expand Up @@ -352,8 +304,12 @@ impl<T: WasiView> crate::preview2::host::udp::udp::HostUdpSocket for T {
&mut self,
this: Resource<udp::UdpSocket>,
) -> Result<IpAddressFamily, anyhow::Error> {
let family = address_family(self.table(), this)?;
Ok(family)
let table = self.table();
let socket = table.get_resource(&this)?;
match socket.family {
AddressFamily::Ipv4 => Ok(IpAddressFamily::Ipv4),
AddressFamily::Ipv6 => Ok(IpAddressFamily::Ipv6),
}
}

fn ipv6_only(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<bool> {
Expand Down Expand Up @@ -477,12 +433,12 @@ impl<T: WasiView> crate::preview2::host::udp::udp::HostUdpSocket for T {
}
}

// On POSIX, non-blocking UDP socket `connect` uses `EINPROGRESS`.
// <https://pubs.opengroup.org/onlinepubs/9699919799/functions/connect.html>
#[cfg(not(windows))]
const INPROGRESS: Errno = Errno::INPROGRESS;

// On Windows, non-blocking UDP socket `connect` uses `WSAEWOULDBLOCK`.
// <https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-connect>
#[cfg(windows)]
const INPROGRESS: Errno = Errno::WOULDBLOCK;
const INPROGRESS: Errno = if cfg!(windows) {
// On Windows, non-blocking UDP socket `connect` uses `WSAEWOULDBLOCK`.
// <https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-connect>
Errno::WOULDBLOCK
} else {
// On POSIX, non-blocking UDP socket `connect` uses `EINPROGRESS`.
// <https://pubs.opengroup.org/onlinepubs/9699919799/functions/connect.html>
Errno::INPROGRESS
};
13 changes: 10 additions & 3 deletions crates/wasi/src/preview2/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,17 @@ pub struct UdpSocket {

/// The current state in the bind/connect progression.
pub(crate) udp_state: UdpState,

/// Socket address family.
pub(crate) family: AddressFamily,
}

#[async_trait]
impl Subscribe for UdpSocket {
async fn ready(&mut self) {
// Some states are ready immediately.
match self.udp_state {
UdpState::BindStarted | UdpState::ConnectReady => return,
UdpState::BindStarted => return,
_ => {}
}

Expand All @@ -68,16 +71,20 @@ impl UdpSocket {
// Create a new host socket and set it to non-blocking, which is needed
// by our async implementation.
let udp_socket = cap_std::net::UdpSocket::new(family, Blocking::No)?;
Self::from_udp_socket(udp_socket)
Self::from_udp_socket(udp_socket, family)
}

pub fn from_udp_socket(udp_socket: cap_std::net::UdpSocket) -> io::Result<Self> {
pub fn from_udp_socket(
udp_socket: cap_std::net::UdpSocket,
family: AddressFamily,
) -> io::Result<Self> {
let fd = udp_socket.into_raw_socketlike();
let std_socket = unsafe { std::net::UdpSocket::from_raw_socketlike(fd) };
let socket = with_ambient_tokio_runtime(|| tokio::net::UdpSocket::try_from(std_socket))?;
Ok(Self {
inner: Arc::new(socket),
udp_state: UdpState::Default,
family,
})
}

Expand Down

0 comments on commit 25f9bc2

Please sign in to comment.