diff --git a/crates/test-programs/tests/wasi-sockets.rs b/crates/test-programs/tests/wasi-sockets.rs index c94243a2a866..806d21c6c5db 100644 --- a/crates/test-programs/tests/wasi-sockets.rs +++ b/crates/test-programs/tests/wasi-sockets.rs @@ -91,6 +91,11 @@ async fn tcp_sockopts() { run("tcp_sockopts").await.unwrap(); } +#[test_log::test(tokio::test(flavor = "multi_thread"))] +async fn udp_sample_application() { + run("udp_sample_application").await.unwrap(); +} + #[test_log::test(tokio::test(flavor = "multi_thread"))] async fn ip_name_lookup() { run("ip_name_lookup").await.unwrap(); diff --git a/crates/test-programs/wasi-sockets-tests/src/bin/udp_sample_application.rs b/crates/test-programs/wasi-sockets-tests/src/bin/udp_sample_application.rs new file mode 100644 index 000000000000..7dc64fa586a6 --- /dev/null +++ b/crates/test-programs/wasi-sockets-tests/src/bin/udp_sample_application.rs @@ -0,0 +1,89 @@ +use wasi::sockets::network::{ + IpAddressFamily, IpSocketAddress, Ipv4SocketAddress, Ipv6SocketAddress, Network, +}; +use wasi::sockets::udp::{Datagram, UdpSocket}; +use wasi_sockets_tests::*; + +fn test_sample_application(family: IpAddressFamily, bind_address: IpSocketAddress) { + let first_message = &[]; + let second_message = b"Hello, world!"; + let third_message = b"Greetings, planet!"; + + let net = Network::default(); + + let server = UdpSocket::new(family).unwrap(); + + server.blocking_bind(&net, bind_address).unwrap(); + let addr = server.local_address().unwrap(); + + let client_addr = { + let client = UdpSocket::new(family).unwrap(); + client.blocking_connect(&net, addr).unwrap(); + + let datagrams = [ + Datagram { + data: first_message.to_vec(), + remote_address: addr, + }, + Datagram { + data: second_message.to_vec(), + remote_address: addr, + }, + ]; + client.blocking_send(&datagrams).unwrap(); + + client.local_address().unwrap() + }; + + { + // Check that we've received our sent messages. + // Not guaranteed to work but should work in practice. + let datagrams = server.blocking_receive(2..100).unwrap(); + assert_eq!(datagrams.len(), 2); + + assert_eq!(datagrams[0].data, first_message); + assert_eq!(datagrams[0].remote_address, client_addr); + + assert_eq!(datagrams[1].data, second_message); + assert_eq!(datagrams[1].remote_address, client_addr); + } + + // Another client + { + let client = UdpSocket::new(family).unwrap(); + client.blocking_connect(&net, addr).unwrap(); + + let datagrams = [Datagram { + data: third_message.to_vec(), + remote_address: addr, + }]; + client.blocking_send(&datagrams).unwrap(); + } + + { + // Check that we sent and received our message! + let datagrams = server.blocking_receive(1..100).unwrap(); + assert_eq!(datagrams.len(), 1); + + assert_eq!(datagrams[0].data, third_message); // Not guaranteed to work but should work in practice. + } +} + +fn main() { + test_sample_application( + IpAddressFamily::Ipv4, + IpSocketAddress::Ipv4(Ipv4SocketAddress { + port: 0, // use any free port + address: (127, 0, 0, 1), // localhost + }), + ); + test_sample_application( + IpAddressFamily::Ipv6, + IpSocketAddress::Ipv6(Ipv6SocketAddress { + port: 0, // use any free port + address: (0, 0, 0, 0, 0, 0, 0, 1), // localhost + flow_info: 0, + scope_id: 0, + }), + ); +} diff --git a/crates/test-programs/wasi-sockets-tests/src/lib.rs b/crates/test-programs/wasi-sockets-tests/src/lib.rs index 4d90e914ddda..258099e5a05c 100644 --- a/crates/test-programs/wasi-sockets-tests/src/lib.rs +++ b/crates/test-programs/wasi-sockets-tests/src/lib.rs @@ -1,5 +1,7 @@ wit_bindgen::generate!("test-command-with-sockets" in "../../wasi/wit"); +use std::ops::Range; +use wasi::clocks::monotonic_clock; use wasi::io::poll::{self, Pollable}; use wasi::io::streams::{InputStream, OutputStream, StreamError}; use wasi::sockets::instance_network; @@ -8,12 +10,25 @@ use wasi::sockets::network::{ Network, }; use wasi::sockets::tcp::TcpSocket; -use wasi::sockets::tcp_create_socket; +use wasi::sockets::udp::{Datagram, UdpSocket}; +use wasi::sockets::{tcp_create_socket, udp_create_socket}; + +const TIMEOUT_NS: u64 = 1_000_000_000; impl Pollable { pub fn wait(&self) { poll::poll_one(self); } + + pub fn wait_until(&self, timeout: &Pollable) -> Result<(), ErrorCode> { + let ready = poll::poll_list(&[self, timeout]); + assert!(ready.len() > 0); + match ready[0] { + 0 => Ok(()), + 1 => Err(ErrorCode::Timeout), + _ => unreachable!(), + } + } } impl OutputStream { @@ -108,6 +123,89 @@ impl TcpSocket { } } +impl UdpSocket { + pub fn new(address_family: IpAddressFamily) -> Result { + udp_create_socket::create_udp_socket(address_family) + } + + pub fn blocking_bind( + &self, + network: &Network, + local_address: IpSocketAddress, + ) -> Result<(), ErrorCode> { + let sub = self.subscribe(); + + self.start_bind(&network, local_address)?; + + loop { + match self.finish_bind() { + Err(ErrorCode::WouldBlock) => sub.wait(), + result => return result, + } + } + } + + pub fn blocking_connect( + &self, + network: &Network, + remote_address: IpSocketAddress, + ) -> Result<(), ErrorCode> { + let sub = self.subscribe(); + + self.start_connect(&network, remote_address)?; + + loop { + match self.finish_connect() { + Err(ErrorCode::WouldBlock) => sub.wait(), + result => return result, + } + } + } + + pub fn blocking_send(&self, mut datagrams: &[Datagram]) -> Result<(), ErrorCode> { + let timeout = monotonic_clock::subscribe(TIMEOUT_NS, false); + let pollable = self.subscribe(); + + while !datagrams.is_empty() { + match self.send(datagrams) { + Ok(packets_sent) => { + datagrams = &datagrams[(packets_sent as usize)..]; + } + Err(ErrorCode::WouldBlock) => pollable.wait_until(&timeout)?, + Err(err) => return Err(err), + } + } + + Ok(()) + } + + pub fn blocking_receive(&self, count: Range) -> Result, ErrorCode> { + let timeout = monotonic_clock::subscribe(TIMEOUT_NS, false); + let pollable = self.subscribe(); + let mut datagrams = vec![]; + + loop { + match self.receive(count.end - datagrams.len() as u64) { + Ok(mut chunk) => { + datagrams.append(&mut chunk); + + if datagrams.len() >= count.start as usize { + return Ok(datagrams); + } + } + Err(ErrorCode::WouldBlock) => { + if datagrams.len() >= count.start as usize { + return Ok(datagrams); + } else { + pollable.wait_until(&timeout)?; + } + } + Err(err) => return Err(err), + } + } + } +} + impl IpAddress { pub const IPV4_BROADCAST: IpAddress = IpAddress::Ipv4((255, 255, 255, 255)); @@ -189,3 +287,28 @@ impl IpSocketAddress { } } } + +impl PartialEq for Ipv4SocketAddress { + fn eq(&self, other: &Self) -> bool { + self.port == other.port && self.address == other.address + } +} + +impl PartialEq for Ipv6SocketAddress { + fn eq(&self, other: &Self) -> bool { + self.port == other.port + && self.flow_info == other.flow_info + && self.address == other.address + && self.scope_id == other.scope_id + } +} + +impl PartialEq for IpSocketAddress { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::Ipv4(l0), Self::Ipv4(r0)) => l0 == r0, + (Self::Ipv6(l0), Self::Ipv6(r0)) => l0 == r0, + _ => false, + } + } +} diff --git a/crates/wasi-http/wit/test.wit b/crates/wasi-http/wit/test.wit index fc9c357522bf..a0d1d07a6c64 100644 --- a/crates/wasi-http/wit/test.wit +++ b/crates/wasi-http/wit/test.wit @@ -37,6 +37,8 @@ world test-command-with-sockets { import wasi:cli/stderr; import wasi:sockets/tcp; import wasi:sockets/tcp-create-socket; + import wasi:sockets/udp; + import wasi:sockets/udp-create-socket; import wasi:sockets/network; import wasi:sockets/instance-network; import wasi:sockets/ip-name-lookup; diff --git a/crates/wasi/src/preview2/command.rs b/crates/wasi/src/preview2/command.rs index 811e3cf18e2c..898311157354 100644 --- a/crates/wasi/src/preview2/command.rs +++ b/crates/wasi/src/preview2/command.rs @@ -48,6 +48,8 @@ pub fn add_to_linker(l: &mut wasmtime::component::Linker) -> any crate::preview2::bindings::cli::terminal_stderr::add_to_linker(l, |t| t)?; crate::preview2::bindings::sockets::tcp::add_to_linker(l, |t| t)?; crate::preview2::bindings::sockets::tcp_create_socket::add_to_linker(l, |t| t)?; + crate::preview2::bindings::sockets::udp::add_to_linker(l, |t| t)?; + crate::preview2::bindings::sockets::udp_create_socket::add_to_linker(l, |t| t)?; crate::preview2::bindings::sockets::instance_network::add_to_linker(l, |t| t)?; crate::preview2::bindings::sockets::network::add_to_linker(l, |t| t)?; crate::preview2::bindings::sockets::ip_name_lookup::add_to_linker(l, |t| t)?; @@ -65,6 +67,7 @@ pub mod sync { "wasi:filesystem/types": crate::preview2::bindings::sync_io::filesystem::types, "wasi:filesystem/preopens": crate::preview2::bindings::filesystem::preopens, "wasi:sockets/tcp": crate::preview2::bindings::sockets::tcp, + "wasi:sockets/udp": crate::preview2::bindings::sockets::udp, "wasi:clocks/monotonic_clock": crate::preview2::bindings::clocks::monotonic_clock, "wasi:io/poll": crate::preview2::bindings::sync_io::io::poll, "wasi:io/streams": crate::preview2::bindings::sync_io::io::streams, @@ -107,6 +110,8 @@ pub mod sync { crate::preview2::bindings::cli::terminal_stderr::add_to_linker(l, |t| t)?; crate::preview2::bindings::sockets::tcp::add_to_linker(l, |t| t)?; crate::preview2::bindings::sockets::tcp_create_socket::add_to_linker(l, |t| t)?; + crate::preview2::bindings::sockets::udp::add_to_linker(l, |t| t)?; + crate::preview2::bindings::sockets::udp_create_socket::add_to_linker(l, |t| t)?; crate::preview2::bindings::sockets::instance_network::add_to_linker(l, |t| t)?; crate::preview2::bindings::sockets::network::add_to_linker(l, |t| t)?; crate::preview2::bindings::sockets::ip_name_lookup::add_to_linker(l, |t| t)?; diff --git a/crates/wasi/src/preview2/host/mod.rs b/crates/wasi/src/preview2/host/mod.rs index 138166731565..651d2cd38e0c 100644 --- a/crates/wasi/src/preview2/host/mod.rs +++ b/crates/wasi/src/preview2/host/mod.rs @@ -8,3 +8,5 @@ mod network; mod random; mod tcp; mod tcp_create_socket; +mod udp; +mod udp_create_socket; diff --git a/crates/wasi/src/preview2/host/tcp.rs b/crates/wasi/src/preview2/host/tcp.rs index 82e824b3a141..05a203a13f80 100644 --- a/crates/wasi/src/preview2/host/tcp.rs +++ b/crates/wasi/src/preview2/host/tcp.rs @@ -603,27 +603,6 @@ impl crate::preview2::host::tcp::tcp::HostTcpSocket for T { // As in the filesystem implementation, we assume closing a socket // doesn't block. let dropped = table.delete_resource(this)?; - - // If we might have an `event::poll` waiting on the socket, wake it up. - #[cfg(not(unix))] - { - match dropped.tcp_state { - TcpState::Default - | TcpState::BindStarted - | TcpState::Bound - | TcpState::ListenStarted - | TcpState::ConnectFailed - | TcpState::ConnectReady => {} - - TcpState::Listening | TcpState::Connecting | TcpState::Connected => { - match rustix::net::shutdown(&*dropped.inner, rustix::net::Shutdown::ReadWrite) { - Ok(()) | Err(Errno::NOTCONN) => {} - Err(err) => Err(err).unwrap(), - } - } - } - } - drop(dropped); Ok(()) diff --git a/crates/wasi/src/preview2/host/udp.rs b/crates/wasi/src/preview2/host/udp.rs new file mode 100644 index 000000000000..f05000ae676f --- /dev/null +++ b/crates/wasi/src/preview2/host/udp.rs @@ -0,0 +1,356 @@ +use std::net::SocketAddr; + +use crate::preview2::{ + bindings::{ + sockets::network::{ErrorCode, IpAddressFamily, IpSocketAddress, Network}, + sockets::udp, + }, + udp::UdpState, +}; +use crate::preview2::{Pollable, SocketResult, WasiView}; +use cap_net_ext::{AddressFamily, PoolExt}; +use io_lifetimes::AsSocketlike; +use rustix::io::Errno; +use rustix::net::sockopt; +use wasmtime::component::Resource; + +/// Theoretical maximum byte size of a UDP datagram, the real limit is lower, +/// but we do not account for e.g. the transport layer here for simplicity. +/// In practice, datagrams are typically less than 1500 bytes. +const MAX_UDP_DATAGRAM_SIZE: usize = 65535; + +impl udp::Host for T {} + +impl crate::preview2::host::udp::udp::HostUdpSocket for T { + fn start_bind( + &mut self, + this: Resource, + network: Resource, + local_address: IpSocketAddress, + ) -> SocketResult<()> { + let table = self.table_mut(); + let socket = table.get_resource(&this)?; + + match socket.udp_state { + UdpState::Default => {} + UdpState::BindStarted | UdpState::Connecting(..) => { + return Err(ErrorCode::ConcurrencyConflict.into()) + } + UdpState::Bound | UdpState::Connected(..) => return Err(ErrorCode::InvalidState.into()), + } + + let network = table.get_resource(&network)?; + let binder = network.pool.udp_binder(local_address)?; + + // Perform the OS bind call. + binder.bind_existing_udp_socket( + &*socket + .udp_socket() + .as_socketlike_view::(), + )?; + + let socket = table.get_resource_mut(&this)?; + socket.udp_state = UdpState::BindStarted; + + Ok(()) + } + + fn finish_bind(&mut self, this: Resource) -> SocketResult<()> { + let table = self.table_mut(); + let socket = table.get_resource_mut(&this)?; + + match socket.udp_state { + UdpState::BindStarted => { + socket.udp_state = UdpState::Bound; + Ok(()) + } + _ => Err(ErrorCode::NotInProgress.into()), + } + } + + fn start_connect( + &mut self, + this: Resource, + network: Resource, + remote_address: IpSocketAddress, + ) -> SocketResult<()> { + let table = self.table_mut(); + let socket = table.get_resource(&this)?; + let network = table.get_resource(&network)?; + + match socket.udp_state { + UdpState::Default | UdpState::Bound => {} + UdpState::BindStarted | UdpState::Connecting(..) => { + return Err(ErrorCode::ConcurrencyConflict.into()) + } + UdpState::Connected(..) => return Err(ErrorCode::InvalidState.into()), + } + + let connecter = network.pool.udp_connecter(remote_address)?; + + // Do an OS `connect`. + connecter.connect_existing_udp_socket( + &*socket + .udp_socket() + .as_socketlike_view::(), + )?; + + let socket = table.get_resource_mut(&this)?; + socket.udp_state = UdpState::Connecting(remote_address); + Ok(()) + } + + fn finish_connect(&mut self, this: Resource) -> SocketResult<()> { + let table = self.table_mut(); + let socket = table.get_resource_mut(&this)?; + + match socket.udp_state { + UdpState::Connecting(addr) => { + socket.udp_state = UdpState::Connected(addr); + Ok(()) + } + _ => Err(ErrorCode::NotInProgress.into()), + } + } + + fn receive( + &mut self, + this: Resource, + max_results: u64, + ) -> SocketResult> { + if max_results == 0 { + return Ok(vec![]); + } + + let table = self.table(); + let socket = table.get_resource(&this)?; + + let udp_socket = socket.udp_socket(); + let mut datagrams = vec![]; + let mut buf = [0; MAX_UDP_DATAGRAM_SIZE]; + match socket.udp_state { + UdpState::Default | UdpState::BindStarted => return Err(ErrorCode::InvalidState.into()), + UdpState::Bound | UdpState::Connecting(..) => { + for i in 0..max_results { + match udp_socket.try_recv_from(&mut buf) { + Ok((size, remote_address)) => datagrams.push(udp::Datagram { + data: buf[..size].into(), + remote_address: remote_address.into(), + }), + Err(_e) if i > 0 => { + return Ok(datagrams); + } + Err(e) => return Err(e.into()), + } + } + } + UdpState::Connected(remote_address) => { + for i in 0..max_results { + match udp_socket.try_recv(&mut buf) { + Ok(size) => datagrams.push(udp::Datagram { + data: buf[..size].into(), + remote_address, + }), + Err(_e) if i > 0 => { + return Ok(datagrams); + } + Err(e) => return Err(e.into()), + } + } + } + } + Ok(datagrams) + } + + fn send( + &mut self, + this: Resource, + datagrams: Vec, + ) -> SocketResult { + if datagrams.is_empty() { + return Ok(0); + }; + let table = self.table(); + let socket = table.get_resource(&this)?; + + let udp_socket = socket.udp_socket(); + let mut count = 0; + match socket.udp_state { + UdpState::Default | UdpState::BindStarted => return Err(ErrorCode::InvalidState.into()), + UdpState::Bound | UdpState::Connecting(..) => { + for udp::Datagram { + data, + remote_address, + } in datagrams + { + match udp_socket.try_send_to(&data, remote_address.into()) { + Ok(_size) => count += 1, + Err(_e) if count > 0 => { + return Ok(count); + } + Err(e) => return Err(e.into()), + } + } + } + UdpState::Connected(addr) => { + let addr = SocketAddr::from(addr); + for udp::Datagram { + data, + remote_address, + } in datagrams + { + if SocketAddr::from(remote_address) != addr { + // From WIT documentation: + // If at least one datagram has been sent successfully, this function never returns an error. + if count == 0 { + return Err(ErrorCode::InvalidArgument.into()); + } else { + return Ok(count); + } + } + match udp_socket.try_send(&data) { + Ok(_size) => count += 1, + Err(_e) if count > 0 => { + return Ok(count); + } + Err(e) => return Err(e.into()), + } + } + } + } + Ok(count) + } + + fn local_address(&mut self, this: Resource) -> SocketResult { + let table = self.table(); + let socket = table.get_resource(&this)?; + let addr = socket + .udp_socket() + .as_socketlike_view::() + .local_addr()?; + Ok(addr.into()) + } + + fn remote_address(&mut self, this: Resource) -> SocketResult { + let table = self.table(); + let socket = table.get_resource(&this)?; + let addr = socket + .udp_socket() + .as_socketlike_view::() + .peer_addr()?; + Ok(addr.into()) + } + + fn address_family( + &mut self, + this: Resource, + ) -> Result { + let table = self.table(); + let socket = table.get_resource(&this)?; + match socket.family { + AddressFamily::Ipv4 => Ok(IpAddressFamily::Ipv4), + AddressFamily::Ipv6 => Ok(IpAddressFamily::Ipv6), + } + } + + fn ipv6_only(&mut self, this: Resource) -> SocketResult { + let table = self.table(); + let socket = table.get_resource(&this)?; + Ok(sockopt::get_ipv6_v6only(socket.udp_socket())?) + } + + fn set_ipv6_only(&mut self, this: Resource, value: bool) -> SocketResult<()> { + let table = self.table(); + let socket = table.get_resource(&this)?; + Ok(sockopt::set_ipv6_v6only(socket.udp_socket(), value)?) + } + + fn unicast_hop_limit(&mut self, this: Resource) -> SocketResult { + let table = self.table(); + let socket = table.get_resource(&this)?; + + // We don't track whether the socket is IPv4 or IPv6 so try one and + // fall back to the other. + match sockopt::get_ipv6_unicast_hops(socket.udp_socket()) { + Ok(value) => Ok(value), + Err(Errno::NOPROTOOPT) => { + let value = sockopt::get_ip_ttl(socket.udp_socket())?; + let value = value.try_into().unwrap(); + Ok(value) + } + Err(err) => Err(err.into()), + } + } + + fn set_unicast_hop_limit( + &mut self, + this: Resource, + value: u8, + ) -> SocketResult<()> { + let table = self.table(); + let socket = table.get_resource(&this)?; + + // We don't track whether the socket is IPv4 or IPv6 so try one and + // fall back to the other. + match sockopt::set_ipv6_unicast_hops(socket.udp_socket(), Some(value)) { + Ok(()) => Ok(()), + Err(Errno::NOPROTOOPT) => Ok(sockopt::set_ip_ttl(socket.udp_socket(), value.into())?), + Err(err) => Err(err.into()), + } + } + + fn receive_buffer_size(&mut self, this: Resource) -> SocketResult { + let table = self.table(); + let socket = table.get_resource(&this)?; + Ok(sockopt::get_socket_recv_buffer_size(socket.udp_socket())? as u64) + } + + fn set_receive_buffer_size( + &mut self, + this: Resource, + value: u64, + ) -> SocketResult<()> { + let table = self.table(); + let socket = table.get_resource(&this)?; + let value = value.try_into().map_err(|_| ErrorCode::OutOfMemory)?; + Ok(sockopt::set_socket_recv_buffer_size( + socket.udp_socket(), + value, + )?) + } + + fn send_buffer_size(&mut self, this: Resource) -> SocketResult { + let table = self.table(); + let socket = table.get_resource(&this)?; + Ok(sockopt::get_socket_send_buffer_size(socket.udp_socket())? as u64) + } + + fn set_send_buffer_size( + &mut self, + this: Resource, + value: u64, + ) -> SocketResult<()> { + let table = self.table(); + let socket = table.get_resource(&this)?; + let value = value.try_into().map_err(|_| ErrorCode::OutOfMemory)?; + Ok(sockopt::set_socket_send_buffer_size( + socket.udp_socket(), + value, + )?) + } + + fn subscribe(&mut self, this: Resource) -> anyhow::Result> { + crate::preview2::poll::subscribe(self.table_mut(), this) + } + + fn drop(&mut self, this: Resource) -> Result<(), anyhow::Error> { + let table = self.table_mut(); + + // As in the filesystem implementation, we assume closing a socket + // doesn't block. + let dropped = table.delete_resource(this)?; + drop(dropped); + + Ok(()) + } +} diff --git a/crates/wasi/src/preview2/host/udp_create_socket.rs b/crates/wasi/src/preview2/host/udp_create_socket.rs new file mode 100644 index 000000000000..7e57e19d5297 --- /dev/null +++ b/crates/wasi/src/preview2/host/udp_create_socket.rs @@ -0,0 +1,15 @@ +use crate::preview2::bindings::{sockets::network::IpAddressFamily, sockets::udp_create_socket}; +use crate::preview2::udp::UdpSocket; +use crate::preview2::{SocketResult, WasiView}; +use wasmtime::component::Resource; + +impl udp_create_socket::Host for T { + fn create_udp_socket( + &mut self, + address_family: IpAddressFamily, + ) -> SocketResult> { + let socket = UdpSocket::new(address_family.into())?; + let socket = self.table_mut().push_resource(socket)?; + Ok(socket) + } +} diff --git a/crates/wasi/src/preview2/mod.rs b/crates/wasi/src/preview2/mod.rs index 8b188e686873..02e6ed405686 100644 --- a/crates/wasi/src/preview2/mod.rs +++ b/crates/wasi/src/preview2/mod.rs @@ -36,6 +36,7 @@ mod stdio; mod stream; mod table; mod tcp; +mod udp; mod write_stream; pub use self::clocks::{HostMonotonicClock, HostWallClock}; @@ -157,6 +158,7 @@ pub mod bindings { with: { "wasi:sockets/network/network": super::network::Network, "wasi:sockets/tcp/tcp-socket": super::tcp::TcpSocket, + "wasi:sockets/udp/udp-socket": super::udp::UdpSocket, "wasi:sockets/ip-name-lookup/resolve-address-stream": super::ip_name_lookup::ResolveAddressStream, "wasi:filesystem/types/directory-entry-stream": super::filesystem::ReaddirIterator, "wasi:filesystem/types/descriptor": super::filesystem::Descriptor, diff --git a/crates/wasi/src/preview2/udp.rs b/crates/wasi/src/preview2/udp.rs new file mode 100644 index 000000000000..2c879f99a6c9 --- /dev/null +++ b/crates/wasi/src/preview2/udp.rs @@ -0,0 +1,92 @@ +use crate::preview2::bindings::sockets::network::IpSocketAddress; +use crate::preview2::poll::Subscribe; +use crate::preview2::with_ambient_tokio_runtime; +use async_trait::async_trait; +use cap_net_ext::{AddressFamily, Blocking, UdpSocketExt}; +use io_lifetimes::raw::{FromRawSocketlike, IntoRawSocketlike}; +use std::io; +use std::sync::Arc; +use tokio::io::Interest; + +/// The state of a UDP socket. +/// +/// This represents the various states a socket can be in during the +/// activities of binding, and connecting. +pub(crate) enum UdpState { + /// The initial state for a newly-created socket. + Default, + + /// Binding started via `start_bind`. + BindStarted, + + /// Binding finished via `finish_bind`. The socket has an address but + /// is not yet listening for connections. + Bound, + + /// A connect call is in progress. + Connecting(IpSocketAddress), + + /// The socket is "connected" to a peer address. + Connected(IpSocketAddress), +} + +/// A host UDP socket, plus associated bookkeeping. +/// +/// The inner state is wrapped in an Arc because the same underlying socket is +/// used for implementing the stream types. +pub struct UdpSocket { + /// The part of a `UdpSocket` which is reference-counted so that we + /// can pass it to async tasks. + pub(crate) inner: Arc, + + /// The current state in the bind/connect progression. + pub(crate) udp_state: UdpState, + + /// Socket address family. + pub(crate) family: AddressFamily, +} + +#[async_trait] +impl Subscribe for UdpSocket { + async fn ready(&mut self) { + // Some states are ready immediately. + match self.udp_state { + UdpState::BindStarted => return, + _ => {} + } + + // FIXME: Add `Interest::ERROR` when we update to tokio 1.32. + self.inner + .ready(Interest::READABLE | Interest::WRITABLE) + .await + .expect("failed to await UDP socket readiness"); + } +} + +impl UdpSocket { + /// Create a new socket in the given family. + pub fn new(family: AddressFamily) -> io::Result { + // Create a new host socket and set it to non-blocking, which is needed + // by our async implementation. + let udp_socket = cap_std::net::UdpSocket::new(family, Blocking::No)?; + Self::from_udp_socket(udp_socket, family) + } + + pub fn from_udp_socket( + udp_socket: cap_std::net::UdpSocket, + family: AddressFamily, + ) -> io::Result { + let fd = udp_socket.into_raw_socketlike(); + let std_socket = unsafe { std::net::UdpSocket::from_raw_socketlike(fd) }; + let socket = with_ambient_tokio_runtime(|| tokio::net::UdpSocket::try_from(std_socket))?; + Ok(Self { + inner: Arc::new(socket), + udp_state: UdpState::Default, + family, + }) + } + + pub fn udp_socket(&self) -> &tokio::net::UdpSocket { + &self.inner + } +} diff --git a/crates/wasi/wit/test.wit b/crates/wasi/wit/test.wit index fc9c357522bf..a0d1d07a6c64 100644 --- a/crates/wasi/wit/test.wit +++ b/crates/wasi/wit/test.wit @@ -37,6 +37,8 @@ world test-command-with-sockets { import wasi:cli/stderr; import wasi:sockets/tcp; import wasi:sockets/tcp-create-socket; + import wasi:sockets/udp; + import wasi:sockets/udp-create-socket; import wasi:sockets/network; import wasi:sockets/instance-network; import wasi:sockets/ip-name-lookup;