diff --git a/Cargo.toml b/Cargo.toml index 02a44d46..5e44ed7d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ members = [ "compio-tls", "compio-log", "compio-process", + "compio-quic", ] resolver = "2" @@ -36,7 +37,9 @@ compio-dispatcher = { path = "./compio-dispatcher", version = "0.3.0" } compio-log = { path = "./compio-log", version = "0.1.0" } compio-tls = { path = "./compio-tls", version = "0.2.0", default-features = false } compio-process = { path = "./compio-process", version = "0.1.0" } +compio-quic = { path = "./compio-quic", version = "0.1.0" } +bytes = "1.7.1" flume = "0.11.0" cfg-if = "1.0.0" criterion = "0.5.1" @@ -49,10 +52,13 @@ nix = "0.29.0" once_cell = "1.18.0" os_pipe = "1.1.4" paste = "1.0.14" +rand = "0.8.5" +rustls = { version = "0.23.1", default-features = false } slab = "0.4.9" socket2 = "0.5.6" tempfile = "3.8.1" tokio = "1.33.0" +tracing-subscriber = "0.3.18" widestring = "1.0.2" windows-sys = "0.52.0" diff --git a/compio-buf/Cargo.toml b/compio-buf/Cargo.toml index 5644c243..16c34ed4 100644 --- a/compio-buf/Cargo.toml +++ b/compio-buf/Cargo.toml @@ -17,7 +17,7 @@ rustdoc-args = ["--cfg", "docsrs"] [dependencies] bumpalo = { version = "3.14.0", optional = true } arrayvec = { version = "0.7.4", optional = true } -bytes = { version = "1.5.0", optional = true } +bytes = { workspace = true, optional = true } [target.'cfg(unix)'.dependencies] libc = { workspace = true } diff --git a/compio-driver/src/iocp/op.rs b/compio-driver/src/iocp/op.rs index 8369b234..672316ae 100644 --- a/compio-driver/src/iocp/op.rs +++ b/compio-driver/src/iocp/op.rs @@ -781,12 +781,11 @@ static WSA_RECVMSG: OnceLock = OnceLock::new(); /// Receive data and source address with ancillary data into vectored buffer. pub struct RecvMsg { + msg: WSAMSG, addr: SOCKADDR_STORAGE, - addr_len: socklen_t, fd: SharedFd, buffer: T, control: C, - control_len: u32, _p: PhantomPinned, } @@ -802,12 +801,11 @@ impl RecvMsg { "misaligned control message buffer" ); Self { + msg: unsafe { std::mem::zeroed() }, addr: unsafe { std::mem::zeroed() }, - addr_len: std::mem::size_of::() as _, fd, buffer, control, - control_len: 0, _p: PhantomPinned, } } @@ -820,8 +818,8 @@ impl IntoInner for RecvMsg { ( (self.buffer, self.control), self.addr, - self.addr_len, - self.control_len as _, + self.msg.namelen, + self.msg.Control.len as _, ) } } @@ -835,26 +833,23 @@ impl OpCode for RecvMsg { })?; let this = self.get_unchecked_mut(); + let mut slices = this.buffer.io_slices_mut(); - let mut msg = WSAMSG { - name: &mut this.addr as *mut _ as _, - namelen: this.addr_len, - lpBuffers: slices.as_mut_ptr() as _, - dwBufferCount: slices.len() as _, - Control: std::mem::transmute::(this.control.as_io_slice_mut()), - dwFlags: 0, - }; - this.control_len = 0; + this.msg.name = &mut this.addr as *mut _ as _; + this.msg.namelen = std::mem::size_of::() as _; + this.msg.lpBuffers = slices.as_mut_ptr() as _; + this.msg.dwBufferCount = slices.len() as _; + this.msg.Control = + std::mem::transmute::(this.control.as_io_slice_mut()); let mut received = 0; let res = recvmsg_fn( this.fd.as_raw_fd() as _, - &mut msg, + &mut this.msg, &mut received, optr, None, ); - this.control_len = msg.Control.len; winsock_result(res, received) } diff --git a/compio-log/Cargo.toml b/compio-log/Cargo.toml index bc07ac8b..eb1e26b8 100644 --- a/compio-log/Cargo.toml +++ b/compio-log/Cargo.toml @@ -13,7 +13,7 @@ repository = { workspace = true } tracing = { version = "0.1", default-features = false } [dev-dependencies] -tracing-subscriber = "0.3" +tracing-subscriber = { workspace = true } [features] enable_log = [] diff --git a/compio-net/src/socket.rs b/compio-net/src/socket.rs index 36c7f53f..1d985e63 100644 --- a/compio-net/src/socket.rs +++ b/compio-net/src/socket.rs @@ -1,4 +1,8 @@ -use std::{future::Future, io, mem::ManuallyDrop}; +use std::{ + future::Future, + io, + mem::{ManuallyDrop, MaybeUninit}, +}; use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut}; #[cfg(unix)] @@ -325,7 +329,51 @@ impl Socket { } #[cfg(unix)] - pub fn set_socket_option(&self, level: i32, name: i32, value: &T) -> io::Result<()> { + pub unsafe fn get_socket_option(&self, level: i32, name: i32) -> io::Result { + let mut value: MaybeUninit = MaybeUninit::uninit(); + let mut len = size_of::() as libc::socklen_t; + syscall!(libc::getsockopt( + self.socket.as_raw_fd(), + level, + name, + value.as_mut_ptr() as _, + &mut len + )) + .map(|_| { + debug_assert_eq!(len as usize, size_of::()); + // SAFETY: The value is initialized by `getsockopt`. + value.assume_init() + }) + } + + #[cfg(windows)] + pub unsafe fn get_socket_option(&self, level: i32, name: i32) -> io::Result { + let mut value: MaybeUninit = MaybeUninit::uninit(); + let mut len = size_of::() as i32; + syscall!( + SOCKET, + windows_sys::Win32::Networking::WinSock::getsockopt( + self.socket.as_raw_fd() as _, + level, + name, + value.as_mut_ptr() as _, + &mut len + ) + ) + .map(|_| { + debug_assert_eq!(len as usize, size_of::()); + // SAFETY: The value is initialized by `getsockopt`. + value.assume_init() + }) + } + + #[cfg(unix)] + pub unsafe fn set_socket_option( + &self, + level: i32, + name: i32, + value: &T, + ) -> io::Result<()> { syscall!(libc::setsockopt( self.socket.as_raw_fd(), level, @@ -337,7 +385,12 @@ impl Socket { } #[cfg(windows)] - pub fn set_socket_option(&self, level: i32, name: i32, value: &T) -> io::Result<()> { + pub unsafe fn set_socket_option( + &self, + level: i32, + name: i32, + value: &T, + ) -> io::Result<()> { syscall!( SOCKET, windows_sys::Win32::Networking::WinSock::setsockopt( diff --git a/compio-net/src/udp.rs b/compio-net/src/udp.rs index 13e59d73..33f39c2d 100644 --- a/compio-net/src/udp.rs +++ b/compio-net/src/udp.rs @@ -316,8 +316,26 @@ impl UdpSocket { .await } + /// Gets a socket option. + /// + /// # Safety + /// + /// The caller must ensure `T` is the correct type for `level` and `name`. + pub unsafe fn get_socket_option(&self, level: i32, name: i32) -> io::Result { + self.inner.get_socket_option(level, name) + } + /// Sets a socket option. - pub fn set_socket_option(&self, level: i32, name: i32, value: &T) -> io::Result<()> { + /// + /// # Safety + /// + /// The caller must ensure `T` is the correct type for `level` and `name`. + pub unsafe fn set_socket_option( + &self, + level: i32, + name: i32, + value: &T, + ) -> io::Result<()> { self.inner.set_socket_option(level, name, value) } } diff --git a/compio-net/tests/udp.rs b/compio-net/tests/udp.rs index d813290e..699dec25 100644 --- a/compio-net/tests/udp.rs +++ b/compio-net/tests/udp.rs @@ -1,4 +1,4 @@ -use compio_net::{CMsgBuilder, CMsgIter, UdpSocket}; +use compio_net::UdpSocket; #[compio_macros::test] async fn connect() { @@ -64,57 +64,3 @@ async fn send_to() { active_addr ); } - -#[compio_macros::test] -async fn send_msg_with_ipv6_ecn() { - #[cfg(unix)] - use libc::{IPPROTO_IPV6, IPV6_RECVTCLASS, IPV6_TCLASS}; - #[cfg(windows)] - use windows_sys::Win32::Networking::WinSock::{ - IPPROTO_IPV6, IPV6_ECN, IPV6_RECVTCLASS, IPV6_TCLASS, - }; - - const MSG: &str = "foo bar baz"; - - let passive = UdpSocket::bind("[::1]:0").await.unwrap(); - let passive_addr = passive.local_addr().unwrap(); - - passive - .set_socket_option(IPPROTO_IPV6, IPV6_RECVTCLASS, &1) - .unwrap(); - - let active = UdpSocket::bind("[::1]:0").await.unwrap(); - let active_addr = active.local_addr().unwrap(); - - let mut control = vec![0u8; 32]; - let mut builder = CMsgBuilder::new(&mut control); - - const ECN_BITS: i32 = 0b10; - - #[cfg(unix)] - builder - .try_push(IPPROTO_IPV6, IPV6_TCLASS, ECN_BITS) - .unwrap(); - #[cfg(windows)] - builder.try_push(IPPROTO_IPV6, IPV6_ECN, ECN_BITS).unwrap(); - - let len = builder.finish(); - control.truncate(len); - - active.send_msg(MSG, control, passive_addr).await.unwrap(); - - let ((_, _, addr), (buffer, control)) = passive - .recv_msg(Vec::with_capacity(20), Vec::with_capacity(32)) - .await - .unwrap(); - assert_eq!(addr, active_addr); - assert_eq!(buffer, MSG.as_bytes()); - unsafe { - let mut iter = CMsgIter::new(&control); - let cmsg = iter.next().unwrap(); - assert_eq!(cmsg.level(), IPPROTO_IPV6); - assert_eq!(cmsg.ty(), IPV6_TCLASS); - assert_eq!(cmsg.data::(), &ECN_BITS); - assert!(iter.next().is_none()); - } -} diff --git a/compio-quic/Cargo.toml b/compio-quic/Cargo.toml new file mode 100644 index 00000000..00575023 --- /dev/null +++ b/compio-quic/Cargo.toml @@ -0,0 +1,82 @@ +[package] +name = "compio-quic" +version = "0.1.0" +description = "QUIC for compio" +categories = ["asynchronous", "network-programming"] +keywords = ["async", "net", "quic"] +edition = { workspace = true } +authors = { workspace = true } +readme = { workspace = true } +license = { workspace = true } +repository = { workspace = true } + +[package.metadata.docs.rs] +all-features = true +rustdoc-args = ["--cfg", "docsrs"] + +[dependencies] +# Workspace dependencies +compio-io = { workspace = true } +compio-buf = { workspace = true } +compio-log = { workspace = true } +compio-net = { workspace = true } +compio-runtime = { workspace = true, features = ["time"] } + +quinn-proto = "0.11.3" +rustls = { workspace = true } +rustls-platform-verifier = { version = "0.3.3", optional = true } +rustls-native-certs = { version = "0.7.1", optional = true } +webpki-roots = { version = "0.26.3", optional = true } +h3 = { version = "0.0.6", optional = true } + +# Utils +bytes = { workspace = true } +flume = { workspace = true } +futures-util = { workspace = true } +rustc-hash = "2.0.0" +thiserror = "1.0.63" + +# Windows specific dependencies +[target.'cfg(windows)'.dependencies] +windows-sys = { workspace = true, features = ["Win32_Networking_WinSock"] } + +[target.'cfg(unix)'.dependencies] +libc = { workspace = true } + +[dev-dependencies] +compio-buf = { workspace = true, features = ["bytes"] } +compio-dispatcher = { workspace = true } +compio-driver = { workspace = true } +compio-fs = { workspace = true } +compio-macros = { workspace = true } +compio-runtime = { workspace = true, features = ["criterion"] } + +criterion = { workspace = true, features = ["async_tokio"] } +http = "1.1.0" +quinn = "0.11.3" +rand = { workspace = true } +rcgen = "0.13.1" +socket2 = { workspace = true, features = ["all"] } +tokio = { workspace = true, features = ["rt", "macros"] } +tracing-subscriber = { workspace = true, features = ["env-filter"] } + +[features] +default = [] +io-compat = ["futures-util/io"] +platform-verifier = ["dep:rustls-platform-verifier"] +native-certs = ["dep:rustls-native-certs"] +webpki-roots = ["dep:webpki-roots"] +h3 = ["dep:h3"] +# FIXME: see https://github.com/quinn-rs/quinn/pull/1962 + +[[example]] +name = "http3-client" +required-features = ["h3"] + +[[example]] +name = "http3-server" +required-features = ["h3"] + +[[bench]] +name = "quic" +harness = false diff --git a/compio-quic/benches/quic.rs b/compio-quic/benches/quic.rs new file mode 100644 index 00000000..66694da5 --- /dev/null +++ b/compio-quic/benches/quic.rs @@ -0,0 +1,193 @@ +use std::{ + net::{IpAddr, Ipv4Addr, SocketAddr}, + sync::Arc, + time::Instant, +}; + +use bytes::Bytes; +use criterion::{criterion_group, criterion_main, Bencher, BenchmarkId, Criterion, Throughput}; +use futures_util::{stream::FuturesUnordered, StreamExt}; +use rand::{thread_rng, RngCore}; + +criterion_group!(quic, echo); +criterion_main!(quic); + +fn gen_cert() -> ( + rustls::pki_types::CertificateDer<'static>, + rustls::pki_types::PrivateKeyDer<'static>, +) { + let rcgen::CertifiedKey { cert, key_pair } = + rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(); + let cert = cert.der().clone(); + let key_der = key_pair.serialize_der().try_into().unwrap(); + (cert, key_der) +} + +macro_rules! echo_impl { + ($send:ident, $recv:ident) => { + loop { + // These are 32 buffers, for reading approximately 32kB at once + let mut bufs: [Bytes; 32] = std::array::from_fn(|_| Bytes::new()); + + match $recv.read_chunks(&mut bufs).await.unwrap() { + Some(n) => { + $send.write_all_chunks(&mut bufs[..n]).await.unwrap(); + } + None => break, + } + } + + let _ = $send.finish(); + }; +} + +fn echo_compio_quic(b: &mut Bencher, content: &[u8], streams: usize) { + use compio_quic::{ClientBuilder, ServerBuilder}; + + let runtime = compio_runtime::Runtime::new().unwrap(); + b.to_async(runtime).iter_custom(|iter| async move { + let (cert, key_der) = gen_cert(); + let server = ServerBuilder::new_with_single_cert(vec![cert.clone()], key_der) + .unwrap() + .bind("127.0.0.1:0") + .await + .unwrap(); + let client = ClientBuilder::new_with_empty_roots() + .with_custom_certificate(cert) + .unwrap() + .with_no_crls() + .bind("127.0.0.1:0") + .await + .unwrap(); + let addr = server.local_addr().unwrap(); + + let (client_conn, server_conn) = futures_util::join!( + async move { + client + .connect(addr, "localhost", None) + .unwrap() + .await + .unwrap() + }, + async move { server.wait_incoming().await.unwrap().await.unwrap() } + ); + + let start = Instant::now(); + let handle = compio_runtime::spawn(async move { + while let Ok((mut send, mut recv)) = server_conn.accept_bi().await { + compio_runtime::spawn(async move { + echo_impl!(send, recv); + }) + .detach(); + } + }); + for _i in 0..iter { + let mut futures = (0..streams) + .map(|_| async { + let (mut send, mut recv) = client_conn.open_bi_wait().await.unwrap(); + futures_util::join!( + async { + send.write_all(content).await.unwrap(); + send.finish().unwrap(); + }, + async { + let mut buf = vec![]; + recv.read_to_end(&mut buf).await.unwrap(); + } + ); + }) + .collect::>(); + while futures.next().await.is_some() {} + } + drop(handle); + start.elapsed() + }) +} + +fn echo_quinn(b: &mut Bencher, content: &[u8], streams: usize) { + use quinn::{ClientConfig, Endpoint, ServerConfig}; + + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + b.to_async(&runtime).iter_custom(|iter| async move { + let (cert, key_der) = gen_cert(); + let server_config = ServerConfig::with_single_cert(vec![cert.clone()], key_der).unwrap(); + let mut roots = rustls::RootCertStore::empty(); + roots.add(cert).unwrap(); + let client_config = ClientConfig::with_root_certificates(Arc::new(roots)).unwrap(); + let server = Endpoint::server( + server_config, + SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0), + ) + .unwrap(); + let mut client = + Endpoint::client(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0)).unwrap(); + client.set_default_client_config(client_config); + let addr = server.local_addr().unwrap(); + + let (client_conn, server_conn) = tokio::join!( + async move { client.connect(addr, "localhost").unwrap().await.unwrap() }, + async move { server.accept().await.unwrap().await.unwrap() } + ); + + let start = Instant::now(); + let handle = tokio::spawn(async move { + while let Ok((mut send, mut recv)) = server_conn.accept_bi().await { + tokio::spawn(async move { + echo_impl!(send, recv); + }); + } + }); + for _i in 0..iter { + let mut futures = (0..streams) + .map(|_| async { + let (mut send, mut recv) = client_conn.open_bi().await.unwrap(); + tokio::join!( + async { + send.write_all(content).await.unwrap(); + send.finish().unwrap(); + }, + async { + recv.read_to_end(usize::MAX).await.unwrap(); + } + ); + }) + .collect::>(); + while futures.next().await.is_some() {} + } + handle.abort(); + start.elapsed() + }); +} + +const DATA_SIZES: &[usize] = &[1, 10, 1024, 1200, 1024 * 16, 1024 * 128]; +const STREAMS: &[usize] = &[1, 10, 100]; + +fn echo(c: &mut Criterion) { + let mut rng = thread_rng(); + + let mut data = vec![0u8; *DATA_SIZES.last().unwrap()]; + rng.fill_bytes(&mut data); + + let mut group = c.benchmark_group("echo"); + for &size in DATA_SIZES { + let data = &data[..size]; + for &streams in STREAMS { + group.throughput(Throughput::Bytes((data.len() * streams * 2) as u64)); + + group.bench_with_input( + BenchmarkId::new("compio-quic", format!("{}-streams-{}-bytes", streams, size)), + &(), + |b, _| echo_compio_quic(b, data, streams), + ); + group.bench_with_input( + BenchmarkId::new("quinn", format!("{}-streams-{}-bytes", streams, size)), + &(), + |b, _| echo_quinn(b, data, streams), + ); + } + } + group.finish(); +} diff --git a/compio-quic/examples/http3-client.rs b/compio-quic/examples/http3-client.rs new file mode 100644 index 00000000..1ddc6936 --- /dev/null +++ b/compio-quic/examples/http3-client.rs @@ -0,0 +1,84 @@ +use std::{ + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, + path::PathBuf, + str::FromStr, +}; + +use bytes::Buf; +use compio_io::AsyncWriteAtExt; +use compio_net::ToSocketAddrsAsync; +use compio_quic::ClientBuilder; +use http::{Request, Uri}; +use tracing_subscriber::EnvFilter; + +#[compio_macros::main] +async fn main() { + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .init(); + + let args = std::env::args().collect::>(); + if args.len() != 3 { + eprintln!("Usage: {} ", args[0]); + std::process::exit(1); + } + + let uri = Uri::from_str(&args[1]).unwrap(); + let outpath = PathBuf::from(&args[2]); + + let host = uri.host().unwrap(); + let remote = (host, uri.port_u16().unwrap_or(443)) + .to_socket_addrs_async() + .await + .unwrap() + .next() + .unwrap(); + + let endpoint = ClientBuilder::new_with_no_server_verification() + .with_key_log() + .with_alpn_protocols(&["h3"]) + .bind(SocketAddr::new( + if remote.is_ipv6() { + IpAddr::V6(Ipv6Addr::UNSPECIFIED) + } else { + IpAddr::V4(Ipv4Addr::UNSPECIFIED) + }, + 0, + )) + .await + .unwrap(); + + { + println!("Connecting to {} at {}", host, remote); + let conn = endpoint.connect(remote, host, None).unwrap().await.unwrap(); + + let (mut conn, mut send_req) = compio_quic::h3::client::new(conn).await.unwrap(); + let handle = compio_runtime::spawn(async move { conn.wait_idle().await }); + + let req = Request::get(uri).body(()).unwrap(); + let mut stream = send_req.send_request(req).await.unwrap(); + stream.finish().await.unwrap(); + + let resp = stream.recv_response().await.unwrap(); + println!("{:?}", resp); + + let mut out = compio_fs::File::create(outpath).await.unwrap(); + let mut pos = 0; + while let Some(mut chunk) = stream.recv_data().await.unwrap() { + let len = chunk.remaining(); + out.write_all_at(chunk.copy_to_bytes(len), pos) + .await + .unwrap(); + pos += len as u64; + } + if let Some(headers) = stream.recv_trailers().await.unwrap() { + println!("{:?}", headers); + } + + drop(send_req); + + handle.await.unwrap().unwrap(); + } + + endpoint.shutdown().await.unwrap(); +} diff --git a/compio-quic/examples/http3-server.rs b/compio-quic/examples/http3-server.rs new file mode 100644 index 00000000..96450910 --- /dev/null +++ b/compio-quic/examples/http3-server.rs @@ -0,0 +1,57 @@ +use bytes::Bytes; +use compio_quic::ServerBuilder; +use http::{HeaderMap, Response}; +use tracing_subscriber::EnvFilter; + +#[compio_macros::main] +async fn main() { + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .init(); + + let rcgen::CertifiedKey { cert, key_pair } = + rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(); + let cert = cert.der().clone(); + let key_der = key_pair.serialize_der().try_into().unwrap(); + + let endpoint = ServerBuilder::new_with_single_cert(vec![cert], key_der) + .unwrap() + .with_key_log() + .with_alpn_protocols(&["h3"]) + .bind("[::1]:4433") + .await + .unwrap(); + + while let Some(incoming) = endpoint.wait_incoming().await { + compio_runtime::spawn(async move { + let conn = incoming.await.unwrap(); + println!("Accepted connection from {}", conn.remote_address()); + + let mut conn = compio_quic::h3::server::builder() + .build::<_, Bytes>(conn) + .await + .unwrap(); + + while let Ok(Some((req, mut stream))) = conn.accept().await { + println!("Received request: {:?}", req); + stream + .send_response( + Response::builder() + .header("server", "compio-quic") + .body(()) + .unwrap(), + ) + .await + .unwrap(); + stream + .send_data("hello from compio-quic".into()) + .await + .unwrap(); + let mut headers = HeaderMap::new(); + headers.insert("msg", "byebye".parse().unwrap()); + stream.send_trailers(headers).await.unwrap(); + } + }) + .detach(); + } +} diff --git a/compio-quic/examples/quic-client.rs b/compio-quic/examples/quic-client.rs new file mode 100644 index 00000000..85fb23be --- /dev/null +++ b/compio-quic/examples/quic-client.rs @@ -0,0 +1,41 @@ +use std::net::{IpAddr, Ipv6Addr, SocketAddr}; + +use compio_quic::ClientBuilder; +use tracing_subscriber::EnvFilter; + +#[compio_macros::main] +async fn main() { + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .init(); + + let endpoint = ClientBuilder::new_with_no_server_verification() + .with_key_log() + .bind("[::1]:0") + .await + .unwrap(); + + { + let conn = endpoint + .connect( + SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 4433), + "localhost", + None, + ) + .unwrap() + .await + .unwrap(); + + let (mut send, mut recv) = conn.open_bi().unwrap(); + send.write(&[1, 2, 3]).await.unwrap(); + send.finish().unwrap(); + + let mut buf = vec![]; + recv.read_to_end(&mut buf).await.unwrap(); + println!("{:?}", buf); + + conn.close(1u32.into(), b"bye"); + } + + endpoint.shutdown().await.unwrap(); +} diff --git a/compio-quic/examples/quic-dispatcher.rs b/compio-quic/examples/quic-dispatcher.rs new file mode 100644 index 00000000..851debcf --- /dev/null +++ b/compio-quic/examples/quic-dispatcher.rs @@ -0,0 +1,75 @@ +use std::num::NonZeroUsize; + +use compio_dispatcher::Dispatcher; +use compio_quic::{ClientBuilder, Endpoint, ServerBuilder}; +use compio_runtime::spawn; +use futures_util::{stream::FuturesUnordered, StreamExt}; + +#[compio_macros::main] +async fn main() { + const THREAD_NUM: usize = 5; + const CLIENT_NUM: usize = 10; + + let rcgen::CertifiedKey { cert, key_pair } = + rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(); + let cert = cert.der().clone(); + let key_der = key_pair.serialize_der().try_into().unwrap(); + + let server_config = ServerBuilder::new_with_single_cert(vec![cert.clone()], key_der) + .unwrap() + .build(); + let client_config = ClientBuilder::new_with_empty_roots() + .with_custom_certificate(cert) + .unwrap() + .with_no_crls() + .build(); + let mut endpoint = Endpoint::server("127.0.0.1:0", server_config) + .await + .unwrap(); + endpoint.default_client_config = Some(client_config); + + spawn({ + let endpoint = endpoint.clone(); + async move { + let mut futures = FuturesUnordered::from_iter((0..CLIENT_NUM).map(|i| { + let endpoint = &endpoint; + async move { + let conn = endpoint + .connect(endpoint.local_addr().unwrap(), "localhost", None) + .unwrap() + .await + .unwrap(); + let mut send = conn.open_uni().unwrap(); + send.write_all(format!("Hello world {}!", i).as_bytes()) + .await + .unwrap(); + send.finish().unwrap(); + send.stopped().await.unwrap(); + } + })); + while let Some(()) = futures.next().await {} + } + }) + .detach(); + + let dispatcher = Dispatcher::builder() + .worker_threads(NonZeroUsize::new(THREAD_NUM).unwrap()) + .build() + .unwrap(); + let mut handles = FuturesUnordered::new(); + for _i in 0..CLIENT_NUM { + let incoming = endpoint.wait_incoming().await.unwrap(); + let handle = dispatcher + .dispatch(move || async move { + let conn = incoming.await.unwrap(); + let mut recv = conn.accept_uni().await.unwrap(); + let mut buf = vec![]; + recv.read_to_end(&mut buf).await.unwrap(); + println!("{}", std::str::from_utf8(&buf).unwrap()); + }) + .unwrap(); + handles.push(handle); + } + while handles.next().await.is_some() {} + dispatcher.join().await.unwrap(); +} diff --git a/compio-quic/examples/quic-server.rs b/compio-quic/examples/quic-server.rs new file mode 100644 index 00000000..20b6c01b --- /dev/null +++ b/compio-quic/examples/quic-server.rs @@ -0,0 +1,39 @@ +use compio_quic::ServerBuilder; +use tracing_subscriber::EnvFilter; + +#[compio_macros::main] +async fn main() { + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .init(); + + let rcgen::CertifiedKey { cert, key_pair } = + rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(); + let cert = cert.der().clone(); + let key_der = key_pair.serialize_der().try_into().unwrap(); + + let endpoint = ServerBuilder::new_with_single_cert(vec![cert], key_der) + .unwrap() + .with_key_log() + .bind("[::1]:4433") + .await + .unwrap(); + + if let Some(incoming) = endpoint.wait_incoming().await { + let conn = incoming.await.unwrap(); + + let (mut send, mut recv) = conn.accept_bi().await.unwrap(); + + let mut buf = vec![]; + recv.read_to_end(&mut buf).await.unwrap(); + println!("{:?}", buf); + + send.write(&[4, 5, 6]).await.unwrap(); + send.finish().unwrap(); + + conn.closed().await; + } + + endpoint.close(0u32.into(), b""); + endpoint.shutdown().await.unwrap(); +} diff --git a/compio-quic/src/builder.rs b/compio-quic/src/builder.rs new file mode 100644 index 00000000..85cd4c14 --- /dev/null +++ b/compio-quic/src/builder.rs @@ -0,0 +1,273 @@ +use std::{io, sync::Arc}; + +use compio_net::ToSocketAddrsAsync; +use quinn_proto::{ + crypto::rustls::{QuicClientConfig, QuicServerConfig}, + ClientConfig, ServerConfig, +}; + +use crate::Endpoint; + +/// Helper to construct an [`Endpoint`] for use with outgoing connections only. +/// +/// To get one, call `new_with_xxx` methods. +/// +/// [builder]: https://rust-unofficial.github.io/patterns/patterns/creational/builder.html +#[derive(Debug)] +pub struct ClientBuilder(T); + +impl ClientBuilder { + /// Create a builder with an empty [`rustls::RootCertStore`]. + pub fn new_with_empty_roots() -> Self { + ClientBuilder(rustls::RootCertStore::empty()) + } + + /// Create a builder with [`rustls_native_certs`]. + #[cfg(feature = "native-certs")] + pub fn new_with_native_certs() -> io::Result { + let mut roots = rustls::RootCertStore::empty(); + roots.add_parsable_certificates(rustls_native_certs::load_native_certs()?); + Ok(ClientBuilder(roots)) + } + + /// Create a builder with [`webpki_roots`]. + #[cfg(feature = "webpki-roots")] + pub fn new_with_webpki_roots() -> Self { + let roots = + rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + ClientBuilder(roots) + } + + /// Add a custom certificate. + pub fn with_custom_certificate( + mut self, + der: rustls::pki_types::CertificateDer, + ) -> Result { + self.0.add(der)?; + Ok(self) + } + + /// Don't configure revocation. + pub fn with_no_crls(self) -> ClientBuilder { + ClientBuilder::new_with_root_certificates(self.0) + } + + /// Verify the revocation state of presented client certificates against the + /// provided certificate revocation lists (CRLs). + pub fn with_crls( + self, + crls: impl IntoIterator>, + ) -> Result, rustls::client::VerifierBuilderError> { + let verifier = rustls::client::WebPkiServerVerifier::builder(Arc::new(self.0)) + .with_crls(crls) + .build()?; + Ok(ClientBuilder::new_with_webpki_verifier(verifier)) + } +} + +impl ClientBuilder { + /// Create a builder with the provided [`rustls::ClientConfig`]. + pub fn new_with_rustls_client_config( + client_config: rustls::ClientConfig, + ) -> ClientBuilder { + ClientBuilder(client_config) + } + + /// Do not verify the server's certificate. It is vulnerable to MITM + /// attacks, but convenient for testing. + pub fn new_with_no_server_verification() -> ClientBuilder { + ClientBuilder( + rustls::ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .dangerous() + .with_custom_certificate_verifier(Arc::new(verifier::SkipServerVerification::new())) + .with_no_client_auth(), + ) + } + + /// Create a builder with [`rustls_platform_verifier`]. + #[cfg(feature = "platform-verifier")] + pub fn new_with_platform_verifier() -> ClientBuilder { + ClientBuilder( + rustls::ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .dangerous() + .with_custom_certificate_verifier(Arc::new( + rustls_platform_verifier::Verifier::new(), + )) + .with_no_client_auth(), + ) + } + + /// Create a builder with the provided [`rustls::RootCertStore`]. + pub fn new_with_root_certificates( + roots: rustls::RootCertStore, + ) -> ClientBuilder { + ClientBuilder( + rustls::ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .with_root_certificates(roots) + .with_no_client_auth(), + ) + } + + /// Create a builder with a custom [`rustls::client::WebPkiServerVerifier`]. + pub fn new_with_webpki_verifier( + verifier: Arc, + ) -> ClientBuilder { + ClientBuilder( + rustls::ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .with_webpki_verifier(verifier) + .with_no_client_auth(), + ) + } + + /// Set the ALPN protocols to use. + pub fn with_alpn_protocols(mut self, protocols: &[&str]) -> Self { + self.0.alpn_protocols = protocols.iter().map(|p| p.as_bytes().to_vec()).collect(); + self + } + + /// Logging key material to a file for debugging. The file's name is given + /// by the `SSLKEYLOGFILE` environment variable. + /// + /// If `SSLKEYLOGFILE` is not set, or such a file cannot be opened or cannot + /// be written, this does nothing. + pub fn with_key_log(mut self) -> Self { + self.0.key_log = Arc::new(rustls::KeyLogFile::new()); + self + } + + /// Build a [`ClientConfig`]. + pub fn build(mut self) -> ClientConfig { + self.0.enable_early_data = true; + ClientConfig::new(Arc::new( + QuicClientConfig::try_from(self.0).expect("should support TLS13_AES_128_GCM_SHA256"), + )) + } + + /// Create a new [`Endpoint`]. + /// + /// See [`Endpoint::client`] for more information. + pub async fn bind(self, addr: impl ToSocketAddrsAsync) -> io::Result { + let mut endpoint = Endpoint::client(addr).await?; + endpoint.default_client_config = Some(self.build()); + Ok(endpoint) + } +} + +/// Helper to construct an [`Endpoint`] for use with incoming connections. +/// +/// To get one, call `new_with_xxx` methods. +/// +/// [builder]: https://rust-unofficial.github.io/patterns/patterns/creational/builder.html +#[derive(Debug)] +pub struct ServerBuilder(T); + +impl ServerBuilder { + /// Create a builder with the provided [`rustls::ServerConfig`]. + pub fn new_with_rustls_server_config(server_config: rustls::ServerConfig) -> Self { + Self(server_config) + } + + /// Create a builder with a single certificate chain and matching private + /// key. Using this method gets the same result as calling + /// [`ServerConfig::with_single_cert`]. + pub fn new_with_single_cert( + cert_chain: Vec>, + key_der: rustls::pki_types::PrivateKeyDer<'static>, + ) -> Result { + let server_config = + rustls::ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .with_no_client_auth() + .with_single_cert(cert_chain, key_der)?; + Ok(Self::new_with_rustls_server_config(server_config)) + } + + /// Set the ALPN protocols to use. + pub fn with_alpn_protocols(mut self, protocols: &[&str]) -> Self { + self.0.alpn_protocols = protocols.iter().map(|p| p.as_bytes().to_vec()).collect(); + self + } + + /// Logging key material to a file for debugging. The file's name is given + /// by the `SSLKEYLOGFILE` environment variable. + /// + /// If `SSLKEYLOGFILE` is not set, or such a file cannot be opened or cannot + /// be written, this does nothing. + pub fn with_key_log(mut self) -> Self { + self.0.key_log = Arc::new(rustls::KeyLogFile::new()); + self + } + + /// Build a [`ServerConfig`]. + pub fn build(mut self) -> ServerConfig { + self.0.max_early_data_size = u32::MAX; + ServerConfig::with_crypto(Arc::new( + QuicServerConfig::try_from(self.0).expect("should support TLS13_AES_128_GCM_SHA256"), + )) + } + + /// Create a new [`Endpoint`]. + /// + /// See [`Endpoint::server`] for more information. + pub async fn bind(self, addr: impl ToSocketAddrsAsync) -> io::Result { + Endpoint::server(addr, self.build()).await + } +} + +mod verifier { + use rustls::{ + client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}, + crypto::WebPkiSupportedAlgorithms, + pki_types::{CertificateDer, ServerName, UnixTime}, + DigitallySignedStruct, Error, SignatureScheme, + }; + + #[derive(Debug)] + pub struct SkipServerVerification(WebPkiSupportedAlgorithms); + + impl SkipServerVerification { + pub fn new() -> Self { + Self( + rustls::crypto::CryptoProvider::get_default() + .map(|provider| provider.signature_verification_algorithms) + .unwrap_or_else(|| { + rustls::crypto::ring::default_provider().signature_verification_algorithms + }), + ) + } + } + + impl ServerCertVerifier for SkipServerVerification { + fn verify_server_cert( + &self, + _end_entity: &CertificateDer<'_>, + _intermediates: &[CertificateDer<'_>], + _server_name: &ServerName<'_>, + _ocsp: &[u8], + _now: UnixTime, + ) -> Result { + Ok(ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + rustls::crypto::verify_tls12_signature(message, cert, dss, &self.0) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + rustls::crypto::verify_tls13_signature(message, cert, dss, &self.0) + } + + fn supported_verify_schemes(&self) -> Vec { + self.0.supported_schemes() + } + } +} diff --git a/compio-quic/src/connection.rs b/compio-quic/src/connection.rs new file mode 100644 index 00000000..181aac18 --- /dev/null +++ b/compio-quic/src/connection.rs @@ -0,0 +1,1235 @@ +use std::{ + collections::VecDeque, + io, + net::{IpAddr, SocketAddr}, + pin::{pin, Pin}, + sync::{Arc, Mutex, MutexGuard}, + task::{Context, Poll, Waker}, + time::{Duration, Instant}, +}; + +use bytes::Bytes; +use compio_buf::BufResult; +use compio_log::{error, Instrument}; +use compio_runtime::JoinHandle; +use flume::{Receiver, Sender}; +use futures_util::{ + future::{self, Fuse, FusedFuture, LocalBoxFuture}, + select, stream, Future, FutureExt, StreamExt, +}; +use quinn_proto::{ + congestion::Controller, crypto::rustls::HandshakeData, ConnectionHandle, ConnectionStats, Dir, + EndpointEvent, StreamEvent, StreamId, VarInt, +}; +use rustc_hash::FxHashMap as HashMap; +use thiserror::Error; + +use crate::{RecvStream, SendStream, Socket}; + +#[derive(Debug)] +pub(crate) enum ConnectionEvent { + Close(VarInt, Bytes), + Proto(quinn_proto::ConnectionEvent), +} + +#[derive(Debug)] +pub(crate) struct ConnectionState { + pub(crate) conn: quinn_proto::Connection, + pub(crate) error: Option, + connected: bool, + worker: Option>, + poller: Option, + on_connected: Option, + on_handshake_data: Option, + datagram_received: VecDeque, + datagrams_unblocked: VecDeque, + stream_opened: [VecDeque; 2], + stream_available: [VecDeque; 2], + pub(crate) writable: HashMap, + pub(crate) readable: HashMap, + pub(crate) stopped: HashMap, +} + +impl ConnectionState { + fn terminate(&mut self, reason: ConnectionError) { + self.error = Some(reason); + self.connected = false; + + if let Some(waker) = self.on_handshake_data.take() { + waker.wake() + } + if let Some(waker) = self.on_connected.take() { + waker.wake() + } + self.datagram_received.drain(..).for_each(Waker::wake); + self.datagrams_unblocked.drain(..).for_each(Waker::wake); + for e in &mut self.stream_opened { + e.drain(..).for_each(Waker::wake); + } + for e in &mut self.stream_available { + e.drain(..).for_each(Waker::wake); + } + wake_all_streams(&mut self.writable); + wake_all_streams(&mut self.readable); + wake_all_streams(&mut self.stopped); + } + + fn close(&mut self, error_code: VarInt, reason: Bytes) { + self.conn.close(Instant::now(), error_code, reason); + self.terminate(ConnectionError::LocallyClosed); + self.wake(); + } + + pub(crate) fn wake(&mut self) { + if let Some(waker) = self.poller.take() { + waker.wake() + } + } + + fn handshake_data(&self) -> Option> { + self.conn + .crypto_session() + .handshake_data() + .map(|data| data.downcast::().unwrap()) + } + + pub(crate) fn check_0rtt(&self) -> bool { + self.conn.side().is_server() || self.conn.is_handshaking() || self.conn.accepted_0rtt() + } +} + +fn wake_stream(stream: StreamId, wakers: &mut HashMap) { + if let Some(waker) = wakers.remove(&stream) { + waker.wake(); + } +} + +fn wake_all_streams(wakers: &mut HashMap) { + wakers.drain().for_each(|(_, waker)| waker.wake()) +} + +#[derive(Debug)] +pub(crate) struct ConnectionInner { + state: Mutex, + handle: ConnectionHandle, + socket: Socket, + events_tx: Sender<(ConnectionHandle, EndpointEvent)>, + events_rx: Receiver, +} + +fn implicit_close(this: &Arc) { + if Arc::strong_count(this) == 2 { + this.state().close(0u32.into(), Bytes::new()) + } +} + +impl ConnectionInner { + fn new( + handle: ConnectionHandle, + conn: quinn_proto::Connection, + socket: Socket, + events_tx: Sender<(ConnectionHandle, EndpointEvent)>, + events_rx: Receiver, + ) -> Self { + Self { + state: Mutex::new(ConnectionState { + conn, + connected: false, + error: None, + worker: None, + poller: None, + on_connected: None, + on_handshake_data: None, + datagram_received: VecDeque::new(), + datagrams_unblocked: VecDeque::new(), + stream_opened: [VecDeque::new(), VecDeque::new()], + stream_available: [VecDeque::new(), VecDeque::new()], + writable: HashMap::default(), + readable: HashMap::default(), + stopped: HashMap::default(), + }), + handle, + socket, + events_tx, + events_rx, + } + } + + #[inline] + pub(crate) fn state(&self) -> MutexGuard { + self.state.lock().unwrap() + } + + #[inline] + pub(crate) fn try_state(&self) -> Result, ConnectionError> { + let state = self.state(); + if let Some(error) = &state.error { + Err(error.clone()) + } else { + Ok(state) + } + } + + async fn run(self: &Arc) -> io::Result<()> { + let mut poller = stream::poll_fn(|cx| { + let mut state = self.state(); + let ready = state.poller.is_none(); + match &state.poller { + Some(waker) if waker.will_wake(cx.waker()) => {} + _ => state.poller = Some(cx.waker().clone()), + }; + if ready { + Poll::Ready(Some(())) + } else { + Poll::Pending + } + }) + .fuse(); + + let mut timer = Timer::new(); + let mut event_stream = self.events_rx.stream().ready_chunks(100); + let mut send_buf = Some(Vec::with_capacity(self.state().conn.current_mtu() as usize)); + let mut transmit_fut = pin!(Fuse::terminated()); + + loop { + let mut state = select! { + _ = poller.select_next_some() => self.state(), + _ = timer => { + timer.reset(None); + let mut state = self.state(); + state.conn.handle_timeout(Instant::now()); + state + } + events = event_stream.select_next_some() => { + let mut state = self.state(); + for event in events { + match event { + ConnectionEvent::Close(error_code, reason) => state.close(error_code, reason), + ConnectionEvent::Proto(event) => state.conn.handle_event(event), + } + } + state + }, + BufResult::<(), Vec>(res, mut buf) = transmit_fut => match res { + Ok(()) => { + buf.clear(); + send_buf = Some(buf); + self.state() + }, + Err(e) => break Err(e), + }, + }; + + if let Some(mut buf) = send_buf.take() { + if let Some(transmit) = state.conn.poll_transmit( + Instant::now(), + self.socket.max_gso_segments(), + &mut buf, + ) { + transmit_fut.set(async move { self.socket.send(buf, &transmit).await }.fuse()) + } else { + send_buf = Some(buf); + } + } + + timer.reset(state.conn.poll_timeout()); + + while let Some(event) = state.conn.poll_endpoint_events() { + let _ = self.events_tx.send((self.handle, event)); + } + + while let Some(event) = state.conn.poll() { + use quinn_proto::Event::*; + match event { + HandshakeDataReady => { + if let Some(waker) = state.on_handshake_data.take() { + waker.wake() + } + } + Connected => { + state.connected = true; + if let Some(waker) = state.on_connected.take() { + waker.wake() + } + if state.conn.side().is_client() && !state.conn.accepted_0rtt() { + // Wake up rejected 0-RTT streams so they can fail immediately with + // `ZeroRttRejected` errors. + wake_all_streams(&mut state.writable); + wake_all_streams(&mut state.readable); + wake_all_streams(&mut state.stopped); + } + } + ConnectionLost { reason } => state.terminate(reason.into()), + Stream(StreamEvent::Readable { id }) => wake_stream(id, &mut state.readable), + Stream(StreamEvent::Writable { id }) => wake_stream(id, &mut state.writable), + Stream(StreamEvent::Finished { id }) => wake_stream(id, &mut state.stopped), + Stream(StreamEvent::Stopped { id, .. }) => { + wake_stream(id, &mut state.stopped); + wake_stream(id, &mut state.writable); + } + Stream(StreamEvent::Available { dir }) => state.stream_available[dir as usize] + .drain(..) + .for_each(Waker::wake), + Stream(StreamEvent::Opened { dir }) => state.stream_opened[dir as usize] + .drain(..) + .for_each(Waker::wake), + DatagramReceived => state.datagram_received.drain(..).for_each(Waker::wake), + DatagramsUnblocked => state.datagrams_unblocked.drain(..).for_each(Waker::wake), + } + } + + if state.conn.is_drained() { + break Ok(()); + } + } + } +} + +macro_rules! conn_fn { + () => { + /// The local IP address which was used when the peer established + /// the connection. + /// + /// This can be different from the address the endpoint is bound to, in case + /// the endpoint is bound to a wildcard address like `0.0.0.0` or `::`. + /// + /// This will return `None` for clients, or when the platform does not + /// expose this information. + pub fn local_ip(&self) -> Option { + self.0.state().conn.local_ip() + } + + /// The peer's UDP address. + /// + /// Will panic if called after `poll` has returned `Ready`. + pub fn remote_address(&self) -> SocketAddr { + self.0.state().conn.remote_address() + } + + /// Current best estimate of this connection's latency (round-trip-time). + pub fn rtt(&self) -> Duration { + self.0.state().conn.rtt() + } + + /// Connection statistics. + pub fn stats(&self) -> ConnectionStats { + self.0.state().conn.stats() + } + + /// Current state of the congestion control algorithm. (For debugging + /// purposes) + pub fn congestion_state(&self) -> Box { + self.0.state().conn.congestion_state().clone_box() + } + + /// Cryptographic identity of the peer. + pub fn peer_identity( + &self, + ) -> Option>>> { + self.0 + .state() + .conn + .crypto_session() + .peer_identity() + .map(|v| v.downcast().unwrap()) + } + + /// Derive keying material from this connection's TLS session secrets. + /// + /// When both peers call this method with the same `label` and `context` + /// arguments and `output` buffers of equal length, they will get the + /// same sequence of bytes in `output`. These bytes are cryptographically + /// strong and pseudorandom, and are suitable for use as keying material. + /// + /// This function fails if called with an empty `output` or called prior to + /// the handshake completing. + /// + /// See [RFC5705](https://tools.ietf.org/html/rfc5705) for more information. + pub fn export_keying_material( + &self, + output: &mut [u8], + label: &[u8], + context: &[u8], + ) -> Result<(), quinn_proto::crypto::ExportKeyingMaterialError> { + self.0 + .state() + .conn + .crypto_session() + .export_keying_material(output, label, context) + } + }; +} + +/// In-progress connection attempt future +#[derive(Debug)] +#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"] +pub struct Connecting(Arc); + +impl Connecting { + conn_fn!(); + + pub(crate) fn new( + handle: ConnectionHandle, + conn: quinn_proto::Connection, + socket: Socket, + events_tx: Sender<(ConnectionHandle, EndpointEvent)>, + events_rx: Receiver, + ) -> Self { + let inner = Arc::new(ConnectionInner::new( + handle, conn, socket, events_tx, events_rx, + )); + let worker = compio_runtime::spawn({ + let inner = inner.clone(); + async move { + #[allow(unused)] + if let Err(e) = inner.run().await { + error!("I/O error: {}", e); + } + } + .in_current_span() + }); + inner.state().worker = Some(worker); + Self(inner) + } + + /// Parameters negotiated during the handshake. + pub async fn handshake_data(&mut self) -> Result, ConnectionError> { + future::poll_fn(|cx| { + let mut state = self.0.try_state()?; + if let Some(data) = state.handshake_data() { + return Poll::Ready(Ok(data)); + } + + match &state.on_handshake_data { + Some(waker) if waker.will_wake(cx.waker()) => {} + _ => state.on_handshake_data = Some(cx.waker().clone()), + } + + Poll::Pending + }) + .await + } + + /// Convert into a 0-RTT or 0.5-RTT connection at the cost of weakened + /// security. + /// + /// Returns `Ok` immediately if the local endpoint is able to attempt + /// sending 0/0.5-RTT data. If so, the returned [`Connection`] can be used + /// to send application data without waiting for the rest of the handshake + /// to complete, at the cost of weakened cryptographic security guarantees. + /// The [`Connection::accepted_0rtt`] method resolves when the handshake + /// does complete, at which point subsequently opened streams and written + /// data will have full cryptographic protection. + /// + /// ## Outgoing + /// + /// For outgoing connections, the initial attempt to convert to a + /// [`Connection`] which sends 0-RTT data will proceed if the + /// [`crypto::ClientConfig`][crate::crypto::ClientConfig] attempts to resume + /// a previous TLS session. However, **the remote endpoint may not actually + /// _accept_ the 0-RTT data**--yet still accept the connection attempt in + /// general. This possibility is conveyed through the + /// [`Connection::accepted_0rtt`] method--when the handshake completes, it + /// resolves to true if the 0-RTT data was accepted and false if it was + /// rejected. If it was rejected, the existence of streams opened and other + /// application data sent prior to the handshake completing will not be + /// conveyed to the remote application, and local operations on them will + /// return `ZeroRttRejected` errors. + /// + /// A server may reject 0-RTT data at its discretion, but accepting 0-RTT + /// data requires the relevant resumption state to be stored in the server, + /// which servers may limit or lose for various reasons including not + /// persisting resumption state across server restarts. + /// + /// ## Incoming + /// + /// For incoming connections, conversion to 0.5-RTT will always fully + /// succeed. `into_0rtt` will always return `Ok` and + /// [`Connection::accepted_0rtt`] will always resolve to true. + /// + /// ## Security + /// + /// On outgoing connections, this enables transmission of 0-RTT data, which + /// is vulnerable to replay attacks, and should therefore never invoke + /// non-idempotent operations. + /// + /// On incoming connections, this enables transmission of 0.5-RTT data, + /// which may be sent before TLS client authentication has occurred, and + /// should therefore not be used to send data for which client + /// authentication is being used. + pub fn into_0rtt(self) -> Result { + let is_ok = { + let state = self.0.state(); + state.conn.has_0rtt() || state.conn.side().is_server() + }; + if is_ok { + Ok(Connection(self.0.clone())) + } else { + Err(self) + } + } +} + +impl Future for Connecting { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut state = self.0.try_state()?; + + if state.connected { + return Poll::Ready(Ok(Connection(self.0.clone()))); + } + + match &state.on_connected { + Some(waker) if waker.will_wake(cx.waker()) => {} + _ => state.on_connected = Some(cx.waker().clone()), + } + + Poll::Pending + } +} + +impl Drop for Connecting { + fn drop(&mut self) { + implicit_close(&self.0) + } +} + +/// A QUIC connection. +#[derive(Debug, Clone)] +pub struct Connection(Arc); + +impl Connection { + conn_fn!(); + + /// Parameters negotiated during the handshake. + pub fn handshake_data(&mut self) -> Result, ConnectionError> { + Ok(self.0.try_state()?.handshake_data().unwrap()) + } + + /// Compute the maximum size of datagrams that may be passed to + /// [`send_datagram()`](Self::send_datagram). + /// + /// Returns `None` if datagrams are unsupported by the peer or disabled + /// locally. + /// + /// This may change over the lifetime of a connection according to variation + /// in the path MTU estimate. The peer can also enforce an arbitrarily small + /// fixed limit, but if the peer's limit is large this is guaranteed to be a + /// little over a kilobyte at minimum. + /// + /// Not necessarily the maximum size of received datagrams. + pub fn max_datagram_size(&self) -> Option { + self.0.state().conn.datagrams().max_size() + } + + /// Bytes available in the outgoing datagram buffer. + /// + /// When greater than zero, calling [`send_datagram()`](Self::send_datagram) + /// with a datagram of at most this size is guaranteed not to cause older + /// datagrams to be dropped. + pub fn datagram_send_buffer_space(&self) -> usize { + self.0.state().conn.datagrams().send_buffer_space() + } + + /// Modify the number of remotely initiated unidirectional streams that may + /// be concurrently open. + /// + /// No streams may be opened by the peer unless fewer than `count` are + /// already open. Large `count`s increase both minimum and worst-case + /// memory consumption. + pub fn set_max_concurrent_uni_streams(&self, count: VarInt) { + let mut state = self.0.state(); + state.conn.set_max_concurrent_streams(Dir::Uni, count); + // May need to send MAX_STREAMS to make progress + state.wake(); + } + + /// See [`quinn_proto::TransportConfig::receive_window()`] + pub fn set_receive_window(&self, receive_window: VarInt) { + let mut state = self.0.state(); + state.conn.set_receive_window(receive_window); + state.wake(); + } + + /// Modify the number of remotely initiated bidirectional streams that may + /// be concurrently open. + /// + /// No streams may be opened by the peer unless fewer than `count` are + /// already open. Large `count`s increase both minimum and worst-case + /// memory consumption. + pub fn set_max_concurrent_bi_streams(&self, count: VarInt) { + let mut state = self.0.state(); + state.conn.set_max_concurrent_streams(Dir::Bi, count); + // May need to send MAX_STREAMS to make progress + state.wake(); + } + + /// Close the connection immediately. + /// + /// Pending operations will fail immediately with + /// [`ConnectionError::LocallyClosed`]. No more data is sent to the peer + /// and the peer may drop buffered data upon receiving + /// the CONNECTION_CLOSE frame. + /// + /// `error_code` and `reason` are not interpreted, and are provided directly + /// to the peer. + /// + /// `reason` will be truncated to fit in a single packet with overhead; to + /// improve odds that it is preserved in full, it should be kept under + /// 1KiB. + /// + /// # Gracefully closing a connection + /// + /// Only the peer last receiving application data can be certain that all + /// data is delivered. The only reliable action it can then take is to + /// close the connection, potentially with a custom error code. The + /// delivery of the final CONNECTION_CLOSE frame is very likely if both + /// endpoints stay online long enough, and [`Endpoint::shutdown()`] can + /// be used to provide sufficient time. Otherwise, the remote peer will + /// time out the connection, provided that the idle timeout is not + /// disabled. + /// + /// The sending side can not guarantee all stream data is delivered to the + /// remote application. It only knows the data is delivered to the QUIC + /// stack of the remote endpoint. Once the local side sends a + /// CONNECTION_CLOSE frame in response to calling [`close()`] the remote + /// endpoint may drop any data it received but is as yet undelivered to + /// the application, including data that was acknowledged as received to + /// the local endpoint. + /// + /// [`ConnectionError::LocallyClosed`]: ConnectionError::LocallyClosed + /// [`Endpoint::shutdown()`]: crate::Endpoint::shutdown + /// [`close()`]: Connection::close + pub fn close(&self, error_code: VarInt, reason: &[u8]) { + self.0 + .state() + .close(error_code, Bytes::copy_from_slice(reason)); + } + + /// Wait for the connection to be closed for any reason. + pub async fn closed(&self) -> ConnectionError { + let worker = self.0.state().worker.take(); + if let Some(worker) = worker { + let _ = worker.await; + } + + self.0.try_state().unwrap_err() + } + + /// If the connection is closed, the reason why. + /// + /// Returns `None` if the connection is still open. + pub fn close_reason(&self) -> Option { + self.0.try_state().err() + } + + fn poll_recv_datagram(&self, cx: &mut Context) -> Poll> { + let mut state = self.0.try_state()?; + if let Some(bytes) = state.conn.datagrams().recv() { + return Poll::Ready(Ok(bytes)); + } + state.datagram_received.push_back(cx.waker().clone()); + Poll::Pending + } + + /// Receive an application datagram. + pub async fn recv_datagram(&self) -> Result { + future::poll_fn(|cx| self.poll_recv_datagram(cx)).await + } + + fn try_send_datagram( + &self, + cx: Option<&mut Context>, + data: Bytes, + ) -> Result<(), Result> { + use quinn_proto::SendDatagramError::*; + let mut state = self.0.try_state().map_err(|e| Ok(e.into()))?; + state + .conn + .datagrams() + .send(data, cx.is_none()) + .map_err(|err| match err { + UnsupportedByPeer => Ok(SendDatagramError::UnsupportedByPeer), + Disabled => Ok(SendDatagramError::Disabled), + TooLarge => Ok(SendDatagramError::TooLarge), + Blocked(data) => { + state + .datagrams_unblocked + .push_back(cx.unwrap().waker().clone()); + Err(data) + } + })?; + state.wake(); + Ok(()) + } + + /// Transmit `data` as an unreliable, unordered application datagram. + /// + /// Application datagrams are a low-level primitive. They may be lost or + /// delivered out of order, and `data` must both fit inside a single + /// QUIC packet and be smaller than the maximum dictated by the peer. + pub fn send_datagram(&self, data: Bytes) -> Result<(), SendDatagramError> { + self.try_send_datagram(None, data).map_err(Result::unwrap) + } + + /// Transmit `data` as an unreliable, unordered application datagram. + /// + /// Unlike [`send_datagram()`], this method will wait for buffer space + /// during congestion conditions, which effectively prioritizes old + /// datagrams over new datagrams. + /// + /// See [`send_datagram()`] for details. + /// + /// [`send_datagram()`]: Connection::send_datagram + pub async fn send_datagram_wait(&self, data: Bytes) -> Result<(), SendDatagramError> { + let mut data = Some(data); + future::poll_fn( + |cx| match self.try_send_datagram(Some(cx), data.take().unwrap()) { + Ok(()) => Poll::Ready(Ok(())), + Err(Ok(e)) => Poll::Ready(Err(e)), + Err(Err(b)) => { + data.replace(b); + Poll::Pending + } + }, + ) + .await + } + + fn poll_open_stream( + &self, + cx: Option<&mut Context>, + dir: Dir, + ) -> Poll> { + let mut state = self.0.try_state()?; + if let Some(stream) = state.conn.streams().open(dir) { + Poll::Ready(Ok(( + stream, + state.conn.side().is_client() && state.conn.is_handshaking(), + ))) + } else { + if let Some(cx) = cx { + state.stream_available[dir as usize].push_back(cx.waker().clone()); + } + Poll::Pending + } + } + + /// Initiate a new outgoing unidirectional stream. + /// + /// Streams are cheap and instantaneous to open. As a consequence, the peer + /// won't be notified that a stream has been opened until the stream is + /// actually used. + pub fn open_uni(&self) -> Result { + if let Poll::Ready((stream, is_0rtt)) = self.poll_open_stream(None, Dir::Uni)? { + Ok(SendStream::new(self.0.clone(), stream, is_0rtt)) + } else { + Err(OpenStreamError::StreamsExhausted) + } + } + + /// Initiate a new outgoing unidirectional stream. + /// + /// Unlike [`open_uni()`], this method will wait for the connection to allow + /// a new stream to be opened. + /// + /// See [`open_uni()`] for details. + /// + /// [`open_uni()`]: crate::Connection::open_uni + pub async fn open_uni_wait(&self) -> Result { + let (stream, is_0rtt) = + future::poll_fn(|cx| self.poll_open_stream(Some(cx), Dir::Uni)).await?; + Ok(SendStream::new(self.0.clone(), stream, is_0rtt)) + } + + /// Initiate a new outgoing bidirectional stream. + /// + /// Streams are cheap and instantaneous to open. As a consequence, the peer + /// won't be notified that a stream has been opened until the stream is + /// actually used. + pub fn open_bi(&self) -> Result<(SendStream, RecvStream), OpenStreamError> { + if let Poll::Ready((stream, is_0rtt)) = self.poll_open_stream(None, Dir::Bi)? { + Ok(( + SendStream::new(self.0.clone(), stream, is_0rtt), + RecvStream::new(self.0.clone(), stream, is_0rtt), + )) + } else { + Err(OpenStreamError::StreamsExhausted) + } + } + + /// Initiate a new outgoing bidirectional stream. + /// + /// Unlike [`open_bi()`], this method will wait for the connection to allow + /// a new stream to be opened. + /// + /// See [`open_bi()`] for details. + /// + /// [`open_bi()`]: crate::Connection::open_bi + pub async fn open_bi_wait(&self) -> Result<(SendStream, RecvStream), ConnectionError> { + let (stream, is_0rtt) = + future::poll_fn(|cx| self.poll_open_stream(Some(cx), Dir::Bi)).await?; + Ok(( + SendStream::new(self.0.clone(), stream, is_0rtt), + RecvStream::new(self.0.clone(), stream, is_0rtt), + )) + } + + fn poll_accept_stream( + &self, + cx: &mut Context, + dir: Dir, + ) -> Poll> { + let mut state = self.0.try_state()?; + if let Some(stream) = state.conn.streams().accept(dir) { + state.wake(); + Poll::Ready(Ok((stream, state.conn.is_handshaking()))) + } else { + state.stream_opened[dir as usize].push_back(cx.waker().clone()); + Poll::Pending + } + } + + /// Accept the next incoming uni-directional stream + pub async fn accept_uni(&self) -> Result { + let (stream, is_0rtt) = future::poll_fn(|cx| self.poll_accept_stream(cx, Dir::Uni)).await?; + Ok(RecvStream::new(self.0.clone(), stream, is_0rtt)) + } + + /// Accept the next incoming bidirectional stream + /// + /// **Important Note**: The `Connection` that calls [`open_bi()`] must write + /// to its [`SendStream`] before the other `Connection` is able to + /// `accept_bi()`. Calling [`open_bi()`] then waiting on the [`RecvStream`] + /// without writing anything to [`SendStream`] will never succeed. + /// + /// [`accept_bi()`]: crate::Connection::accept_bi + /// [`open_bi()`]: crate::Connection::open_bi + /// [`SendStream`]: crate::SendStream + /// [`RecvStream`]: crate::RecvStream + pub async fn accept_bi(&self) -> Result<(SendStream, RecvStream), ConnectionError> { + let (stream, is_0rtt) = future::poll_fn(|cx| self.poll_accept_stream(cx, Dir::Bi)).await?; + Ok(( + SendStream::new(self.0.clone(), stream, is_0rtt), + RecvStream::new(self.0.clone(), stream, is_0rtt), + )) + } + + /// Wait for the connection to be fully established. + /// + /// For clients, the resulting value indicates if 0-RTT was accepted. For + /// servers, the resulting value is meaningless. + pub async fn accepted_0rtt(&self) -> Result { + future::poll_fn(|cx| { + let mut state = self.0.try_state()?; + + if state.connected { + return Poll::Ready(Ok(state.conn.accepted_0rtt())); + } + + match &state.on_connected { + Some(waker) if waker.will_wake(cx.waker()) => {} + _ => state.on_connected = Some(cx.waker().clone()), + } + + Poll::Pending + }) + .await + } +} + +impl PartialEq for Connection { + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.0, &other.0) + } +} + +impl Eq for Connection {} + +impl Drop for Connection { + fn drop(&mut self) { + implicit_close(&self.0) + } +} + +struct Timer { + deadline: Option, + fut: Fuse>, +} + +impl Timer { + fn new() -> Self { + Self { + deadline: None, + fut: Fuse::terminated(), + } + } + + fn reset(&mut self, deadline: Option) { + if let Some(deadline) = deadline { + if self.deadline.is_none() || self.deadline != Some(deadline) { + self.fut = compio_runtime::time::sleep_until(deadline) + .boxed_local() + .fuse(); + } + } else { + self.fut = Fuse::terminated(); + } + self.deadline = deadline; + } +} + +impl Future for Timer { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.fut.poll_unpin(cx) + } +} + +impl FusedFuture for Timer { + fn is_terminated(&self) -> bool { + self.fut.is_terminated() + } +} + +/// Reasons why a connection might be lost +#[derive(Debug, Error, Clone, PartialEq, Eq)] +pub enum ConnectionError { + /// The peer doesn't implement any supported version + #[error("peer doesn't implement any supported version")] + VersionMismatch, + /// The peer violated the QUIC specification as understood by this + /// implementation + #[error(transparent)] + TransportError(#[from] quinn_proto::TransportError), + /// The peer's QUIC stack aborted the connection automatically + #[error("aborted by peer: {0}")] + ConnectionClosed(quinn_proto::ConnectionClose), + /// The peer closed the connection + #[error("closed by peer: {0}")] + ApplicationClosed(quinn_proto::ApplicationClose), + /// The peer is unable to continue processing this connection, usually due + /// to having restarted + #[error("reset by peer")] + Reset, + /// Communication with the peer has lapsed for longer than the negotiated + /// idle timeout + /// + /// If neither side is sending keep-alives, a connection will time out after + /// a long enough idle period even if the peer is still reachable. See + /// also [`TransportConfig::max_idle_timeout()`] + /// and [`TransportConfig::keep_alive_interval()`]. + #[error("timed out")] + TimedOut, + /// The local application closed the connection + #[error("closed")] + LocallyClosed, + /// The connection could not be created because not enough of the CID space + /// is available + /// + /// Try using longer connection IDs. + #[error("CIDs exhausted")] + CidsExhausted, +} + +impl From for ConnectionError { + fn from(value: quinn_proto::ConnectionError) -> Self { + use quinn_proto::ConnectionError::*; + + match value { + VersionMismatch => ConnectionError::VersionMismatch, + TransportError(e) => ConnectionError::TransportError(e), + ConnectionClosed(e) => ConnectionError::ConnectionClosed(e), + ApplicationClosed(e) => ConnectionError::ApplicationClosed(e), + Reset => ConnectionError::Reset, + TimedOut => ConnectionError::TimedOut, + LocallyClosed => ConnectionError::LocallyClosed, + CidsExhausted => ConnectionError::CidsExhausted, + } + } +} + +/// Errors that can arise when sending a datagram +#[derive(Debug, Error, Clone, Eq, PartialEq)] +pub enum SendDatagramError { + /// The peer does not support receiving datagram frames + #[error("datagrams not supported by peer")] + UnsupportedByPeer, + /// Datagram support is disabled locally + #[error("datagram support disabled")] + Disabled, + /// The datagram is larger than the connection can currently accommodate + /// + /// Indicates that the path MTU minus overhead or the limit advertised by + /// the peer has been exceeded. + #[error("datagram too large")] + TooLarge, + /// The connection was lost + #[error("connection lost")] + ConnectionLost(#[from] ConnectionError), +} + +/// Errors that can arise when trying to open a stream +#[derive(Debug, Error, Clone, Eq, PartialEq)] +pub enum OpenStreamError { + /// The connection was lost + #[error("connection lost")] + ConnectionLost(#[from] ConnectionError), + // The streams in the given direction are currently exhausted + #[error("streams exhausted")] + StreamsExhausted, +} + +#[cfg(feature = "h3")] +pub(crate) mod h3_impl { + use bytes::{Buf, BytesMut}; + use futures_util::ready; + use h3::{ + error::Code, + ext::Datagram, + quic::{self, Error, RecvDatagramExt, SendDatagramExt, WriteBuf}, + }; + + use super::*; + use crate::{send_stream::h3_impl::SendStream, ReadError, WriteError}; + + impl Error for ConnectionError { + fn is_timeout(&self) -> bool { + matches!(self, ConnectionError::TimedOut) + } + + fn err_code(&self) -> Option { + match &self { + ConnectionError::ApplicationClosed(quinn_proto::ApplicationClose { + error_code, + .. + }) => Some(error_code.into_inner()), + _ => None, + } + } + } + + impl Error for SendDatagramError { + fn is_timeout(&self) -> bool { + false + } + + fn err_code(&self) -> Option { + match self { + SendDatagramError::ConnectionLost(ConnectionError::ApplicationClosed( + quinn_proto::ApplicationClose { error_code, .. }, + )) => Some(error_code.into_inner()), + _ => None, + } + } + } + + impl SendDatagramExt for Connection + where + B: Buf, + { + type Error = SendDatagramError; + + fn send_datagram(&mut self, data: Datagram) -> Result<(), Self::Error> { + let mut buf = BytesMut::new(); + data.encode(&mut buf); + Connection::send_datagram(self, buf.freeze()) + } + } + + impl RecvDatagramExt for Connection { + type Buf = Bytes; + type Error = ConnectionError; + + fn poll_accept_datagram( + &mut self, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + Poll::Ready(Ok(Some(ready!(self.poll_recv_datagram(cx))?))) + } + } + + /// Bidirectional stream. + pub struct BidiStream { + send: SendStream, + recv: RecvStream, + } + + impl BidiStream { + pub(crate) fn new(conn: Arc, stream: StreamId, is_0rtt: bool) -> Self { + Self { + send: SendStream::new(conn.clone(), stream, is_0rtt), + recv: RecvStream::new(conn, stream, is_0rtt), + } + } + } + + impl quic::BidiStream for BidiStream + where + B: Buf, + { + type RecvStream = RecvStream; + type SendStream = SendStream; + + fn split(self) -> (Self::SendStream, Self::RecvStream) { + (self.send, self.recv) + } + } + + impl quic::RecvStream for BidiStream + where + B: Buf, + { + type Buf = Bytes; + type Error = ReadError; + + fn poll_data( + &mut self, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + self.recv.poll_data(cx) + } + + fn stop_sending(&mut self, error_code: u64) { + self.recv.stop_sending(error_code) + } + + fn recv_id(&self) -> quic::StreamId { + self.recv.recv_id() + } + } + + impl quic::SendStream for BidiStream + where + B: Buf, + { + type Error = WriteError; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.send.poll_ready(cx) + } + + fn send_data>>(&mut self, data: T) -> Result<(), Self::Error> { + self.send.send_data(data) + } + + fn poll_finish(&mut self, cx: &mut Context<'_>) -> Poll> { + self.send.poll_finish(cx) + } + + fn reset(&mut self, reset_code: u64) { + self.send.reset(reset_code) + } + + fn send_id(&self) -> quic::StreamId { + self.send.send_id() + } + } + + impl quic::SendStreamUnframed for BidiStream + where + B: Buf, + { + fn poll_send( + &mut self, + cx: &mut Context<'_>, + buf: &mut D, + ) -> Poll> { + self.send.poll_send(cx, buf) + } + } + + /// Stream opener. + #[derive(Clone)] + pub struct OpenStreams(Connection); + + impl quic::OpenStreams for OpenStreams + where + B: Buf, + { + type BidiStream = BidiStream; + type OpenError = ConnectionError; + type SendStream = SendStream; + + fn poll_open_bidi( + &mut self, + cx: &mut Context<'_>, + ) -> Poll> { + let (stream, is_0rtt) = ready!(self.0.poll_open_stream(Some(cx), Dir::Bi))?; + Poll::Ready(Ok(BidiStream::new(self.0.0.clone(), stream, is_0rtt))) + } + + fn poll_open_send( + &mut self, + cx: &mut Context<'_>, + ) -> Poll> { + let (stream, is_0rtt) = ready!(self.0.poll_open_stream(Some(cx), Dir::Uni))?; + Poll::Ready(Ok(SendStream::new(self.0.0.clone(), stream, is_0rtt))) + } + + fn close(&mut self, code: Code, reason: &[u8]) { + self.0 + .close(code.value().try_into().expect("invalid code"), reason) + } + } + + impl quic::OpenStreams for Connection + where + B: Buf, + { + type BidiStream = BidiStream; + type OpenError = ConnectionError; + type SendStream = SendStream; + + fn poll_open_bidi( + &mut self, + cx: &mut Context<'_>, + ) -> Poll> { + let (stream, is_0rtt) = ready!(self.poll_open_stream(Some(cx), Dir::Bi))?; + Poll::Ready(Ok(BidiStream::new(self.0.clone(), stream, is_0rtt))) + } + + fn poll_open_send( + &mut self, + cx: &mut Context<'_>, + ) -> Poll> { + let (stream, is_0rtt) = ready!(self.poll_open_stream(Some(cx), Dir::Uni))?; + Poll::Ready(Ok(SendStream::new(self.0.clone(), stream, is_0rtt))) + } + + fn close(&mut self, code: Code, reason: &[u8]) { + Connection::close(self, code.value().try_into().expect("invalid code"), reason) + } + } + + impl quic::Connection for Connection + where + B: Buf, + { + type AcceptError = ConnectionError; + type OpenStreams = OpenStreams; + type RecvStream = RecvStream; + + fn poll_accept_recv( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll, Self::AcceptError>> { + let (stream, is_0rtt) = ready!(self.poll_accept_stream(cx, Dir::Uni))?; + Poll::Ready(Ok(Some(RecvStream::new(self.0.clone(), stream, is_0rtt)))) + } + + fn poll_accept_bidi( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll, Self::AcceptError>> { + let (stream, is_0rtt) = ready!(self.poll_accept_stream(cx, Dir::Bi))?; + Poll::Ready(Ok(Some(BidiStream::new(self.0.clone(), stream, is_0rtt)))) + } + + fn opener(&self) -> Self::OpenStreams { + OpenStreams(self.clone()) + } + } +} diff --git a/compio-quic/src/endpoint.rs b/compio-quic/src/endpoint.rs new file mode 100644 index 00000000..99d7400f --- /dev/null +++ b/compio-quic/src/endpoint.rs @@ -0,0 +1,518 @@ +use std::{ + collections::VecDeque, + io, + mem::ManuallyDrop, + net::{SocketAddr, SocketAddrV6}, + pin::pin, + sync::{Arc, Mutex}, + task::{Context, Poll, Waker}, + time::Instant, +}; + +use bytes::Bytes; +use compio_buf::BufResult; +use compio_log::{error, Instrument}; +use compio_net::{ToSocketAddrsAsync, UdpSocket}; +use compio_runtime::JoinHandle; +use flume::{unbounded, Receiver, Sender}; +use futures_util::{ + future::{self}, + select, + task::AtomicWaker, + FutureExt, StreamExt, +}; +use quinn_proto::{ + ClientConfig, ConnectError, ConnectionError, ConnectionHandle, DatagramEvent, EndpointConfig, + EndpointEvent, ServerConfig, Transmit, VarInt, +}; +use rustc_hash::FxHashMap as HashMap; + +use crate::{Connecting, ConnectionEvent, Incoming, RecvMeta, Socket}; + +#[derive(Debug)] +struct EndpointState { + endpoint: quinn_proto::Endpoint, + worker: Option>, + connections: HashMap>, + close: Option<(VarInt, Bytes)>, + exit_on_idle: bool, + incoming: VecDeque, + incoming_wakers: VecDeque, +} + +impl EndpointState { + fn handle_data(&mut self, meta: RecvMeta, buf: &[u8], respond_fn: impl Fn(Vec, Transmit)) { + let now = Instant::now(); + for data in buf[..meta.len] + .chunks(meta.stride.min(meta.len)) + .map(Into::into) + { + let mut resp_buf = Vec::new(); + match self.endpoint.handle( + now, + meta.remote, + meta.local_ip, + meta.ecn, + data, + &mut resp_buf, + ) { + Some(DatagramEvent::NewConnection(incoming)) => { + if self.close.is_none() { + self.incoming.push_back(incoming); + } else { + let transmit = self.endpoint.refuse(incoming, &mut resp_buf); + respond_fn(resp_buf, transmit); + } + } + Some(DatagramEvent::ConnectionEvent(ch, event)) => { + let _ = self + .connections + .get(&ch) + .unwrap() + .send(ConnectionEvent::Proto(event)); + } + Some(DatagramEvent::Response(transmit)) => respond_fn(resp_buf, transmit), + None => {} + } + } + } + + fn handle_event(&mut self, ch: ConnectionHandle, event: EndpointEvent) { + if event.is_drained() { + self.connections.remove(&ch); + } + if let Some(event) = self.endpoint.handle_event(ch, event) { + let _ = self + .connections + .get(&ch) + .unwrap() + .send(ConnectionEvent::Proto(event)); + } + } + + fn is_idle(&self) -> bool { + self.connections.is_empty() + } + + fn poll_incoming(&mut self, cx: &mut Context) -> Poll> { + if self.close.is_none() { + if let Some(incoming) = self.incoming.pop_front() { + Poll::Ready(Some(incoming)) + } else { + self.incoming_wakers.push_back(cx.waker().clone()); + Poll::Pending + } + } else { + Poll::Ready(None) + } + } + + fn new_connection( + &mut self, + handle: ConnectionHandle, + conn: quinn_proto::Connection, + socket: Socket, + events_tx: Sender<(ConnectionHandle, EndpointEvent)>, + ) -> Connecting { + let (tx, rx) = unbounded(); + if let Some((error_code, reason)) = &self.close { + tx.send(ConnectionEvent::Close(*error_code, reason.clone())) + .unwrap(); + } + self.connections.insert(handle, tx); + Connecting::new(handle, conn, socket, events_tx, rx) + } +} + +type ChannelPair = (Sender, Receiver); + +#[derive(Debug)] +pub(crate) struct EndpointInner { + state: Mutex, + socket: Socket, + ipv6: bool, + events: ChannelPair<(ConnectionHandle, EndpointEvent)>, + done: AtomicWaker, +} + +impl EndpointInner { + fn new( + socket: UdpSocket, + config: EndpointConfig, + server_config: Option, + ) -> io::Result { + let socket = Socket::new(socket)?; + let ipv6 = socket.local_addr()?.is_ipv6(); + let allow_mtud = !socket.may_fragment(); + + Ok(Self { + state: Mutex::new(EndpointState { + endpoint: quinn_proto::Endpoint::new( + Arc::new(config), + server_config.map(Arc::new), + allow_mtud, + None, + ), + worker: None, + connections: HashMap::default(), + close: None, + exit_on_idle: false, + incoming: VecDeque::new(), + incoming_wakers: VecDeque::new(), + }), + socket, + ipv6, + events: unbounded(), + done: AtomicWaker::new(), + }) + } + + fn connect( + &self, + remote: SocketAddr, + server_name: &str, + config: ClientConfig, + ) -> Result { + let mut state = self.state.lock().unwrap(); + + if state.worker.is_none() { + return Err(ConnectError::EndpointStopping); + } + if remote.is_ipv6() && !self.ipv6 { + return Err(ConnectError::InvalidRemoteAddress(remote)); + } + let remote = if self.ipv6 { + SocketAddr::V6(match remote { + SocketAddr::V4(addr) => { + SocketAddrV6::new(addr.ip().to_ipv6_mapped(), addr.port(), 0, 0) + } + SocketAddr::V6(addr) => addr, + }) + } else { + remote + }; + + let (handle, conn) = state + .endpoint + .connect(Instant::now(), config, remote, server_name)?; + + Ok(state.new_connection(handle, conn, self.socket.clone(), self.events.0.clone())) + } + + fn respond(&self, buf: Vec, transmit: Transmit) { + let socket = self.socket.clone(); + compio_runtime::spawn(async move { + let _ = socket.send(buf, &transmit).await; + }) + .detach(); + } + + pub(crate) fn accept( + &self, + incoming: quinn_proto::Incoming, + server_config: Option, + ) -> Result { + let mut state = self.state.lock().unwrap(); + let mut resp_buf = Vec::new(); + let now = Instant::now(); + match state + .endpoint + .accept(incoming, now, &mut resp_buf, server_config.map(Arc::new)) + { + Ok((handle, conn)) => { + Ok(state.new_connection(handle, conn, self.socket.clone(), self.events.0.clone())) + } + Err(err) => { + if let Some(transmit) = err.response { + self.respond(resp_buf, transmit); + } + Err(err.cause) + } + } + } + + pub(crate) fn refuse(&self, incoming: quinn_proto::Incoming) { + let mut state = self.state.lock().unwrap(); + let mut resp_buf = Vec::new(); + let transmit = state.endpoint.refuse(incoming, &mut resp_buf); + self.respond(resp_buf, transmit); + } + + pub(crate) fn retry( + &self, + incoming: quinn_proto::Incoming, + ) -> Result<(), quinn_proto::RetryError> { + let mut state = self.state.lock().unwrap(); + let mut resp_buf = Vec::new(); + let transmit = state.endpoint.retry(incoming, &mut resp_buf)?; + self.respond(resp_buf, transmit); + Ok(()) + } + + pub(crate) fn ignore(&self, incoming: quinn_proto::Incoming) { + let mut state = self.state.lock().unwrap(); + state.endpoint.ignore(incoming); + } + + async fn run(&self) -> io::Result<()> { + let respond_fn = |buf: Vec, transmit: Transmit| self.respond(buf, transmit); + + let mut recv_fut = pin!( + self.socket + .recv(Vec::with_capacity( + self.state + .lock() + .unwrap() + .endpoint + .config() + .get_max_udp_payload_size() + .min(64 * 1024) as usize + * self.socket.max_gro_segments(), + )) + .fuse() + ); + + let mut event_stream = self.events.1.stream().ready_chunks(100); + + loop { + let mut state = select! { + BufResult(res, recv_buf) = recv_fut => { + let mut state = self.state.lock().unwrap(); + match res { + Ok(meta) => state.handle_data(meta, &recv_buf, respond_fn), + Err(e) if e.kind() == io::ErrorKind::ConnectionReset => {} + #[cfg(windows)] + Err(e) if e.raw_os_error() == Some(windows_sys::Win32::Foundation::ERROR_PORT_UNREACHABLE as _) => {} + Err(e) => break Err(e), + } + recv_fut.set(self.socket.recv(recv_buf).fuse()); + state + }, + events = event_stream.select_next_some() => { + let mut state = self.state.lock().unwrap(); + for (ch, event) in events { + state.handle_event(ch, event); + } + state + }, + }; + + if state.exit_on_idle && state.is_idle() { + break Ok(()); + } + if !state.incoming.is_empty() { + let n = state.incoming.len().min(state.incoming_wakers.len()); + state.incoming_wakers.drain(..n).for_each(Waker::wake); + } + } + } +} + +/// A QUIC endpoint. +#[derive(Debug, Clone)] +pub struct Endpoint { + inner: Arc, + /// The client configuration used by `connect` + pub default_client_config: Option, +} + +impl Endpoint { + /// Create a QUIC endpoint. + pub fn new( + socket: UdpSocket, + config: EndpointConfig, + server_config: Option, + default_client_config: Option, + ) -> io::Result { + let inner = Arc::new(EndpointInner::new(socket, config, server_config)?); + let worker = compio_runtime::spawn({ + let inner = inner.clone(); + async move { + #[allow(unused)] + if let Err(e) = inner.run().await { + error!("I/O error: {}", e); + } + } + .in_current_span() + }); + inner.state.lock().unwrap().worker = Some(worker); + Ok(Self { + inner, + default_client_config, + }) + } + + /// Helper to construct an endpoint for use with outgoing connections only. + /// + /// Note that `addr` is the *local* address to bind to, which should usually + /// be a wildcard address like `0.0.0.0:0` or `[::]:0`, which allow + /// communication with any reachable IPv4 or IPv6 address respectively + /// from an OS-assigned port. + /// + /// If an IPv6 address is provided, the socket may dual-stack depending on + /// the platform, so as to allow communication with both IPv4 and IPv6 + /// addresses. As such, calling this method with the address `[::]:0` is a + /// reasonable default to maximize the ability to connect to other + /// address. + /// + /// IPv4 client is never dual-stack. + pub async fn client(addr: impl ToSocketAddrsAsync) -> io::Result { + // TODO: try to enable dual-stack on all platforms, notably Windows + let socket = UdpSocket::bind(addr).await?; + Self::new(socket, EndpointConfig::default(), None, None) + } + + /// Helper to construct an endpoint for use with both incoming and outgoing + /// connections + /// + /// Platform defaults for dual-stack sockets vary. For example, any socket + /// bound to a wildcard IPv6 address on Windows will not by default be + /// able to communicate with IPv4 addresses. Portable applications + /// should bind an address that matches the family they wish to + /// communicate within. + pub async fn server(addr: impl ToSocketAddrsAsync, config: ServerConfig) -> io::Result { + let socket = UdpSocket::bind(addr).await?; + Self::new(socket, EndpointConfig::default(), Some(config), None) + } + + /// Connect to a remote endpoint. + pub fn connect( + &self, + remote: SocketAddr, + server_name: &str, + config: Option, + ) -> Result { + let config = config + .or_else(|| self.default_client_config.clone()) + .ok_or(ConnectError::NoDefaultClientConfig)?; + + self.inner.connect(remote, server_name, config) + } + + /// Wait for the next incoming connection attempt from a client. + /// + /// Yields [`Incoming`]s, or `None` if the endpoint is + /// [`close`](Self::close)d. [`Incoming`] can be `await`ed to obtain the + /// final [`Connection`](crate::Connection), or used to e.g. filter + /// connection attempts or force address validation, or converted into an + /// intermediate `Connecting` future which can be used to e.g. send 0.5-RTT + /// data. + pub async fn wait_incoming(&self) -> Option { + future::poll_fn(|cx| self.inner.state.lock().unwrap().poll_incoming(cx)) + .await + .map(|incoming| Incoming::new(incoming, self.inner.clone())) + } + + /// Replace the server configuration, affecting new incoming connections + /// only. + /// + /// Useful for e.g. refreshing TLS certificates without disrupting existing + /// connections. + pub fn set_server_config(&self, server_config: Option) { + self.inner + .state + .lock() + .unwrap() + .endpoint + .set_server_config(server_config.map(Arc::new)) + } + + /// Get the local `SocketAddr` the underlying socket is bound to. + pub fn local_addr(&self) -> io::Result { + self.inner.socket.local_addr() + } + + /// Get the number of connections that are currently open. + pub fn open_connections(&self) -> usize { + self.inner.state.lock().unwrap().endpoint.open_connections() + } + + /// Close all of this endpoint's connections immediately and cease accepting + /// new connections. + /// + /// See [`Connection::close()`] for details. + /// + /// [`Connection::close()`]: crate::Connection::close + pub fn close(&self, error_code: VarInt, reason: &[u8]) { + let reason = Bytes::copy_from_slice(reason); + let mut state = self.inner.state.lock().unwrap(); + if state.close.is_some() { + return; + } + state.close = Some((error_code, reason.clone())); + for conn in state.connections.values() { + let _ = conn.send(ConnectionEvent::Close(error_code, reason.clone())); + } + state.incoming_wakers.drain(..).for_each(Waker::wake); + } + + // Modified from [`SharedFd::try_unwrap_inner`], see notes there. + unsafe fn try_unwrap_inner(this: &ManuallyDrop) -> Option { + let ptr = ManuallyDrop::new(std::ptr::read(&this.inner)); + match Arc::try_unwrap(ManuallyDrop::into_inner(ptr)) { + Ok(inner) => Some(inner), + Err(ptr) => { + std::mem::forget(ptr); + None + } + } + } + + /// Gracefully shutdown the endpoint. + /// + /// Wait for all connections on the endpoint to be cleanly shut down and + /// close the underlying socket. This will wait for all clones of the + /// endpoint, all connections and all streams to be dropped before + /// closing the socket. + /// + /// Waiting for this condition before exiting ensures that a good-faith + /// effort is made to notify peers of recent connection closes, whereas + /// exiting immediately could force them to wait out the idle timeout + /// period. + /// + /// Does not proactively close existing connections. Consider calling + /// [`close()`] if that is desired. + /// + /// [`close()`]: Endpoint::close + pub async fn shutdown(self) -> io::Result<()> { + let worker = self.inner.state.lock().unwrap().worker.take(); + if let Some(worker) = worker { + if self.inner.state.lock().unwrap().is_idle() { + worker.cancel().await; + } else { + self.inner.state.lock().unwrap().exit_on_idle = true; + let _ = worker.await; + } + } + + let this = ManuallyDrop::new(self); + let inner = future::poll_fn(move |cx| { + if let Some(inner) = unsafe { Self::try_unwrap_inner(&this) } { + return Poll::Ready(inner); + } + + this.inner.done.register(cx.waker()); + + if let Some(inner) = unsafe { Self::try_unwrap_inner(&this) } { + Poll::Ready(inner) + } else { + Poll::Pending + } + }) + .await; + + inner.socket.close().await + } +} + +impl Drop for Endpoint { + fn drop(&mut self) { + if Arc::strong_count(&self.inner) == 2 { + // There are actually two cases: + // 1. User is trying to shutdown the socket. + self.inner.done.wake(); + // 2. User dropped the endpoint but the worker is still running. + self.inner.state.lock().unwrap().exit_on_idle = true; + } + } +} diff --git a/compio-quic/src/incoming.rs b/compio-quic/src/incoming.rs new file mode 100644 index 00000000..d0eec213 --- /dev/null +++ b/compio-quic/src/incoming.rs @@ -0,0 +1,141 @@ +use std::{ + future::{Future, IntoFuture}, + net::{IpAddr, SocketAddr}, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + +use futures_util::FutureExt; +use quinn_proto::ServerConfig; +use thiserror::Error; + +use crate::{Connecting, Connection, ConnectionError, EndpointInner}; + +#[derive(Debug)] +pub(crate) struct IncomingInner { + pub(crate) incoming: quinn_proto::Incoming, + pub(crate) endpoint: Arc, +} + +/// An incoming connection for which the server has not yet begun its part +/// of the handshake. +#[derive(Debug)] +pub struct Incoming(Option); + +impl Incoming { + pub(crate) fn new(incoming: quinn_proto::Incoming, endpoint: Arc) -> Self { + Self(Some(IncomingInner { incoming, endpoint })) + } + + /// Attempt to accept this incoming connection (an error may still + /// occur). + pub fn accept(mut self) -> Result { + let inner = self.0.take().unwrap(); + Ok(inner.endpoint.accept(inner.incoming, None)?) + } + + /// Accept this incoming connection using a custom configuration. + /// + /// See [`accept()`] for more details. + /// + /// [`accept()`]: Incoming::accept + pub fn accept_with( + mut self, + server_config: ServerConfig, + ) -> Result { + let inner = self.0.take().unwrap(); + Ok(inner.endpoint.accept(inner.incoming, Some(server_config))?) + } + + /// Reject this incoming connection attempt. + pub fn refuse(mut self) { + let inner = self.0.take().unwrap(); + inner.endpoint.refuse(inner.incoming); + } + + /// Respond with a retry packet, requiring the client to retry with + /// address validation. + /// + /// Errors if `remote_address_validated()` is true. + pub fn retry(mut self) -> Result<(), RetryError> { + let inner = self.0.take().unwrap(); + inner + .endpoint + .retry(inner.incoming) + .map_err(|e| RetryError(Self::new(e.into_incoming(), inner.endpoint))) + } + + /// Ignore this incoming connection attempt, not sending any packet in + /// response. + pub fn ignore(mut self) { + let inner = self.0.take().unwrap(); + inner.endpoint.ignore(inner.incoming); + } + + /// The local IP address which was used when the peer established + /// the connection. + pub fn local_ip(&self) -> Option { + self.0.as_ref().unwrap().incoming.local_ip() + } + + /// The peer's UDP address. + pub fn remote_address(&self) -> SocketAddr { + self.0.as_ref().unwrap().incoming.remote_address() + } + + /// Whether the socket address that is initiating this connection has + /// been validated. + /// + /// This means that the sender of the initial packet has proved that + /// they can receive traffic sent to `self.remote_address()`. + pub fn remote_address_validated(&self) -> bool { + self.0.as_ref().unwrap().incoming.remote_address_validated() + } +} + +impl Drop for Incoming { + fn drop(&mut self) { + // Implicit reject, similar to Connection's implicit close + if let Some(inner) = self.0.take() { + inner.endpoint.refuse(inner.incoming); + } + } +} + +/// Error for attempting to retry an [`Incoming`] which already bears an +/// address validation token from a previous retry. +#[derive(Debug, Error)] +#[error("retry() with validated Incoming")] +pub struct RetryError(Incoming); + +impl RetryError { + /// Get the [`Incoming`] + pub fn into_incoming(self) -> Incoming { + self.0 + } +} + +/// Basic adapter to let [`Incoming`] be `await`-ed like a [`Connecting`]. +#[derive(Debug)] +pub struct IncomingFuture(Result); + +impl Future for IncomingFuture { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + match &mut self.0 { + Ok(connecting) => connecting.poll_unpin(cx), + Err(e) => Poll::Ready(Err(e.clone())), + } + } +} + +impl IntoFuture for Incoming { + type IntoFuture = IncomingFuture; + type Output = Result; + + fn into_future(self) -> Self::IntoFuture { + IncomingFuture(self.accept()) + } +} diff --git a/compio-quic/src/lib.rs b/compio-quic/src/lib.rs new file mode 100644 index 00000000..3e3e36ff --- /dev/null +++ b/compio-quic/src/lib.rs @@ -0,0 +1,73 @@ +//! QUIC implementation for compio +//! +//! Ported from [`quinn`]. +//! +//! [`quinn`]: https://docs.rs/quinn + +#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))] +#![warn(missing_docs)] + +pub use quinn_proto::{ + congestion, crypto, AckFrequencyConfig, ApplicationClose, Chunk, ClientConfig, ClosedStream, + ConfigError, ConnectError, ConnectionClose, ConnectionStats, EndpointConfig, IdleTimeout, + MtuDiscoveryConfig, ServerConfig, StreamId, Transmit, TransportConfig, VarInt, +}; + +mod builder; +mod connection; +mod endpoint; +mod incoming; +mod recv_stream; +mod send_stream; +mod socket; + +pub use builder::{ClientBuilder, ServerBuilder}; +pub use connection::{Connecting, Connection, ConnectionError}; +pub use endpoint::Endpoint; +pub use incoming::{Incoming, IncomingFuture}; +pub use recv_stream::{ReadError, ReadExactError, RecvStream}; +pub use send_stream::{SendStream, WriteError}; + +pub(crate) use crate::{ + connection::{ConnectionEvent, ConnectionInner}, + endpoint::EndpointInner, + socket::*, +}; + +/// Errors from [`SendStream::stopped`] and [`RecvStream::stopped`]. +#[derive(Debug, thiserror::Error, Clone, PartialEq, Eq)] +pub enum StoppedError { + /// The connection was lost + #[error("connection lost")] + ConnectionLost(#[from] ConnectionError), + /// This was a 0-RTT stream and the server rejected it + /// + /// Can only occur on clients for 0-RTT streams, which can be opened using + /// [`Connecting::into_0rtt()`]. + /// + /// [`Connecting::into_0rtt()`]: crate::Connecting::into_0rtt() + #[error("0-RTT rejected")] + ZeroRttRejected, +} + +impl From for std::io::Error { + fn from(x: StoppedError) -> Self { + use StoppedError::*; + let kind = match x { + ZeroRttRejected => std::io::ErrorKind::ConnectionReset, + ConnectionLost(_) => std::io::ErrorKind::NotConnected, + }; + Self::new(kind, x) + } +} + +/// HTTP/3 support via [`h3`]. +#[cfg(feature = "h3")] +pub mod h3 { + pub use h3::*; + + pub use crate::{ + connection::h3_impl::{BidiStream, OpenStreams}, + send_stream::h3_impl::SendStream, + }; +} diff --git a/compio-quic/src/recv_stream.rs b/compio-quic/src/recv_stream.rs new file mode 100644 index 00000000..93cff0c2 --- /dev/null +++ b/compio-quic/src/recv_stream.rs @@ -0,0 +1,573 @@ +use std::{ + collections::BTreeMap, + io, + sync::Arc, + task::{Context, Poll}, +}; + +use bytes::{BufMut, Bytes}; +use compio_buf::{BufResult, IoBufMut}; +use compio_io::AsyncRead; +use futures_util::{future::poll_fn, ready}; +use quinn_proto::{Chunk, Chunks, ClosedStream, ReadableError, StreamId, VarInt}; +use thiserror::Error; + +use crate::{ConnectionError, ConnectionInner, StoppedError}; + +/// A stream that can only be used to receive data +/// +/// `stop(0)` is implicitly called on drop unless: +/// - A variant of [`ReadError`] has been yielded by a read call +/// - [`stop()`] was called explicitly +/// +/// # Cancellation +/// +/// A `read` method is said to be *cancel-safe* when dropping its future before +/// the future becomes ready cannot lead to loss of stream data. This is true of +/// methods which succeed immediately when any progress is made, and is not true +/// of methods which might need to perform multiple reads internally before +/// succeeding. Each `read` method documents whether it is cancel-safe. +/// +/// # Common issues +/// +/// ## Data never received on a locally-opened stream +/// +/// Peers are not notified of streams until they or a later-numbered stream are +/// used to send data. If a bidirectional stream is locally opened but never +/// used to send, then the peer may never see it. Application protocols should +/// always arrange for the endpoint which will first transmit on a stream to be +/// the endpoint responsible for opening it. +/// +/// ## Data never received on a remotely-opened stream +/// +/// Verify that the stream you are receiving is the same one that the server is +/// sending on, e.g. by logging the [`id`] of each. Streams are always accepted +/// in the same order as they are created, i.e. ascending order by [`StreamId`]. +/// For example, even if a sender first transmits on bidirectional stream 1, the +/// first stream yielded by [`Connection::accept_bi`] on the receiver +/// will be bidirectional stream 0. +/// +/// [`stop()`]: RecvStream::stop +/// [`id`]: RecvStream::id +/// [`Connection::accept_bi`]: crate::Connection::accept_bi +#[derive(Debug)] +pub struct RecvStream { + conn: Arc, + stream: StreamId, + is_0rtt: bool, + all_data_read: bool, + reset: Option, +} + +impl RecvStream { + pub(crate) fn new(conn: Arc, stream: StreamId, is_0rtt: bool) -> Self { + Self { + conn, + stream, + is_0rtt, + all_data_read: false, + reset: None, + } + } + + /// Get the identity of this stream + pub fn id(&self) -> StreamId { + self.stream + } + + /// Check if this stream has been opened during 0-RTT. + /// + /// In which case any non-idempotent request should be considered dangerous + /// at the application level. Because read data is subject to replay + /// attacks. + pub fn is_0rtt(&self) -> bool { + self.is_0rtt + } + + /// Stop accepting data + /// + /// Discards unread data and notifies the peer to stop transmitting. Once + /// stopped, further attempts to operate on a stream will yield + /// `ClosedStream` errors. + pub fn stop(&mut self, error_code: VarInt) -> Result<(), ClosedStream> { + let mut state = self.conn.state(); + if self.is_0rtt && !state.check_0rtt() { + return Ok(()); + } + state.conn.recv_stream(self.stream).stop(error_code)?; + state.wake(); + self.all_data_read = true; + Ok(()) + } + + /// Completes when the stream has been reset by the peer or otherwise + /// closed. + /// + /// Yields `Some` with the reset error code when the stream is reset by the + /// peer. Yields `None` when the stream was previously + /// [`stop()`](Self::stop)ed, or when the stream was + /// [`finish()`](crate::SendStream::finish)ed by the peer and all data has + /// been received, after which it is no longer meaningful for the stream to + /// be reset. + /// + /// This operation is cancel-safe. + pub async fn stopped(&mut self) -> Result, StoppedError> { + poll_fn(|cx| { + let mut state = self.conn.state(); + + if self.is_0rtt && !state.check_0rtt() { + return Poll::Ready(Err(StoppedError::ZeroRttRejected)); + } + if let Some(code) = self.reset { + return Poll::Ready(Ok(Some(code))); + } + + match state.conn.recv_stream(self.stream).received_reset() { + Err(_) => Poll::Ready(Ok(None)), + Ok(Some(error_code)) => { + // Stream state has just now been freed, so the connection may need to issue new + // stream ID flow control credit + state.wake(); + Poll::Ready(Ok(Some(error_code))) + } + Ok(None) => { + if let Some(e) = &state.error { + return Poll::Ready(Err(e.clone().into())); + } + // Resets always notify readers, since a reset is an immediate read error. We + // could introduce a dedicated channel to reduce the risk of spurious wakeups, + // but that increased complexity is probably not justified, as an application + // that is expecting a reset is not likely to receive large amounts of data. + state.readable.insert(self.stream, cx.waker().clone()); + Poll::Pending + } + } + }) + .await + } + + /// Handle common logic related to reading out of a receive stream. + /// + /// This takes an `FnMut` closure that takes care of the actual reading + /// process, matching the detailed read semantics for the calling + /// function with a particular return type. The closure can read from + /// the passed `&mut Chunks` and has to return the status after reading: + /// the amount of data read, and the status after the final read call. + fn execute_poll_read( + &mut self, + cx: &mut Context, + ordered: bool, + mut read_fn: F, + ) -> Poll, ReadError>> + where + F: FnMut(&mut Chunks) -> ReadStatus, + { + use quinn_proto::ReadError::*; + + if self.all_data_read { + return Poll::Ready(Ok(None)); + } + + let mut state = self.conn.state(); + if self.is_0rtt && !state.check_0rtt() { + return Poll::Ready(Err(ReadError::ZeroRttRejected)); + } + + // If we stored an error during a previous call, return it now. This can happen + // if a `read_fn` both wants to return data and also returns an error in + // its final stream status. + let status = match self.reset { + Some(code) => ReadStatus::Failed(None, Reset(code)), + None => { + let mut recv = state.conn.recv_stream(self.stream); + let mut chunks = recv.read(ordered)?; + let status = read_fn(&mut chunks); + if chunks.finalize().should_transmit() { + state.wake(); + } + status + } + }; + + match status { + ReadStatus::Readable(read) => Poll::Ready(Ok(Some(read))), + ReadStatus::Finished(read) => { + self.all_data_read = true; + Poll::Ready(Ok(read)) + } + ReadStatus::Failed(read, Blocked) => match read { + Some(val) => Poll::Ready(Ok(Some(val))), + None => { + if let Some(error) = &state.error { + return Poll::Ready(Err(error.clone().into())); + } + state.readable.insert(self.stream, cx.waker().clone()); + Poll::Pending + } + }, + ReadStatus::Failed(read, Reset(error_code)) => match read { + None => { + self.all_data_read = true; + self.reset = Some(error_code); + Poll::Ready(Err(ReadError::Reset(error_code))) + } + done => { + self.reset = Some(error_code); + Poll::Ready(Ok(done)) + } + }, + } + } + + fn poll_read( + &mut self, + cx: &mut Context, + mut buf: impl BufMut, + ) -> Poll, ReadError>> { + if !buf.has_remaining_mut() { + return Poll::Ready(Ok(Some(0))); + } + + self.execute_poll_read(cx, true, |chunks| { + let mut read = 0; + loop { + if !buf.has_remaining_mut() { + // We know `read` is `true` because `buf.remaining()` was not 0 before + return ReadStatus::Readable(read); + } + + match chunks.next(buf.remaining_mut()) { + Ok(Some(chunk)) => { + read += chunk.bytes.len(); + buf.put(chunk.bytes); + } + res => { + return (if read == 0 { None } else { Some(read) }, res.err()).into(); + } + } + } + }) + } + + /// Read data contiguously from the stream. + /// + /// Yields the number of bytes read into `buf` on success, or `None` if the + /// stream was finished. + /// + /// This operation is cancel-safe. + pub async fn read(&mut self, mut buf: impl BufMut) -> Result, ReadError> { + poll_fn(|cx| self.poll_read(cx, &mut buf)).await + } + + /// Read an exact number of bytes contiguously from the stream. + /// + /// See [`read()`] for details. This operation is *not* cancel-safe. + /// + /// [`read()`]: RecvStream::read + pub async fn read_exact(&mut self, mut buf: impl BufMut) -> Result<(), ReadExactError> { + poll_fn(|cx| { + while buf.has_remaining_mut() { + if ready!(self.poll_read(cx, &mut buf))?.is_none() { + return Poll::Ready(Err(ReadExactError::FinishedEarly(buf.remaining_mut()))); + } + } + Poll::Ready(Ok(())) + }) + .await + } + + /// Read the next segment of data. + /// + /// Yields `None` if the stream was finished. Otherwise, yields a segment of + /// data and its offset in the stream. If `ordered` is `true`, the chunk's + /// offset will be immediately after the last data yielded by + /// [`read()`](Self::read) or [`read_chunk()`](Self::read_chunk). If + /// `ordered` is `false`, segments may be received in any order, and the + /// `Chunk`'s `offset` field can be used to determine ordering in the + /// caller. Unordered reads are less prone to head-of-line blocking within a + /// stream, but require the application to manage reassembling the original + /// data. + /// + /// Slightly more efficient than `read` due to not copying. Chunk boundaries + /// do not correspond to peer writes, and hence cannot be used as framing. + /// + /// This operation is cancel-safe. + pub async fn read_chunk( + &mut self, + max_length: usize, + ordered: bool, + ) -> Result, ReadError> { + poll_fn(|cx| { + self.execute_poll_read(cx, ordered, |chunks| match chunks.next(max_length) { + Ok(Some(chunk)) => ReadStatus::Readable(chunk), + res => (None, res.err()).into(), + }) + }) + .await + } + + /// Read the next segments of data. + /// + /// Fills `bufs` with the segments of data beginning immediately after the + /// last data yielded by `read` or `read_chunk`, or `None` if the stream was + /// finished. + /// + /// Slightly more efficient than `read` due to not copying. Chunk boundaries + /// do not correspond to peer writes, and hence cannot be used as framing. + /// + /// This operation is cancel-safe. + pub async fn read_chunks(&mut self, bufs: &mut [Bytes]) -> Result, ReadError> { + if bufs.is_empty() { + return Ok(Some(0)); + } + + poll_fn(|cx| { + self.execute_poll_read(cx, true, |chunks| { + let mut read = 0; + loop { + if read >= bufs.len() { + // We know `read > 0` because `bufs` cannot be empty here + return ReadStatus::Readable(read); + } + + match chunks.next(usize::MAX) { + Ok(Some(chunk)) => { + bufs[read] = chunk.bytes; + read += 1; + } + res => { + return (if read == 0 { None } else { Some(read) }, res.err()).into(); + } + } + } + }) + }) + .await + } + + /// Convenience method to read all remaining data into a buffer. + /// + /// Uses unordered reads to be more efficient than using [`AsyncRead`]. If + /// unordered reads have already been made, the resulting buffer may have + /// gaps containing zero. + /// + /// Depending on [`BufMut`] implementation, this method may fail with + /// [`ReadError::BufferTooShort`] if the buffer is not large enough to + /// hold the entire stream. For example when using a `&mut [u8]` it will + /// never receive bytes more than the length of the slice, but when using a + /// `&mut Vec` it will allocate more memory as needed. + /// + /// This operation is *not* cancel-safe. + pub async fn read_to_end(&mut self, mut buf: impl BufMut) -> Result { + let mut start = u64::MAX; + let mut end = 0; + let mut chunks = BTreeMap::new(); + loop { + let Some(chunk) = self.read_chunk(usize::MAX, false).await? else { + break; + }; + start = start.min(chunk.offset); + end = end.max(chunk.offset + chunk.bytes.len() as u64); + if end - start > buf.remaining_mut() as u64 { + return Err(ReadError::BufferTooShort); + } + chunks.insert(chunk.offset, chunk.bytes); + } + let mut last = 0; + for (offset, bytes) in chunks { + let offset = (offset - start) as usize; + if offset > last { + buf.put_bytes(0, offset - last); + } + last = offset + bytes.len(); + buf.put(bytes); + } + Ok((end - start) as usize) + } +} + +impl Drop for RecvStream { + fn drop(&mut self) { + let mut state = self.conn.state(); + + // clean up any previously registered wakers + state.readable.remove(&self.stream); + + if state.error.is_some() || (self.is_0rtt && !state.check_0rtt()) { + return; + } + if !self.all_data_read { + // Ignore ClosedStream errors + let _ = state.conn.recv_stream(self.stream).stop(0u32.into()); + state.wake(); + } + } +} + +enum ReadStatus { + Readable(T), + Finished(Option), + Failed(Option, quinn_proto::ReadError), +} + +impl From<(Option, Option)> for ReadStatus { + fn from(status: (Option, Option)) -> Self { + match status { + (read, None) => Self::Finished(read), + (read, Some(e)) => Self::Failed(read, e), + } + } +} + +/// Errors that arise from reading from a stream. +#[derive(Debug, Error, Clone, PartialEq, Eq)] +pub enum ReadError { + /// The peer abandoned transmitting data on this stream. + /// + /// Carries an application-defined error code. + #[error("stream reset by peer: error {0}")] + Reset(VarInt), + /// The connection was lost. + #[error("connection lost")] + ConnectionLost(#[from] ConnectionError), + /// The stream has already been stopped, finished, or reset. + #[error("closed stream")] + ClosedStream, + /// Attempted an ordered read following an unordered read. + /// + /// Performing an unordered read allows discontinuities to arise in the + /// receive buffer of a stream which cannot be recovered, making further + /// ordered reads impossible. + #[error("ordered read after unordered read")] + IllegalOrderedRead, + /// This was a 0-RTT stream and the server rejected it. + /// + /// Can only occur on clients for 0-RTT streams, which can be opened using + /// [`Connecting::into_0rtt()`]. + /// + /// [`Connecting::into_0rtt()`]: crate::Connecting::into_0rtt() + #[error("0-RTT rejected")] + ZeroRttRejected, + /// The stream is larger than the user-supplied buffer capacity. + /// + /// Can only occur when using [`read_to_end()`](RecvStream::read_to_end). + #[error("buffer too short")] + BufferTooShort, +} + +impl From for ReadError { + fn from(e: ReadableError) -> Self { + match e { + ReadableError::ClosedStream => Self::ClosedStream, + ReadableError::IllegalOrderedRead => Self::IllegalOrderedRead, + } + } +} + +impl From for ReadError { + fn from(e: StoppedError) -> Self { + match e { + StoppedError::ConnectionLost(e) => Self::ConnectionLost(e), + StoppedError::ZeroRttRejected => Self::ZeroRttRejected, + } + } +} + +impl From for io::Error { + fn from(x: ReadError) -> Self { + use self::ReadError::*; + let kind = match x { + Reset { .. } | ZeroRttRejected => io::ErrorKind::ConnectionReset, + ConnectionLost(_) | ClosedStream => io::ErrorKind::NotConnected, + IllegalOrderedRead | BufferTooShort => io::ErrorKind::InvalidInput, + }; + Self::new(kind, x) + } +} + +/// Errors that arise from reading from a stream. +#[derive(Debug, Error, Clone, PartialEq, Eq)] +pub enum ReadExactError { + /// The stream finished before all bytes were read + #[error("stream finished early (expected {0} bytes more)")] + FinishedEarly(usize), + /// A read error occurred + #[error(transparent)] + ReadError(#[from] ReadError), +} + +impl AsyncRead for RecvStream { + async fn read(&mut self, mut buf: B) -> BufResult { + let res = self + .read(buf.as_mut_slice()) + .await + .map(|n| { + let n = n.unwrap_or_default(); + unsafe { buf.set_buf_init(n) } + n + }) + .map_err(Into::into); + BufResult(res, buf) + } +} + +#[cfg(feature = "io-compat")] +impl futures_util::AsyncRead for RecvStream { + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + self.get_mut() + .poll_read(cx, buf) + .map_ok(Option::unwrap_or_default) + .map_err(Into::into) + } +} + +#[cfg(feature = "h3")] +pub(crate) mod h3_impl { + use h3::quic::{self, Error}; + + use super::*; + + impl Error for ReadError { + fn is_timeout(&self) -> bool { + matches!(self, Self::ConnectionLost(ConnectionError::TimedOut)) + } + + fn err_code(&self) -> Option { + match self { + Self::ConnectionLost(ConnectionError::ApplicationClosed( + quinn_proto::ApplicationClose { error_code, .. }, + )) + | Self::Reset(error_code) => Some(error_code.into_inner()), + _ => None, + } + } + } + + impl quic::RecvStream for RecvStream { + type Buf = Bytes; + type Error = ReadError; + + fn poll_data( + &mut self, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + self.execute_poll_read(cx, true, |chunks| match chunks.next(usize::MAX) { + Ok(Some(chunk)) => ReadStatus::Readable(chunk.bytes), + res => (None, res.err()).into(), + }) + } + + fn stop_sending(&mut self, error_code: u64) { + self.stop(error_code.try_into().expect("invalid error_code")) + .ok(); + } + + fn recv_id(&self) -> quic::StreamId { + self.stream.0.try_into().unwrap() + } + } +} diff --git a/compio-quic/src/send_stream.rs b/compio-quic/src/send_stream.rs new file mode 100644 index 00000000..3e0a1a2a --- /dev/null +++ b/compio-quic/src/send_stream.rs @@ -0,0 +1,474 @@ +use std::{ + io, + sync::Arc, + task::{Context, Poll}, +}; + +use bytes::Bytes; +use compio_buf::{BufResult, IoBuf}; +use compio_io::AsyncWrite; +use futures_util::{future::poll_fn, ready}; +use quinn_proto::{ClosedStream, FinishError, StreamId, VarInt, Written}; +use thiserror::Error; + +use crate::{ConnectionError, ConnectionInner, StoppedError}; + +/// A stream that can only be used to send data. +/// +/// If dropped, streams that haven't been explicitly [`reset()`] will be +/// implicitly [`finish()`]ed, continuing to (re)transmit previously written +/// data until it has been fully acknowledged or the connection is closed. +/// +/// # Cancellation +/// +/// A `write` method is said to be *cancel-safe* when dropping its future before +/// the future becomes ready will always result in no data being written to the +/// stream. This is true of methods which succeed immediately when any progress +/// is made, and is not true of methods which might need to perform multiple +/// writes internally before succeeding. Each `write` method documents whether +/// it is cancel-safe. +/// +/// [`reset()`]: SendStream::reset +/// [`finish()`]: SendStream::finish +#[derive(Debug)] +pub struct SendStream { + conn: Arc, + stream: StreamId, + is_0rtt: bool, +} + +impl SendStream { + pub(crate) fn new(conn: Arc, stream: StreamId, is_0rtt: bool) -> Self { + Self { + conn, + stream, + is_0rtt, + } + } + + /// Get the identity of this stream + pub fn id(&self) -> StreamId { + self.stream + } + + /// Notify the peer that no more data will ever be written to this stream. + /// + /// It is an error to write to a stream after `finish()`ing it. [`reset()`] + /// may still be called after `finish` to abandon transmission of any stream + /// data that might still be buffered. + /// + /// To wait for the peer to receive all buffered stream data, see + /// [`stopped()`]. + /// + /// May fail if [`finish()`] or [`reset()`] was previously called.This + /// error is harmless and serves only to indicate that the caller may have + /// incorrect assumptions about the stream's state. + /// + /// [`reset()`]: Self::reset + /// [`stopped()`]: Self::stopped + /// [`finish()`]: Self::finish + pub fn finish(&mut self) -> Result<(), ClosedStream> { + let mut state = self.conn.state(); + match state.conn.send_stream(self.stream).finish() { + Ok(()) => { + state.wake(); + Ok(()) + } + Err(FinishError::ClosedStream) => Err(ClosedStream::new()), + // Harmless. If the application needs to know about stopped streams at this point, + // it should call `stopped`. + Err(FinishError::Stopped(_)) => Ok(()), + } + } + + /// Close the stream immediately. + /// + /// No new data can be written after calling this method. Locally buffered + /// data is dropped, and previously transmitted data will no longer be + /// retransmitted if lost. If an attempt has already been made to finish + /// the stream, the peer may still receive all written data. + /// + /// May fail if [`finish()`](Self::finish) or [`reset()`](Self::reset) was + /// previously called. This error is harmless and serves only to + /// indicate that the caller may have incorrect assumptions about the + /// stream's state. + pub fn reset(&mut self, error_code: VarInt) -> Result<(), ClosedStream> { + let mut state = self.conn.state(); + if self.is_0rtt && !state.check_0rtt() { + return Ok(()); + } + state.conn.send_stream(self.stream).reset(error_code)?; + state.wake(); + Ok(()) + } + + /// Set the priority of the stream. + /// + /// Every stream has an initial priority of 0. Locally buffered data + /// from streams with higher priority will be transmitted before data + /// from streams with lower priority. Changing the priority of a stream + /// with pending data may only take effect after that data has been + /// transmitted. Using many different priority levels per connection may + /// have a negative impact on performance. + pub fn set_priority(&self, priority: i32) -> Result<(), ClosedStream> { + self.conn + .state() + .conn + .send_stream(self.stream) + .set_priority(priority) + } + + /// Get the priority of the stream + pub fn priority(&self) -> Result { + self.conn.state().conn.send_stream(self.stream).priority() + } + + /// Completes when the peer stops the stream or reads the stream to + /// completion. + /// + /// Yields `Some` with the stop error code if the peer stops the stream. + /// Yields `None` if the local side [`finish()`](Self::finish)es the stream + /// and then the peer acknowledges receipt of all stream data (although not + /// necessarily the processing of it), after which the peer closing the + /// stream is no longer meaningful. + /// + /// For a variety of reasons, the peer may not send acknowledgements + /// immediately upon receiving data. As such, relying on `stopped` to + /// know when the peer has read a stream to completion may introduce + /// more latency than using an application-level response of some sort. + pub async fn stopped(&mut self) -> Result, StoppedError> { + poll_fn(|cx| { + let mut state = self.conn.state(); + if self.is_0rtt && !state.check_0rtt() { + return Poll::Ready(Err(StoppedError::ZeroRttRejected)); + } + match state.conn.send_stream(self.stream).stopped() { + Err(_) => Poll::Ready(Ok(None)), + Ok(Some(error_code)) => Poll::Ready(Ok(Some(error_code))), + Ok(None) => { + if let Some(e) = &state.error { + return Poll::Ready(Err(e.clone().into())); + } + state.stopped.insert(self.stream, cx.waker().clone()); + Poll::Pending + } + } + }) + .await + } + + fn execute_poll_write(&mut self, cx: &mut Context, f: F) -> Poll> + where + F: FnOnce(quinn_proto::SendStream) -> Result, + { + let mut state = self.conn.try_state()?; + if self.is_0rtt && !state.check_0rtt() { + return Poll::Ready(Err(WriteError::ZeroRttRejected)); + } + match f(state.conn.send_stream(self.stream)) { + Ok(r) => { + state.wake(); + Poll::Ready(Ok(r)) + } + Err(e) => match e.try_into() { + Ok(e) => Poll::Ready(Err(e)), + Err(()) => { + state.writable.insert(self.stream, cx.waker().clone()); + Poll::Pending + } + }, + } + } + + /// Write bytes to the stream. + /// + /// Yields the number of bytes written on success. Congestion and flow + /// control may cause this to be shorter than `buf.len()`, indicating + /// that only a prefix of `buf` was written. + /// + /// This operation is cancel-safe. + pub async fn write(&mut self, buf: &[u8]) -> Result { + poll_fn(|cx| self.execute_poll_write(cx, |mut stream| stream.write(buf))).await + } + + /// Convenience method to write an entire buffer to the stream. + /// + /// This operation is *not* cancel-safe. + pub async fn write_all(&mut self, buf: &[u8]) -> Result<(), WriteError> { + let mut count = 0; + poll_fn(|cx| { + loop { + if count == buf.len() { + return Poll::Ready(Ok(())); + } + let n = + ready!(self.execute_poll_write(cx, |mut stream| stream.write(&buf[count..])))?; + count += n; + } + }) + .await + } + + /// Write chunks to the stream. + /// + /// Yields the number of bytes and chunks written on success. + /// Congestion and flow control may cause this to be shorter than + /// `buf.len()`, indicating that only a prefix of `bufs` was written. + /// + /// This operation is cancel-safe. + pub async fn write_chunks(&mut self, bufs: &mut [Bytes]) -> Result { + poll_fn(|cx| self.execute_poll_write(cx, |mut stream| stream.write_chunks(bufs))).await + } + + /// Convenience method to write an entire list of chunks to the stream. + /// + /// This operation is *not* cancel-safe. + pub async fn write_all_chunks(&mut self, bufs: &mut [Bytes]) -> Result<(), WriteError> { + let mut chunks = 0; + poll_fn(|cx| { + loop { + if chunks == bufs.len() { + return Poll::Ready(Ok(())); + } + let written = ready!(self.execute_poll_write(cx, |mut stream| { + stream.write_chunks(&mut bufs[chunks..]) + }))?; + chunks += written.chunks; + } + }) + .await + } +} + +impl Drop for SendStream { + fn drop(&mut self) { + let mut state = self.conn.state(); + + // clean up any previously registered wakers + state.stopped.remove(&self.stream); + state.writable.remove(&self.stream); + + if state.error.is_some() || (self.is_0rtt && !state.check_0rtt()) { + return; + } + match state.conn.send_stream(self.stream).finish() { + Ok(()) => state.wake(), + Err(FinishError::Stopped(reason)) => { + if state.conn.send_stream(self.stream).reset(reason).is_ok() { + state.wake(); + } + } + // Already finished or reset, which is fine. + Err(FinishError::ClosedStream) => {} + } + } +} + +/// Errors that arise from writing to a stream +#[derive(Debug, Error, Clone, PartialEq, Eq)] +pub enum WriteError { + /// The peer is no longer accepting data on this stream + /// + /// Carries an application-defined error code. + #[error("sending stopped by peer: error {0}")] + Stopped(VarInt), + /// The connection was lost + #[error("connection lost")] + ConnectionLost(#[from] ConnectionError), + /// The stream has already been finished or reset + #[error("closed stream")] + ClosedStream, + /// This was a 0-RTT stream and the server rejected it + /// + /// Can only occur on clients for 0-RTT streams, which can be opened using + /// [`Connecting::into_0rtt()`]. + /// + /// [`Connecting::into_0rtt()`]: crate::Connecting::into_0rtt() + #[error("0-RTT rejected")] + ZeroRttRejected, + /// Error when the stream is not ready, because it is still sending + /// data from a previous call + #[cfg(feature = "h3")] + #[error("stream not ready")] + NotReady, +} + +impl TryFrom for WriteError { + type Error = (); + + fn try_from(value: quinn_proto::WriteError) -> Result { + use quinn_proto::WriteError::*; + match value { + Stopped(e) => Ok(Self::Stopped(e)), + ClosedStream => Ok(Self::ClosedStream), + Blocked => Err(()), + } + } +} + +impl From for WriteError { + fn from(x: StoppedError) -> Self { + match x { + StoppedError::ConnectionLost(e) => Self::ConnectionLost(e), + StoppedError::ZeroRttRejected => Self::ZeroRttRejected, + } + } +} + +impl From for io::Error { + fn from(x: WriteError) -> Self { + use WriteError::*; + let kind = match x { + Stopped(_) | ZeroRttRejected => io::ErrorKind::ConnectionReset, + ConnectionLost(_) | ClosedStream => io::ErrorKind::NotConnected, + #[cfg(feature = "h3")] + NotReady => io::ErrorKind::Other, + }; + Self::new(kind, x) + } +} + +impl AsyncWrite for SendStream { + async fn write(&mut self, buf: T) -> BufResult { + let res = self.write(buf.as_slice()).await.map_err(Into::into); + BufResult(res, buf) + } + + async fn flush(&mut self) -> io::Result<()> { + Ok(()) + } + + async fn shutdown(&mut self) -> io::Result<()> { + self.finish()?; + Ok(()) + } +} + +#[cfg(feature = "io-compat")] +impl futures_util::AsyncWrite for SendStream { + fn poll_write( + self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.get_mut() + .execute_poll_write(cx, |mut stream| stream.write(buf)) + .map_err(Into::into) + } + + fn poll_flush(self: std::pin::Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close(self: std::pin::Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + self.get_mut().finish()?; + Poll::Ready(Ok(())) + } +} + +#[cfg(feature = "h3")] +pub(crate) mod h3_impl { + use bytes::Buf; + use h3::quic::{self, Error, WriteBuf}; + + use super::*; + + impl Error for WriteError { + fn is_timeout(&self) -> bool { + matches!(self, Self::ConnectionLost(ConnectionError::TimedOut)) + } + + fn err_code(&self) -> Option { + match self { + Self::ConnectionLost(ConnectionError::ApplicationClosed( + quinn_proto::ApplicationClose { error_code, .. }, + )) + | Self::Stopped(error_code) => Some(error_code.into_inner()), + _ => None, + } + } + } + + /// A wrapper around `SendStream` that implements `quic::SendStream` and + /// `quic::SendStreamUnframed`. + pub struct SendStream { + inner: super::SendStream, + buf: Option>, + } + + impl SendStream { + pub(crate) fn new(conn: Arc, stream: StreamId, is_0rtt: bool) -> Self { + Self { + inner: super::SendStream::new(conn, stream, is_0rtt), + buf: None, + } + } + } + + impl quic::SendStream for SendStream + where + B: Buf, + { + type Error = WriteError; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + if let Some(data) = &mut self.buf { + while data.has_remaining() { + let n = ready!( + self.inner + .execute_poll_write(cx, |mut stream| stream.write(data.chunk())) + )?; + data.advance(n); + } + } + self.buf = None; + Poll::Ready(Ok(())) + } + + fn send_data>>(&mut self, data: T) -> Result<(), Self::Error> { + if self.buf.is_some() { + return Err(WriteError::NotReady); + } + self.buf = Some(data.into()); + Ok(()) + } + + fn poll_finish(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(self.inner.finish().map_err(|_| WriteError::ClosedStream)) + } + + fn reset(&mut self, reset_code: u64) { + self.inner + .reset(reset_code.try_into().unwrap_or(VarInt::MAX)) + .ok(); + } + + fn send_id(&self) -> quic::StreamId { + self.inner.stream.0.try_into().unwrap() + } + } + + impl quic::SendStreamUnframed for SendStream + where + B: Buf, + { + fn poll_send( + &mut self, + cx: &mut Context<'_>, + buf: &mut D, + ) -> Poll> { + // This signifies a bug in implementation + debug_assert!( + self.buf.is_some(), + "poll_send called while send stream is not ready" + ); + + let n = ready!( + self.inner + .execute_poll_write(cx, |mut stream| stream.write(buf.chunk())) + )?; + buf.advance(n); + Poll::Ready(Ok(n)) + } + } +} diff --git a/compio-quic/src/socket.rs b/compio-quic/src/socket.rs new file mode 100644 index 00000000..4ea39e97 --- /dev/null +++ b/compio-quic/src/socket.rs @@ -0,0 +1,783 @@ +//! Simple wrapper around UDP socket with advanced features useful for QUIC, +//! ported from [`quinn-udp`] +//! +//! Differences from [`quinn-udp`]: +//! - [quinn-rs/quinn#1516] is not implemented +//! - `recvmmsg` is not available +//! +//! [`quinn-udp`]: https://docs.rs/quinn-udp +//! [quinn-rs/quinn#1516]: https://github.com/quinn-rs/quinn/pull/1516 + +use std::{ + future::Future, + io, + net::{IpAddr, SocketAddr}, + ops::{Deref, DerefMut}, + sync::atomic::{AtomicBool, Ordering}, +}; + +use compio_buf::{buf_try, BufResult, IntoInner, IoBuf, IoBufMut, SetBufInit}; +use compio_net::{CMsgBuilder, CMsgIter, UdpSocket}; +use quinn_proto::{EcnCodepoint, Transmit}; +#[cfg(windows)] +use windows_sys::Win32::Networking::WinSock; + +/// Metadata for a single buffer filled with bytes received from the network +/// +/// This associated buffer can contain one or more datagrams, see [`stride`]. +/// +/// [`stride`]: RecvMeta::stride +#[derive(Debug)] +pub(crate) struct RecvMeta { + /// The source address of the datagram(s) contained in the buffer + pub remote: SocketAddr, + /// The number of bytes the associated buffer has + pub len: usize, + /// The size of a single datagram in the associated buffer + /// + /// When GRO (Generic Receive Offload) is used this indicates the size of a + /// single datagram inside the buffer. If the buffer is larger, that is + /// if [`len`] is greater then this value, then the individual datagrams + /// contained have their boundaries at `stride` increments from the + /// start. The last datagram could be smaller than `stride`. + /// + /// [`len`]: RecvMeta::len + pub stride: usize, + /// The Explicit Congestion Notification bits for the datagram(s) in the + /// buffer + pub ecn: Option, + /// The destination IP address which was encoded in this datagram + /// + /// Populated on platforms: Windows, Linux, Android, FreeBSD, OpenBSD, + /// NetBSD, macOS, and iOS. + pub local_ip: Option, +} + +const CMSG_LEN: usize = 128; + +struct Ancillary { + inner: [u8; N], + len: usize, + #[cfg(unix)] + _align: [libc::cmsghdr; 0], + #[cfg(windows)] + _align: [WinSock::CMSGHDR; 0], +} + +impl Ancillary { + fn new() -> Self { + Self { + inner: [0u8; N], + len: N, + _align: [], + } + } +} + +unsafe impl IoBuf for Ancillary { + fn as_buf_ptr(&self) -> *const u8 { + self.inner.as_buf_ptr() + } + + fn buf_len(&self) -> usize { + self.len + } + + fn buf_capacity(&self) -> usize { + N + } +} + +impl SetBufInit for Ancillary { + unsafe fn set_buf_init(&mut self, len: usize) { + debug_assert!(len <= N); + self.len = len; + } +} + +unsafe impl IoBufMut for Ancillary { + fn as_buf_mut_ptr(&mut self) -> *mut u8 { + self.inner.as_buf_mut_ptr() + } +} + +impl Deref for Ancillary { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + &self.inner[0..self.len] + } +} + +impl DerefMut for Ancillary { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner[0..self.len] + } +} + +#[cfg(target_os = "linux")] +#[inline] +fn max_gso_segments(socket: &UdpSocket) -> io::Result { + unsafe { + socket.get_socket_option::(libc::SOL_UDP, libc::UDP_SEGMENT)?; + } + Ok(64) +} +#[cfg(windows)] +#[inline] +fn max_gso_segments(socket: &UdpSocket) -> io::Result { + unsafe { + socket.get_socket_option::(WinSock::IPPROTO_UDP, WinSock::UDP_SEND_MSG_SIZE)?; + } + Ok(512) +} +#[cfg(not(any(target_os = "linux", windows)))] +#[inline] +fn max_gso_segments(_socket: &UdpSocket) -> io::Result { + Err(io::Error::from(io::ErrorKind::Unsupported)) +} + +macro_rules! set_socket_option { + ($socket:expr, $level:expr, $name:expr, $value:expr $(,)?) => { + match unsafe { $socket.set_socket_option($level, $name, $value) } { + Ok(()) => true, + Err(e) => { + compio_log::warn!( + level = stringify!($level), + name = stringify!($name), + "failed to set socket option: {}", + e + ); + if e.kind() == io::ErrorKind::InvalidInput { + true + } else if e.raw_os_error() + == Some( + #[cfg(unix)] + libc::ENOPROTOOPT, + #[cfg(windows)] + WinSock::WSAENOPROTOOPT, + ) + { + false + } else { + return Err(e); + } + } + } + }; +} + +#[derive(Debug)] +pub(crate) struct Socket { + inner: UdpSocket, + max_gro_segments: usize, + max_gso_segments: usize, + may_fragment: bool, + has_gso_error: AtomicBool, + #[cfg(target_os = "freebsd")] + encode_src_ip_v4: bool, +} + +impl Socket { + pub fn new(socket: UdpSocket) -> io::Result { + let is_ipv6 = socket.local_addr()?.is_ipv6(); + #[cfg(unix)] + let only_v6 = unsafe { + is_ipv6 + && socket.get_socket_option::(libc::IPPROTO_IPV6, libc::IPV6_V6ONLY)? + != 0 + }; + #[cfg(windows)] + let only_v6 = unsafe { + is_ipv6 + && socket.get_socket_option::(WinSock::IPPROTO_IPV6, WinSock::IPV6_V6ONLY)? != 0 + }; + let is_ipv4 = socket.local_addr()?.is_ipv4() || !only_v6; + + // ECN + if is_ipv4 { + #[cfg(all(unix, not(any(target_os = "openbsd", target_os = "netbsd"))))] + set_socket_option!(socket, libc::IPPROTO_IP, libc::IP_RECVTOS, &1); + #[cfg(windows)] + set_socket_option!(socket, WinSock::IPPROTO_IP, WinSock::IP_ECN, &1); + } + if is_ipv6 { + #[cfg(unix)] + set_socket_option!(socket, libc::IPPROTO_IPV6, libc::IPV6_RECVTCLASS, &1); + #[cfg(windows)] + set_socket_option!(socket, WinSock::IPPROTO_IPV6, WinSock::IPV6_ECN, &1); + } + + // pktinfo / destination address + if is_ipv4 { + #[cfg(any(target_os = "linux", target_os = "android"))] + set_socket_option!(socket, libc::IPPROTO_IP, libc::IP_PKTINFO, &1); + #[cfg(any( + target_os = "freebsd", + target_os = "openbsd", + target_os = "netbsd", + target_os = "macos", + target_os = "ios" + ))] + set_socket_option!(socket, libc::IPPROTO_IP, libc::IP_RECVDSTADDR, &1); + #[cfg(windows)] + set_socket_option!(socket, WinSock::IPPROTO_IP, WinSock::IP_PKTINFO, &1); + } + if is_ipv6 { + #[cfg(unix)] + set_socket_option!(socket, libc::IPPROTO_IPV6, libc::IPV6_RECVPKTINFO, &1); + #[cfg(windows)] + set_socket_option!(socket, WinSock::IPPROTO_IPV6, WinSock::IPV6_PKTINFO, &1); + } + + // disable fragmentation + let mut may_fragment = false; + if is_ipv4 { + #[cfg(any(target_os = "linux", target_os = "android"))] + { + may_fragment |= set_socket_option!( + socket, + libc::IPPROTO_IP, + libc::IP_MTU_DISCOVER, + &libc::IP_PMTUDISC_PROBE, + ); + } + #[cfg(any( + target_os = "aix", + target_os = "freebsd", + target_os = "macos", + target_os = "ios" + ))] + { + may_fragment |= set_socket_option!(socket, libc::IPPROTO_IP, libc::IP_DONTFRAG, &1); + } + #[cfg(windows)] + { + may_fragment |= + set_socket_option!(socket, WinSock::IPPROTO_IP, WinSock::IP_DONTFRAGMENT, &1); + } + } + if is_ipv6 { + #[cfg(any(target_os = "linux", target_os = "android"))] + { + may_fragment |= set_socket_option!( + socket, + libc::IPPROTO_IPV6, + libc::IPV6_MTU_DISCOVER, + &libc::IPV6_PMTUDISC_PROBE, + ); + } + #[cfg(all(unix, not(any(target_os = "openbsd", target_os = "netbsd"))))] + { + may_fragment |= + set_socket_option!(socket, libc::IPPROTO_IPV6, libc::IPV6_DONTFRAG, &1); + } + #[cfg(any(target_os = "openbsd", target_os = "netbsd"))] + { + // FIXME: workaround until https://github.com/rust-lang/libc/pull/3716 is released (at least in 0.2.155) + may_fragment |= set_socket_option!(socket, libc::IPPROTO_IPV6, 62, &1); + } + #[cfg(windows)] + { + may_fragment |= + set_socket_option!(socket, WinSock::IPPROTO_IPV6, WinSock::IPV6_DONTFRAG, &1); + } + } + + // GRO + #[allow(unused_mut)] // only mutable on Linux and Windows + let mut max_gro_segments = 1; + #[cfg(target_os = "linux")] + if set_socket_option!(socket, libc::SOL_UDP, libc::UDP_GRO, &1) { + max_gro_segments = 64; + } + #[cfg(windows)] + if set_socket_option!( + socket, + WinSock::IPPROTO_UDP, + WinSock::UDP_RECV_MAX_COALESCED_SIZE, + &(u16::MAX as u32), + ) { + max_gro_segments = 64; + } + + // GSO + let max_gso_segments = max_gso_segments(&socket).unwrap_or(1); + + #[cfg(target_os = "freebsd")] + let encode_src_ip_v4 = + socket.local_addr().unwrap().ip() == IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED); + + Ok(Self { + inner: socket, + max_gro_segments, + max_gso_segments, + may_fragment, + has_gso_error: AtomicBool::new(false), + #[cfg(target_os = "freebsd")] + encode_src_ip_v4, + }) + } + + #[inline] + pub fn local_addr(&self) -> io::Result { + self.inner.local_addr() + } + + #[inline] + pub fn may_fragment(&self) -> bool { + self.may_fragment + } + + #[inline] + pub fn max_gro_segments(&self) -> usize { + self.max_gro_segments + } + + #[inline] + pub fn max_gso_segments(&self) -> usize { + if self.has_gso_error.load(Ordering::Relaxed) { + 1 + } else { + self.max_gso_segments + } + } + + pub async fn recv(&self, buffer: T) -> BufResult { + let control = Ancillary::::new(); + + let BufResult(res, (buffer, control)) = self.inner.recv_msg(buffer, control).await; + let ((len, _, remote), buffer) = buf_try!(res, buffer); + + let mut ecn_bits = 0u8; + let mut local_ip = None; + #[allow(unused_mut)] // only mutable on Linux + let mut stride = len; + + // SAFETY: `control` contains valid data + unsafe { + for cmsg in CMsgIter::new(&control) { + #[cfg(windows)] + const UDP_COALESCED_INFO: i32 = WinSock::UDP_COALESCED_INFO as i32; + + match (cmsg.level(), cmsg.ty()) { + // ECN + #[cfg(unix)] + (libc::IPPROTO_IP, libc::IP_TOS) => ecn_bits = *cmsg.data::(), + #[cfg(all(unix, not(any(target_os = "openbsd", target_os = "netbsd"))))] + (libc::IPPROTO_IP, libc::IP_RECVTOS) => ecn_bits = *cmsg.data::(), + #[cfg(unix)] + (libc::IPPROTO_IPV6, libc::IPV6_TCLASS) => { + // NOTE: It's OK to use `c_int` instead of `u8` on Apple systems + ecn_bits = *cmsg.data::() as u8 + } + #[cfg(windows)] + (WinSock::IPPROTO_IP, WinSock::IP_ECN) + | (WinSock::IPPROTO_IPV6, WinSock::IPV6_ECN) => { + ecn_bits = *cmsg.data::() as u8 + } + + // pktinfo / destination address + #[cfg(any(target_os = "linux", target_os = "android"))] + (libc::IPPROTO_IP, libc::IP_PKTINFO) => { + let pktinfo = cmsg.data::(); + local_ip = Some(IpAddr::from(pktinfo.ipi_addr.s_addr.to_ne_bytes())); + } + #[cfg(any( + target_os = "freebsd", + target_os = "openbsd", + target_os = "netbsd", + target_os = "macos", + target_os = "ios", + ))] + (libc::IPPROTO_IP, libc::IP_RECVDSTADDR) => { + let in_addr = cmsg.data::(); + local_ip = Some(IpAddr::from(in_addr.s_addr.to_ne_bytes())); + } + #[cfg(windows)] + (WinSock::IPPROTO_IP, WinSock::IP_PKTINFO) => { + let pktinfo = cmsg.data::(); + local_ip = Some(IpAddr::from(pktinfo.ipi_addr.S_un.S_addr.to_ne_bytes())); + } + #[cfg(unix)] + (libc::IPPROTO_IPV6, libc::IPV6_PKTINFO) => { + let pktinfo = cmsg.data::(); + local_ip = Some(IpAddr::from(pktinfo.ipi6_addr.s6_addr)); + } + #[cfg(windows)] + (WinSock::IPPROTO_IPV6, WinSock::IPV6_PKTINFO) => { + let pktinfo = cmsg.data::(); + local_ip = Some(IpAddr::from(pktinfo.ipi6_addr.u.Byte)); + } + + // GRO + #[cfg(target_os = "linux")] + (libc::SOL_UDP, libc::UDP_GRO) => stride = *cmsg.data::() as usize, + #[cfg(windows)] + (WinSock::IPPROTO_UDP, UDP_COALESCED_INFO) => { + stride = *cmsg.data::() as usize + } + + _ => {} + } + } + } + + let meta = RecvMeta { + remote, + len, + stride, + ecn: EcnCodepoint::from_bits(ecn_bits), + local_ip, + }; + BufResult(Ok(meta), buffer) + } + + pub async fn send(&self, buffer: T, transmit: &Transmit) -> BufResult<(), T> { + let is_ipv4 = transmit.destination.ip().to_canonical().is_ipv4(); + let ecn = transmit.ecn.map_or(0, |x| x as u8); + + let mut control = Ancillary::::new(); + let mut builder = CMsgBuilder::new(&mut control); + + // ECN + if is_ipv4 { + #[cfg(all(unix, not(any(target_os = "freebsd", target_os = "netbsd"))))] + builder.try_push(libc::IPPROTO_IP, libc::IP_TOS, ecn as libc::c_int); + #[cfg(target_os = "freebsd")] + builder.try_push(libc::IPPROTO_IP, libc::IP_TOS, ecn as libc::c_uchar); + #[cfg(windows)] + builder.try_push(WinSock::IPPROTO_IP, WinSock::IP_ECN, ecn as i32); + } else { + #[cfg(unix)] + builder.try_push(libc::IPPROTO_IPV6, libc::IPV6_TCLASS, ecn as libc::c_int); + #[cfg(windows)] + builder.try_push(WinSock::IPPROTO_IPV6, WinSock::IPV6_ECN, ecn as i32); + } + + // pktinfo / destination address + match transmit.src_ip { + Some(IpAddr::V4(ip)) => { + let addr = u32::from_ne_bytes(ip.octets()); + #[cfg(any(target_os = "linux", target_os = "android"))] + { + let pktinfo = libc::in_pktinfo { + ipi_ifindex: 0, + ipi_spec_dst: libc::in_addr { s_addr: addr }, + ipi_addr: libc::in_addr { s_addr: 0 }, + }; + builder.try_push(libc::IPPROTO_IP, libc::IP_PKTINFO, pktinfo); + } + #[cfg(any( + target_os = "freebsd", + target_os = "openbsd", + target_os = "netbsd", + target_os = "macos", + target_os = "ios", + ))] + { + #[cfg(target_os = "freebsd")] + let encode_src_ip_v4 = self.encode_src_ip_v4; + #[cfg(any( + target_os = "openbsd", + target_os = "netbsd", + target_os = "macos", + target_os = "ios", + ))] + let encode_src_ip_v4 = true; + + if encode_src_ip_v4 { + let addr = libc::in_addr { s_addr: addr }; + builder.try_push(libc::IPPROTO_IP, libc::IP_RECVDSTADDR, addr); + } + } + #[cfg(windows)] + { + let pktinfo = WinSock::IN_PKTINFO { + ipi_addr: WinSock::IN_ADDR { + S_un: WinSock::IN_ADDR_0 { S_addr: addr }, + }, + ipi_ifindex: 0, + }; + builder.try_push(WinSock::IPPROTO_IP, WinSock::IP_PKTINFO, pktinfo); + } + } + Some(IpAddr::V6(ip)) => { + #[cfg(unix)] + { + let pktinfo = libc::in6_pktinfo { + ipi6_ifindex: 0, + ipi6_addr: libc::in6_addr { + s6_addr: ip.octets(), + }, + }; + builder.try_push(libc::IPPROTO_IPV6, libc::IPV6_PKTINFO, pktinfo); + } + #[cfg(windows)] + { + let pktinfo = WinSock::IN6_PKTINFO { + ipi6_addr: WinSock::IN6_ADDR { + u: WinSock::IN6_ADDR_0 { Byte: ip.octets() }, + }, + ipi6_ifindex: 0, + }; + builder.try_push(WinSock::IPPROTO_IPV6, WinSock::IPV6_PKTINFO, pktinfo); + } + } + None => {} + } + + // GSO + if let Some(segment_size) = transmit.segment_size { + #[cfg(target_os = "linux")] + builder.try_push(libc::SOL_UDP, libc::UDP_SEGMENT, segment_size as u16); + #[cfg(windows)] + builder.try_push( + WinSock::IPPROTO_UDP, + WinSock::UDP_SEND_MSG_SIZE, + segment_size as u32, + ); + #[cfg(not(any(target_os = "linux", windows)))] + let _ = segment_size; + } + + let len = builder.finish(); + control.len = len; + + let buffer = buffer.slice(0..transmit.size); + let BufResult(res, (buffer, _)) = self + .inner + .send_msg(buffer, control, transmit.destination) + .await; + let buffer = buffer.into_inner(); + match res { + Ok(_) => BufResult(Ok(()), buffer), + Err(e) => { + #[cfg(target_os = "linux")] + if let Some(libc::EIO) | Some(libc::EINVAL) = e.raw_os_error() { + if self.max_gso_segments() > 1 { + self.has_gso_error.store(true, Ordering::Relaxed); + } + } + BufResult(Err(e), buffer) + } + } + } + + pub fn close(self) -> impl Future> { + self.inner.close() + } +} + +impl Clone for Socket { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + may_fragment: self.may_fragment, + max_gro_segments: self.max_gro_segments, + max_gso_segments: self.max_gso_segments, + has_gso_error: AtomicBool::new(self.has_gso_error.load(Ordering::Relaxed)), + #[cfg(target_os = "freebsd")] + encode_src_ip_v4: self.encode_src_ip_v4.clone(), + } + } +} + +#[cfg(test)] +mod tests { + use std::net::{Ipv4Addr, Ipv6Addr}; + + use compio_driver::AsRawFd; + use socket2::{Domain, Protocol, Socket as Socket2, Type}; + + use super::*; + + async fn test_send_recv( + passive: Socket, + active: Socket, + content: T, + transmit: Transmit, + ) { + let passive_addr = passive.local_addr().unwrap(); + let active_addr = active.local_addr().unwrap(); + + let (_, content) = active.send(content, &transmit).await.unwrap(); + + let segment_size = transmit.segment_size.unwrap_or(transmit.size); + let expected_datagrams = transmit.size / segment_size; + let mut datagrams = 0; + while datagrams < expected_datagrams { + let (meta, buf) = passive + .recv(Vec::with_capacity(u16::MAX as usize)) + .await + .unwrap(); + let segments = meta.len / meta.stride; + for i in 0..segments { + assert_eq!( + &content.as_slice() + [(datagrams + i) * segment_size..(datagrams + i + 1) * segment_size], + &buf[(i * meta.stride)..((i + 1) * meta.stride)] + ); + } + datagrams += segments; + + assert_eq!(meta.ecn, transmit.ecn); + + assert_eq!(meta.remote.port(), active_addr.port()); + for addr in [meta.remote.ip(), meta.local_ip.unwrap()] { + match (active_addr.is_ipv6(), passive_addr.is_ipv6()) { + (_, false) => assert_eq!(addr, Ipv4Addr::LOCALHOST), + (false, true) => assert!( + addr == Ipv4Addr::LOCALHOST || addr == Ipv4Addr::LOCALHOST.to_ipv6_mapped() + ), + (true, true) => assert!( + addr == Ipv6Addr::LOCALHOST || addr == Ipv4Addr::LOCALHOST.to_ipv6_mapped() + ), + } + } + } + assert_eq!(datagrams, expected_datagrams); + } + + /// Helper function to create dualstack udp socket. + /// This is only used for testing. + fn bind_udp_dualstack() -> io::Result { + #[cfg(unix)] + use std::os::fd::{FromRawFd, IntoRawFd}; + #[cfg(windows)] + use std::os::windows::io::{FromRawSocket, IntoRawSocket}; + + let socket = Socket2::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?; + socket.set_only_v6(false)?; + socket.bind(&SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).into())?; + + compio_runtime::Runtime::with_current(|r| r.attach(socket.as_raw_fd()))?; + #[cfg(unix)] + unsafe { + Ok(UdpSocket::from_raw_fd(socket.into_raw_fd())) + } + #[cfg(windows)] + unsafe { + Ok(UdpSocket::from_raw_socket(socket.into_raw_socket())) + } + } + + #[compio_macros::test] + async fn basic() { + let passive = Socket::new(UdpSocket::bind("[::1]:0").await.unwrap()).unwrap(); + let active = Socket::new(UdpSocket::bind("[::1]:0").await.unwrap()).unwrap(); + let content = b"hello"; + let transmit = Transmit { + destination: passive.local_addr().unwrap(), + ecn: None, + size: content.len(), + segment_size: None, + src_ip: None, + }; + test_send_recv(passive, active, content, transmit).await; + } + + #[compio_macros::test] + #[cfg_attr(any(target_os = "openbsd", target_os = "netbsd"), ignore)] + async fn ecn_v4() { + let passive = Socket::new(UdpSocket::bind("127.0.0.1:0").await.unwrap()).unwrap(); + let active = Socket::new(UdpSocket::bind("127.0.0.1:0").await.unwrap()).unwrap(); + for ecn in [EcnCodepoint::Ect0, EcnCodepoint::Ect1] { + let content = b"hello"; + let transmit = Transmit { + destination: passive.local_addr().unwrap(), + ecn: Some(ecn), + size: content.len(), + segment_size: None, + src_ip: None, + }; + test_send_recv(passive.clone(), active.clone(), content, transmit).await; + } + } + + #[compio_macros::test] + async fn ecn_v6() { + let passive = Socket::new(UdpSocket::bind("[::1]:0").await.unwrap()).unwrap(); + let active = Socket::new(UdpSocket::bind("[::1]:0").await.unwrap()).unwrap(); + for ecn in [EcnCodepoint::Ect0, EcnCodepoint::Ect1] { + let content = b"hello"; + let transmit = Transmit { + destination: passive.local_addr().unwrap(), + ecn: Some(ecn), + size: content.len(), + segment_size: None, + src_ip: None, + }; + test_send_recv(passive.clone(), active.clone(), content, transmit).await; + } + } + + #[compio_macros::test] + #[cfg_attr(any(target_os = "openbsd", target_os = "netbsd"), ignore)] + async fn ecn_dualstack() { + let passive = Socket::new(bind_udp_dualstack().unwrap()).unwrap(); + + let mut dst_v4 = passive.local_addr().unwrap(); + dst_v4.set_ip(IpAddr::V4(Ipv4Addr::LOCALHOST)); + let mut dst_v6 = dst_v4; + dst_v6.set_ip(IpAddr::V6(Ipv6Addr::LOCALHOST)); + + for (src, dst) in [("[::1]:0", dst_v6), ("127.0.0.1:0", dst_v4)] { + let active = Socket::new(UdpSocket::bind(src).await.unwrap()).unwrap(); + + for ecn in [EcnCodepoint::Ect0, EcnCodepoint::Ect1] { + let content = b"hello"; + let transmit = Transmit { + destination: dst, + ecn: Some(ecn), + size: content.len(), + segment_size: None, + src_ip: None, + }; + test_send_recv(passive.clone(), active.clone(), content, transmit).await; + } + } + } + + #[compio_macros::test] + #[cfg_attr(any(target_os = "openbsd", target_os = "netbsd"), ignore)] + async fn ecn_v4_mapped_v6() { + let passive = Socket::new(UdpSocket::bind("127.0.0.1:0").await.unwrap()).unwrap(); + let active = Socket::new(bind_udp_dualstack().unwrap()).unwrap(); + + let mut dst_addr = passive.local_addr().unwrap(); + dst_addr.set_ip(IpAddr::V6(Ipv4Addr::LOCALHOST.to_ipv6_mapped())); + + for ecn in [EcnCodepoint::Ect0, EcnCodepoint::Ect1] { + let content = b"hello"; + let transmit = Transmit { + destination: dst_addr, + ecn: Some(ecn), + size: content.len(), + segment_size: None, + src_ip: None, + }; + test_send_recv(passive.clone(), active.clone(), content, transmit).await; + } + } + + #[compio_macros::test] + #[cfg_attr(not(any(target_os = "linux", target_os = "windows")), ignore)] + async fn gso() { + let passive = Socket::new(UdpSocket::bind("[::1]:0").await.unwrap()).unwrap(); + let active = Socket::new(UdpSocket::bind("[::1]:0").await.unwrap()).unwrap(); + + let max_segments = active.max_gso_segments(); + const SEGMENT_SIZE: usize = 128; + let content = vec![0xAB; SEGMENT_SIZE * max_segments]; + + let transmit = Transmit { + destination: passive.local_addr().unwrap(), + ecn: None, + size: content.len(), + segment_size: Some(SEGMENT_SIZE), + src_ip: None, + }; + test_send_recv(passive, active, content, transmit).await; + } +} diff --git a/compio-quic/tests/basic.rs b/compio-quic/tests/basic.rs new file mode 100644 index 00000000..4f0ed775 --- /dev/null +++ b/compio-quic/tests/basic.rs @@ -0,0 +1,269 @@ +use std::{ + net::{IpAddr, Ipv4Addr, SocketAddr}, + sync::Arc, + time::{Duration, Instant}, +}; + +use compio_quic::{ClientBuilder, ConnectionError, Endpoint, TransportConfig}; +use futures_util::join; + +mod common; +use common::{config_pair, subscribe}; + +#[compio_macros::test] +async fn handshake_timeout() { + let _guard = subscribe(); + + let endpoint = Endpoint::client("127.0.0.1:0").await.unwrap(); + + const IDLE_TIMEOUT: Duration = Duration::from_millis(100); + + let mut transport_config = TransportConfig::default(); + transport_config + .max_idle_timeout(Some(IDLE_TIMEOUT.try_into().unwrap())) + .initial_rtt(Duration::from_millis(10)); + let mut client_config = ClientBuilder::new_with_no_server_verification().build(); + client_config.transport_config(Arc::new(transport_config)); + + let start = Instant::now(); + match endpoint + .connect( + SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 1), + "localhost", + Some(client_config), + ) + .unwrap() + .await + { + Err(ConnectionError::TimedOut) => {} + Err(e) => panic!("unexpected error: {e:?}"), + Ok(_) => panic!("unexpected success"), + } + let dt = start.elapsed(); + assert!(dt > IDLE_TIMEOUT && dt < 2 * IDLE_TIMEOUT); +} + +#[compio_macros::test] +async fn close_endpoint() { + let _guard = subscribe(); + + let endpoint = ClientBuilder::new_with_no_server_verification() + .bind("127.0.0.1:0") + .await + .unwrap(); + + let conn = endpoint + .connect( + SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 1), + "localhost", + None, + ) + .unwrap(); + endpoint.close(0u32.into(), b""); + match conn.await { + Err(ConnectionError::LocallyClosed) => (), + Err(e) => panic!("unexpected error: {e}"), + Ok(_) => { + panic!("unexpected success"); + } + } +} + +async fn endpoint() -> Endpoint { + let (server_config, client_config) = config_pair(None); + let mut endpoint = Endpoint::server("127.0.0.1:0", server_config) + .await + .unwrap(); + endpoint.default_client_config = Some(client_config); + endpoint +} + +#[compio_macros::test] +async fn read_after_close() { + let _guard = subscribe(); + + let endpoint = endpoint().await; + + const MSG: &[u8] = b"goodbye!"; + + join!( + async { + let conn = endpoint.wait_incoming().await.unwrap().await.unwrap(); + let mut s = conn.open_uni().unwrap(); + s.write_all(MSG).await.unwrap(); + s.finish().unwrap(); + // Wait for the stream to be closed, one way or another. + let _ = s.stopped().await; + }, + async { + let conn = endpoint + .connect(endpoint.local_addr().unwrap(), "localhost", None) + .unwrap() + .await + .unwrap(); + let mut recv = conn.accept_uni().await.unwrap(); + let mut buf = vec![]; + recv.read_to_end(&mut buf).await.unwrap(); + assert_eq!(buf, MSG); + }, + ); +} + +#[compio_macros::test] +async fn export_keying_material() { + let _guard = subscribe(); + + let endpoint = endpoint().await; + + let (conn1, conn2) = join!( + async { + endpoint + .connect(endpoint.local_addr().unwrap(), "localhost", None) + .unwrap() + .await + .unwrap() + }, + async { endpoint.wait_incoming().await.unwrap().await.unwrap() }, + ); + let mut buf1 = [0u8; 64]; + let mut buf2 = [0u8; 64]; + conn1 + .export_keying_material(&mut buf1, b"qaq", b"qwq") + .unwrap(); + conn2 + .export_keying_material(&mut buf2, b"qaq", b"qwq") + .unwrap(); + assert_eq!(buf1, buf2); +} + +#[compio_macros::test] +async fn zero_rtt() { + let _guard = subscribe(); + + let endpoint = endpoint().await; + + const MSG0: &[u8] = b"zero"; + const MSG1: &[u8] = b"one"; + + join!( + async { + for _ in 0..2 { + let conn = endpoint + .wait_incoming() + .await + .unwrap() + .accept() + .unwrap() + .into_0rtt() + .unwrap(); + join!( + async { + while let Ok(mut recv) = conn.accept_uni().await { + let mut buf = vec![]; + recv.read_to_end(&mut buf).await.unwrap(); + assert_eq!(buf, MSG0); + } + }, + async { + let mut send = conn.open_uni().unwrap(); + send.write_all(MSG0).await.unwrap(); + send.finish().unwrap(); + conn.accepted_0rtt().await.unwrap(); + let mut send = conn.open_uni().unwrap(); + send.write_all(MSG1).await.unwrap(); + send.finish().unwrap(); + // no need to wait for the stream to be closed due to + // the `while` loop above + }, + ); + } + }, + async { + { + let conn = endpoint + .connect(endpoint.local_addr().unwrap(), "localhost", None) + .unwrap() + .into_0rtt() + .unwrap_err() + .await + .unwrap(); + + let mut buf = vec![]; + let mut recv = conn.accept_uni().await.unwrap(); + recv.read_to_end(&mut buf).await.expect("read_to_end"); + assert_eq!(buf, MSG0); + + buf.clear(); + let mut recv = conn.accept_uni().await.unwrap(); + recv.read_to_end(&mut buf).await.expect("read_to_end"); + assert_eq!(buf, MSG1); + } + + let conn = endpoint + .connect(endpoint.local_addr().unwrap(), "localhost", None) + .unwrap() + .into_0rtt() + .unwrap(); + + let mut send = conn.open_uni().unwrap(); + send.write_all(MSG0).await.unwrap(); + send.finish().unwrap(); + + let mut buf = vec![]; + let mut recv = conn.accept_uni().await.unwrap(); + recv.read_to_end(&mut buf).await.expect("read_to_end"); + assert_eq!(buf, MSG0); + + assert!(conn.accepted_0rtt().await.unwrap()); + + buf.clear(); + let mut recv = conn.accept_uni().await.unwrap(); + recv.read_to_end(&mut buf).await.expect("read_to_end"); + assert_eq!(buf, MSG1); + }, + ); +} + +#[compio_macros::test] +async fn two_datagram_readers() { + let _guard = subscribe(); + + let endpoint = endpoint().await; + + const MSG1: &[u8] = b"one"; + const MSG2: &[u8] = b"two"; + + let (conn1, conn2) = join!( + async { + endpoint + .connect(endpoint.local_addr().unwrap(), "localhost", None) + .unwrap() + .await + .unwrap() + }, + async { endpoint.wait_incoming().await.unwrap().await.unwrap() }, + ); + + let (tx, rx) = flume::bounded::<()>(1); + + let (a, b, _) = join!( + async { + let x = conn1.recv_datagram().await.unwrap(); + let _ = tx.try_send(()); + x + }, + async { + let x = conn1.recv_datagram().await.unwrap(); + let _ = tx.try_send(()); + x + }, + async { + conn2.send_datagram(MSG1.into()).unwrap(); + rx.recv_async().await.unwrap(); + conn2.send_datagram_wait(MSG2.into()).await.unwrap(); + } + ); + + assert!(a == MSG1 || b == MSG1); + assert!(a == MSG2 || b == MSG2); +} diff --git a/compio-quic/tests/common/mod.rs b/compio-quic/tests/common/mod.rs new file mode 100644 index 00000000..08745b3d --- /dev/null +++ b/compio-quic/tests/common/mod.rs @@ -0,0 +1,34 @@ +use std::sync::Arc; + +use compio_log::subscriber::DefaultGuard; +use compio_quic::{ClientBuilder, ClientConfig, ServerBuilder, ServerConfig, TransportConfig}; +use tracing_subscriber::{util::SubscriberInitExt, EnvFilter}; + +pub fn subscribe() -> DefaultGuard { + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .finish() + .set_default() +} + +pub fn config_pair(transport: Option) -> (ServerConfig, ClientConfig) { + let rcgen::CertifiedKey { cert, key_pair } = + rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(); + let cert = cert.der().clone(); + let key_der = key_pair.serialize_der().try_into().unwrap(); + + let mut server_config = ServerBuilder::new_with_single_cert(vec![cert.clone()], key_der) + .unwrap() + .build(); + let mut client_config = ClientBuilder::new_with_empty_roots() + .with_custom_certificate(cert) + .unwrap() + .with_no_crls() + .build(); + if let Some(transport) = transport { + let transport = Arc::new(transport); + server_config.transport_config(transport.clone()); + client_config.transport_config(transport); + } + (server_config, client_config) +} diff --git a/compio-quic/tests/control.rs b/compio-quic/tests/control.rs new file mode 100644 index 00000000..6610feb9 --- /dev/null +++ b/compio-quic/tests/control.rs @@ -0,0 +1,91 @@ +use compio_quic::{ConnectionError, Endpoint, TransportConfig}; + +mod common; +use common::{config_pair, subscribe}; +use futures_util::join; + +#[compio_macros::test] +async fn ip_blocking() { + let _guard = subscribe(); + + let (server_config, client_config) = config_pair(None); + + let server = Endpoint::server("127.0.0.1:0", server_config) + .await + .unwrap(); + let server_addr = server.local_addr().unwrap(); + + let client1 = Endpoint::client("127.0.0.1:0").await.unwrap(); + let client1_addr = client1.local_addr().unwrap(); + let client2 = Endpoint::client("127.0.0.1:0").await.unwrap(); + + let srv = compio_runtime::spawn(async move { + loop { + let incoming = server.wait_incoming().await.unwrap(); + if incoming.remote_address() == client1_addr { + incoming.refuse(); + } else if incoming.remote_address_validated() { + incoming.await.unwrap(); + } else { + incoming.retry().unwrap(); + } + } + }); + + let e = client1 + .connect(server_addr, "localhost", Some(client_config.clone())) + .unwrap() + .await + .unwrap_err(); + assert!(matches!(e, ConnectionError::ConnectionClosed(_))); + client2 + .connect(server_addr, "localhost", Some(client_config)) + .unwrap() + .await + .unwrap(); + + let _ = srv.cancel().await; +} + +#[compio_macros::test] +async fn stream_id_flow_control() { + let _guard = subscribe(); + + let mut cfg = TransportConfig::default(); + cfg.max_concurrent_uni_streams(1u32.into()); + + let (server_config, client_config) = config_pair(Some(cfg)); + let mut endpoint = Endpoint::server("127.0.0.1:0", server_config) + .await + .unwrap(); + endpoint.default_client_config = Some(client_config); + + let (conn1, conn2) = join!( + async { + endpoint + .connect(endpoint.local_addr().unwrap(), "localhost", None) + .unwrap() + .await + .unwrap() + }, + async { endpoint.wait_incoming().await.unwrap().await.unwrap() }, + ); + + // If `open_uni_wait` doesn't get unblocked when the previous stream is dropped, + // this will time out. + join!( + async { + conn1.open_uni_wait().await.unwrap(); + }, + async { + conn1.open_uni_wait().await.unwrap(); + }, + async { + conn1.open_uni_wait().await.unwrap(); + }, + async { + conn2.accept_uni().await.unwrap(); + conn2.accept_uni().await.unwrap(); + } + ); +} diff --git a/compio-quic/tests/echo.rs b/compio-quic/tests/echo.rs new file mode 100644 index 00000000..9ac40196 --- /dev/null +++ b/compio-quic/tests/echo.rs @@ -0,0 +1,191 @@ +use std::{ + array, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, +}; + +use bytes::Bytes; +use compio_quic::{Endpoint, RecvStream, SendStream, TransportConfig}; + +mod common; +use common::{config_pair, subscribe}; +use futures_util::join; +use rand::{rngs::StdRng, RngCore, SeedableRng}; + +struct EchoArgs { + client_addr: SocketAddr, + server_addr: SocketAddr, + nr_streams: usize, + stream_size: usize, + receive_window: Option, + stream_receive_window: Option, +} + +async fn echo((mut send, mut recv): (SendStream, RecvStream)) { + loop { + // These are 32 buffers, for reading approximately 32kB at once + let mut bufs: [Bytes; 32] = array::from_fn(|_| Bytes::new()); + + match recv.read_chunks(&mut bufs).await.unwrap() { + Some(n) => { + send.write_all_chunks(&mut bufs[..n]).await.unwrap(); + } + None => break, + } + } + + let _ = send.finish(); +} + +/// This is just an arbitrary number to generate deterministic test data +const SEED: u64 = 0x12345678; + +fn gen_data(size: usize) -> Vec { + let mut rng = StdRng::seed_from_u64(SEED); + let mut buf = vec![0; size]; + rng.fill_bytes(&mut buf); + buf +} + +async fn run_echo(args: EchoArgs) { + // Use small receive windows + let mut transport_config = TransportConfig::default(); + if let Some(receive_window) = args.receive_window { + transport_config.receive_window(receive_window.try_into().unwrap()); + } + if let Some(stream_receive_window) = args.stream_receive_window { + transport_config.stream_receive_window(stream_receive_window.try_into().unwrap()); + } + transport_config.max_concurrent_bidi_streams(1_u8.into()); + transport_config.max_concurrent_uni_streams(1_u8.into()); + + let (server_config, client_config) = config_pair(Some(transport_config)); + + let server = Endpoint::server(args.server_addr, server_config) + .await + .unwrap(); + let client = Endpoint::client(args.client_addr).await.unwrap(); + + join!( + async { + let conn = server.wait_incoming().await.unwrap().await.unwrap(); + + while let Ok(stream) = conn.accept_bi().await { + compio_runtime::spawn(echo(stream)).detach(); + } + }, + async { + let conn = client + .connect( + server.local_addr().unwrap(), + "localhost", + Some(client_config), + ) + .unwrap() + .await + .unwrap(); + + for _ in 0..args.nr_streams { + let (mut send, mut recv) = conn.open_bi_wait().await.unwrap(); + let msg = gen_data(args.stream_size); + + let (_, data) = join!( + async { + send.write_all(&msg).await.unwrap(); + send.finish().unwrap(); + }, + async { + let mut buf = vec![]; + recv.read_to_end(&mut buf).await.unwrap(); + buf + } + ); + + assert_eq!(data, msg); + } + } + ); +} + +#[compio_macros::test] +async fn echo_v6() { + let _guard = subscribe(); + run_echo(EchoArgs { + client_addr: SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0), + server_addr: SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 0), + nr_streams: 1, + stream_size: 10 * 1024, + receive_window: None, + stream_receive_window: None, + }) + .await; +} + +#[compio_macros::test] +async fn echo_v4() { + let _guard = subscribe(); + run_echo(EchoArgs { + client_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), + server_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0), + nr_streams: 1, + stream_size: 10 * 1024, + receive_window: None, + stream_receive_window: None, + }) + .await; +} + +#[compio_macros::test] +#[cfg_attr(any(target_os = "openbsd", target_os = "netbsd", windows), ignore)] +async fn echo_dualstack() { + let _guard = subscribe(); + run_echo(EchoArgs { + client_addr: SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0), + server_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0), + nr_streams: 1, + stream_size: 10 * 1024, + receive_window: None, + stream_receive_window: None, + }) + .await; +} + +#[compio_macros::test] +async fn stress_receive_window() { + run_echo(EchoArgs { + client_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), + server_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0), + nr_streams: 50, + stream_size: 25 * 1024 + 11, + receive_window: Some(37), + stream_receive_window: Some(100 * 1024 * 1024), + }) + .await; +} + +#[compio_macros::test] +async fn stress_stream_receive_window() { + // Note that there is no point in running this with too many streams, + // since the window is only active within a stream. + run_echo(EchoArgs { + client_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), + server_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0), + nr_streams: 2, + stream_size: 250 * 1024 + 11, + receive_window: Some(100 * 1024 * 1024), + stream_receive_window: Some(37), + }) + .await; +} + +#[compio_macros::test] +async fn stress_both_windows() { + run_echo(EchoArgs { + client_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), + server_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0), + nr_streams: 50, + stream_size: 25 * 1024 + 11, + receive_window: Some(37), + stream_receive_window: Some(37), + }) + .await; +} diff --git a/compio-tls/Cargo.toml b/compio-tls/Cargo.toml index 34bb7d90..ef70c561 100644 --- a/compio-tls/Cargo.toml +++ b/compio-tls/Cargo.toml @@ -19,7 +19,7 @@ compio-buf = { workspace = true } compio-io = { workspace = true, features = ["compat"] } native-tls = { version = "0.2.11", optional = true } -rustls = { version = "0.23.1", default-features = false, optional = true, features = [ +rustls = { workspace = true, default-features = false, optional = true, features = [ "logging", "std", "tls12", @@ -30,7 +30,7 @@ compio-net = { workspace = true } compio-runtime = { workspace = true } compio-macros = { workspace = true } -rustls = { version = "0.23.1", default-features = false, features = ["ring"] } +rustls = { workspace = true, default-features = false, features = ["ring"] } rustls-native-certs = "0.7.0" [features] diff --git a/compio/Cargo.toml b/compio/Cargo.toml index 692d7ac1..921e1692 100644 --- a/compio/Cargo.toml +++ b/compio/Cargo.toml @@ -42,6 +42,7 @@ compio-dispatcher = { workspace = true, optional = true } compio-log = { workspace = true } compio-tls = { workspace = true, optional = true } compio-process = { workspace = true, optional = true } +compio-quic = { workspace = true, optional = true } # Shared dev dependencies for all platforms [dev-dependencies] @@ -52,7 +53,7 @@ compio-macros = { workspace = true } criterion = { workspace = true, features = ["async_tokio"] } futures-channel = { workspace = true } futures-util = { workspace = true } -rand = "0.8.5" +rand = { workspace = true } tempfile = { workspace = true } tokio = { workspace = true, features = [ "fs", @@ -83,7 +84,7 @@ io-uring = [ ] polling = ["compio-driver/polling"] io = ["dep:compio-io"] -io-compat = ["io", "compio-io/compat"] +io-compat = ["io", "compio-io/compat", "compio-quic?/io-compat"] runtime = ["dep:compio-runtime", "dep:compio-fs", "dep:compio-net", "io"] macros = ["dep:compio-macros", "runtime"] event = ["compio-runtime/event", "runtime"] @@ -94,6 +95,8 @@ tls = ["dep:compio-tls"] native-tls = ["tls", "compio-tls/native-tls"] rustls = ["tls", "compio-tls/rustls"] process = ["dep:compio-process"] +quic = ["dep:compio-quic"] +h3 = ["quic", "compio-quic/h3"] all = [ "time", "macros", @@ -102,6 +105,7 @@ all = [ "native-tls", "rustls", "process", + "quic", ] arrayvec = ["compio-buf/arrayvec"] diff --git a/compio/src/lib.rs b/compio/src/lib.rs index 244d8b37..8b6c5c09 100644 --- a/compio/src/lib.rs +++ b/compio/src/lib.rs @@ -41,6 +41,9 @@ pub use compio_macros::*; #[cfg(feature = "process")] #[doc(inline)] pub use compio_process as process; +#[cfg(feature = "quic")] +#[doc(inline)] +pub use compio_quic as quic; #[cfg(feature = "signal")] #[doc(inline)] pub use compio_signal as signal;