Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(wasi-sockets): implement UDP #7148

Merged
merged 8 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions crates/test-programs/tests/wasi-sockets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ async fn tcp_sockopts() {
run("tcp_sockopts").await.unwrap();
}

#[test_log::test(tokio::test(flavor = "multi_thread"))]
async fn udp_sample_application() {
run("udp_sample_application").await.unwrap();
}

#[test_log::test(tokio::test(flavor = "multi_thread"))]
async fn ip_name_lookup() {
run("ip_name_lookup").await.unwrap();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
use wasi::sockets::network::{
IpAddressFamily, IpSocketAddress, Ipv4SocketAddress, Ipv6SocketAddress, Network,
};
use wasi::sockets::udp::{Datagram, UdpSocket};
use wasi_sockets_tests::*;

fn test_sample_application(family: IpAddressFamily, bind_address: IpSocketAddress) {
let first_message = &[];
let second_message = b"Hello, world!";
let third_message = b"Greetings, planet!";

let net = Network::default();

let server = UdpSocket::new(family).unwrap();

server.blocking_bind(&net, bind_address).unwrap();
let addr = server.local_address().unwrap();

let client_addr = {
let client = UdpSocket::new(family).unwrap();
client.blocking_connect(&net, addr).unwrap();

let datagrams = [
Datagram {
data: first_message.to_vec(),
remote_address: addr,
},
Datagram {
data: second_message.to_vec(),
remote_address: addr,
},
];
client.blocking_send(&datagrams).unwrap();

client.local_address().unwrap()
};

{
// Check that we've received our sent messages.
// Not guaranteed to work but should work in practice.
let datagrams = server.blocking_receive(2..100).unwrap();
assert_eq!(datagrams.len(), 2);

assert_eq!(datagrams[0].data, first_message);
assert_eq!(datagrams[0].remote_address, client_addr);

assert_eq!(datagrams[1].data, second_message);
assert_eq!(datagrams[1].remote_address, client_addr);
}

// Another client
{
let client = UdpSocket::new(family).unwrap();
client.blocking_connect(&net, addr).unwrap();

let datagrams = [Datagram {
data: third_message.to_vec(),
remote_address: addr,
}];
client.blocking_send(&datagrams).unwrap();
}

{
// Check that we sent and received our message!
let datagrams = server.blocking_receive(1..100).unwrap();
assert_eq!(datagrams.len(), 1);

assert_eq!(datagrams[0].data, third_message); // Not guaranteed to work but should work in practice.
}
}

fn main() {
test_sample_application(
IpAddressFamily::Ipv4,
IpSocketAddress::Ipv4(Ipv4SocketAddress {
port: 0, // use any free port
address: (127, 0, 0, 1), // localhost
}),
);
test_sample_application(
IpAddressFamily::Ipv6,
IpSocketAddress::Ipv6(Ipv6SocketAddress {
port: 0, // use any free port
address: (0, 0, 0, 0, 0, 0, 0, 1), // localhost
flow_info: 0,
scope_id: 0,
}),
);
}
125 changes: 124 additions & 1 deletion crates/test-programs/wasi-sockets-tests/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
wit_bindgen::generate!("test-command-with-sockets" in "../../wasi/wit");

use std::ops::Range;
use wasi::clocks::monotonic_clock;
use wasi::io::poll::{self, Pollable};
use wasi::io::streams::{InputStream, OutputStream, StreamError};
use wasi::sockets::instance_network;
Expand All @@ -8,12 +10,25 @@ use wasi::sockets::network::{
Network,
};
use wasi::sockets::tcp::TcpSocket;
use wasi::sockets::tcp_create_socket;
use wasi::sockets::udp::{Datagram, UdpSocket};
use wasi::sockets::{tcp_create_socket, udp_create_socket};

const TIMEOUT_NS: u64 = 1_000_000_000;

impl Pollable {
pub fn wait(&self) {
poll::poll_one(self);
}

pub fn wait_until(&self, timeout: &Pollable) -> Result<(), ErrorCode> {
let ready = poll::poll_list(&[self, timeout]);
assert!(ready.len() > 0);
match ready[0] {
0 => Ok(()),
1 => Err(ErrorCode::Timeout),
_ => unreachable!(),
}
}
}

impl OutputStream {
Expand Down Expand Up @@ -108,6 +123,89 @@ impl TcpSocket {
}
}

impl UdpSocket {
pub fn new(address_family: IpAddressFamily) -> Result<UdpSocket, ErrorCode> {
udp_create_socket::create_udp_socket(address_family)
}

pub fn blocking_bind(
&self,
network: &Network,
local_address: IpSocketAddress,
) -> Result<(), ErrorCode> {
let sub = self.subscribe();

self.start_bind(&network, local_address)?;

loop {
match self.finish_bind() {
Err(ErrorCode::WouldBlock) => sub.wait(),
result => return result,
}
}
}

pub fn blocking_connect(
&self,
network: &Network,
remote_address: IpSocketAddress,
) -> Result<(), ErrorCode> {
let sub = self.subscribe();

self.start_connect(&network, remote_address)?;

loop {
match self.finish_connect() {
Err(ErrorCode::WouldBlock) => sub.wait(),
result => return result,
}
}
}

pub fn blocking_send(&self, mut datagrams: &[Datagram]) -> Result<(), ErrorCode> {
let timeout = monotonic_clock::subscribe(TIMEOUT_NS, false);
let pollable = self.subscribe();

while !datagrams.is_empty() {
match self.send(datagrams) {
Ok(packets_sent) => {
datagrams = &datagrams[(packets_sent as usize)..];
}
Err(ErrorCode::WouldBlock) => pollable.wait_until(&timeout)?,
Err(err) => return Err(err),
}
}

Ok(())
}

pub fn blocking_receive(&self, count: Range<u64>) -> Result<Vec<Datagram>, ErrorCode> {
let timeout = monotonic_clock::subscribe(TIMEOUT_NS, false);
let pollable = self.subscribe();
let mut datagrams = vec![];

loop {
match self.receive(count.end - datagrams.len() as u64) {
Ok(mut chunk) => {
datagrams.append(&mut chunk);

if datagrams.len() >= count.start as usize {
return Ok(datagrams);
}
}
Err(ErrorCode::WouldBlock) => {
if datagrams.len() >= count.start as usize {
return Ok(datagrams);
} else {
pollable.wait_until(&timeout)?;
}
}
Err(err) => return Err(err),
}
}
}
}

impl IpAddress {
pub const IPV4_BROADCAST: IpAddress = IpAddress::Ipv4((255, 255, 255, 255));

Expand Down Expand Up @@ -189,3 +287,28 @@ impl IpSocketAddress {
}
}
}

impl PartialEq for Ipv4SocketAddress {
fn eq(&self, other: &Self) -> bool {
self.port == other.port && self.address == other.address
}
}

impl PartialEq for Ipv6SocketAddress {
fn eq(&self, other: &Self) -> bool {
self.port == other.port
&& self.flow_info == other.flow_info
&& self.address == other.address
&& self.scope_id == other.scope_id
}
}

impl PartialEq for IpSocketAddress {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::Ipv4(l0), Self::Ipv4(r0)) => l0 == r0,
(Self::Ipv6(l0), Self::Ipv6(r0)) => l0 == r0,
_ => false,
}
}
}
2 changes: 2 additions & 0 deletions crates/wasi-http/wit/test.wit
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ world test-command-with-sockets {
import wasi:cli/stderr;
import wasi:sockets/tcp;
import wasi:sockets/tcp-create-socket;
import wasi:sockets/udp;
import wasi:sockets/udp-create-socket;
import wasi:sockets/network;
import wasi:sockets/instance-network;
import wasi:sockets/ip-name-lookup;
Expand Down
5 changes: 5 additions & 0 deletions crates/wasi/src/preview2/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ pub fn add_to_linker<T: WasiView>(l: &mut wasmtime::component::Linker<T>) -> any
crate::preview2::bindings::cli::terminal_stderr::add_to_linker(l, |t| t)?;
crate::preview2::bindings::sockets::tcp::add_to_linker(l, |t| t)?;
crate::preview2::bindings::sockets::tcp_create_socket::add_to_linker(l, |t| t)?;
crate::preview2::bindings::sockets::udp::add_to_linker(l, |t| t)?;
crate::preview2::bindings::sockets::udp_create_socket::add_to_linker(l, |t| t)?;
crate::preview2::bindings::sockets::instance_network::add_to_linker(l, |t| t)?;
crate::preview2::bindings::sockets::network::add_to_linker(l, |t| t)?;
crate::preview2::bindings::sockets::ip_name_lookup::add_to_linker(l, |t| t)?;
Expand All @@ -65,6 +67,7 @@ pub mod sync {
"wasi:filesystem/types": crate::preview2::bindings::sync_io::filesystem::types,
"wasi:filesystem/preopens": crate::preview2::bindings::filesystem::preopens,
"wasi:sockets/tcp": crate::preview2::bindings::sockets::tcp,
"wasi:sockets/udp": crate::preview2::bindings::sockets::udp,
"wasi:clocks/monotonic_clock": crate::preview2::bindings::clocks::monotonic_clock,
"wasi:io/poll": crate::preview2::bindings::sync_io::io::poll,
"wasi:io/streams": crate::preview2::bindings::sync_io::io::streams,
Expand Down Expand Up @@ -107,6 +110,8 @@ pub mod sync {
crate::preview2::bindings::cli::terminal_stderr::add_to_linker(l, |t| t)?;
crate::preview2::bindings::sockets::tcp::add_to_linker(l, |t| t)?;
crate::preview2::bindings::sockets::tcp_create_socket::add_to_linker(l, |t| t)?;
crate::preview2::bindings::sockets::udp::add_to_linker(l, |t| t)?;
crate::preview2::bindings::sockets::udp_create_socket::add_to_linker(l, |t| t)?;
crate::preview2::bindings::sockets::instance_network::add_to_linker(l, |t| t)?;
crate::preview2::bindings::sockets::network::add_to_linker(l, |t| t)?;
crate::preview2::bindings::sockets::ip_name_lookup::add_to_linker(l, |t| t)?;
Expand Down
2 changes: 2 additions & 0 deletions crates/wasi/src/preview2/host/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ mod network;
mod random;
mod tcp;
mod tcp_create_socket;
mod udp;
mod udp_create_socket;
21 changes: 0 additions & 21 deletions crates/wasi/src/preview2/host/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -603,27 +603,6 @@ impl<T: WasiView> crate::preview2::host::tcp::tcp::HostTcpSocket for T {
// As in the filesystem implementation, we assume closing a socket
// doesn't block.
let dropped = table.delete_resource(this)?;

// If we might have an `event::poll` waiting on the socket, wake it up.
#[cfg(not(unix))]
{
match dropped.tcp_state {
TcpState::Default
| TcpState::BindStarted
| TcpState::Bound
| TcpState::ListenStarted
| TcpState::ConnectFailed
| TcpState::ConnectReady => {}

TcpState::Listening | TcpState::Connecting | TcpState::Connected => {
match rustix::net::shutdown(&*dropped.inner, rustix::net::Shutdown::ReadWrite) {
Ok(()) | Err(Errno::NOTCONN) => {}
Err(err) => Err(err).unwrap(),
}
}
}
}

drop(dropped);

Ok(())
Expand Down
Loading