From 89ef43e60f8272696a58abd30676059ae5821ef5 Mon Sep 17 00:00:00 2001 From: Caio Date: Fri, 6 Oct 2023 15:05:04 -0300 Subject: [PATCH] Make trait return `Send` --- .scripts/internal-tests.sh | 1 - Cargo.lock | 280 +---------------- README.md | 3 +- wtx/Cargo.toml | 7 +- wtx/src/error.rs | 11 - wtx/src/lib.rs | 4 +- wtx/src/stream.rs | 459 +++++++++++++++++++++------- wtx/src/tests.rs | 28 ++ wtx/src/web_socket/handshake.rs | 12 +- wtx/src/web_socket/handshake/raw.rs | 181 +++++------ 10 files changed, 493 insertions(+), 493 deletions(-) create mode 100644 wtx/src/tests.rs diff --git a/.scripts/internal-tests.sh b/.scripts/internal-tests.sh index 9346e234..3631b397 100755 --- a/.scripts/internal-tests.sh +++ b/.scripts/internal-tests.sh @@ -19,7 +19,6 @@ $rt test-generic wtx $rt test-with-features wtx arbitrary $rt test-with-features wtx async-std $rt test-with-features wtx base64 -$rt test-with-features wtx embassy-net,_hack $rt test-with-features wtx flate2 $rt test-with-features wtx futures-lite $rt test-with-features wtx glommio diff --git a/Cargo.lock b/Cargo.lock index 181584bb..baac9427 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -52,27 +52,6 @@ dependencies = [ "derive_arbitrary", ] -[[package]] -name = "as-slice" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45403b49e3954a4b8428a0ac21a4b7afadccf92bfd96273f1a58cd4812496ae0" -dependencies = [ - "generic-array 0.12.4", - "generic-array 0.13.3", - "generic-array 0.14.7", - "stable_deref_trait", -] - -[[package]] -name = "as-slice" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "516b6b4f0e40d50dcda9365d53964ec74560ad4284da2e7fc97122cd83174516" -dependencies = [ - "stable_deref_trait", -] - [[package]] name = "async-channel" version = "1.9.0" @@ -234,33 +213,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b9441c6b2fe128a7c2bf680a44c34d0df31ce09e5b7e401fcca3faa483dbc921" [[package]] -name = "atomic-polyfill" -version = "0.1.11" +name = "async-trait" +version = "0.1.73" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3ff7eb3f316534d83a8a2c3d1674ace8a5a71198eba31e2e2b597833f699b28" +checksum = "bc00ceb34980c03614e35a3a4e218276a0a824e911d07651cd0d858a51e8c0f0" dependencies = [ - "critical-section", -] - -[[package]] -name = "atomic-polyfill" -version = "1.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8cf2bce30dfe09ef0bfaef228b9d414faaf7e563035494d7fe092dba54b300f4" -dependencies = [ - "critical-section", -] - -[[package]] -name = "atomic-pool" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "58c5fc22e05ec2884db458bf307dc7b278c9428888d2b6e6fad9c0ae7804f5f6" -dependencies = [ - "as-slice 0.1.5", - "as-slice 0.2.1", - "atomic-polyfill 1.0.3", - "stable_deref_trait", + "proc-macro2", + "quote", + "syn", ] [[package]] @@ -331,7 +291,7 @@ version = "0.10.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" dependencies = [ - "generic-array 0.14.7", + "generic-array", ] [[package]] @@ -479,12 +439,6 @@ dependencies = [ "cfg-if", ] -[[package]] -name = "critical-section" -version = "1.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7059fff8937831a9ae6f0fe4d658ffabf58f2ca96aa9dec1c889f936f705f216" - [[package]] name = "crossbeam" version = "0.8.2" @@ -558,7 +512,7 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" dependencies = [ - "generic-array 0.14.7", + "generic-array", "typenum", ] @@ -583,75 +537,6 @@ dependencies = [ "crypto-common", ] -[[package]] -name = "embassy-net" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d812646c1c50452d77293e05bf2eca13c23803447cdf9bbbcf820588f4c9879" -dependencies = [ - "as-slice 0.2.1", - "atomic-polyfill 1.0.3", - "atomic-pool", - "embassy-net-driver", - "embassy-sync", - "embassy-time", - "futures", - "generic-array 0.14.7", - "heapless", - "managed", - "smoltcp", - "stable_deref_trait", -] - -[[package]] -name = "embassy-net-driver" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c6a4985a5dab4cb55d09703bfdd7d74f58c12c6e889fd3cbb40ea40462a976e" - -[[package]] -name = "embassy-sync" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0dad296a6f70bfdc32ef52442a31f98c28e1608893c1cecc9b6f419bab005a0" -dependencies = [ - "cfg-if", - "critical-section", - "embedded-io", - "futures-util", - "heapless", -] - -[[package]] -name = "embassy-time" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94a9532d29ec2949c49079d85fd77422e3f4cdc133a71674743cd3b8d5a56a20" -dependencies = [ - "atomic-polyfill 1.0.3", - "cfg-if", - "critical-section", - "embedded-hal", - "futures-util", - "heapless", -] - -[[package]] -name = "embedded-hal" -version = "0.2.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35949884794ad573cf46071e41c9b60efb0cb311e3ca01f7af807af1debc66ff" -dependencies = [ - "nb 0.1.3", - "void", -] - -[[package]] -name = "embedded-io" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef1a6892d9eef45c8fa6b9e0086428a2cca8491aca8f787c534a3d6d0bcb3ced" - [[package]] name = "enclose" version = "1.1.8" @@ -735,20 +620,6 @@ dependencies = [ "spin 0.9.8", ] -[[package]] -name = "futures" -version = "0.3.28" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23342abe12aba583913b2e62f22225ff9c950774065e4bfb61a19cd9770fec40" -dependencies = [ - "futures-channel", - "futures-core", - "futures-io", - "futures-sink", - "futures-task", - "futures-util", -] - [[package]] name = "futures-channel" version = "0.3.28" @@ -756,7 +627,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "955518d47e09b25bbebc7a18df10b81f0c766eaf4c4f1cccef2fca5f2a4fb5f2" dependencies = [ "futures-core", - "futures-sink", ] [[package]] @@ -786,43 +656,12 @@ dependencies = [ "waker-fn", ] -[[package]] -name = "futures-macro" -version = "0.3.28" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "futures-sink" version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f43be4fe21a13b9781a69afa4985b0f6ee0e1afab2c6f454a8cf30e2b2237b6e" -[[package]] -name = "futures-task" -version = "0.3.28" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65" - -[[package]] -name = "futures-util" -version = "0.3.28" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533" -dependencies = [ - "futures-core", - "futures-macro", - "futures-sink", - "futures-task", - "pin-project-lite", - "pin-utils", -] - [[package]] name = "fxhash" version = "0.2.1" @@ -832,24 +671,6 @@ dependencies = [ "byteorder", ] -[[package]] -name = "generic-array" -version = "0.12.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffdf9f34f1447443d37393cc6c2b8313aebddcd96906caf34e54c68d8e57d7bd" -dependencies = [ - "typenum", -] - -[[package]] -name = "generic-array" -version = "0.13.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f797e67af32588215eaaab8327027ee8e71b9dd0b2b26996aedf20c030fce309" -dependencies = [ - "typenum", -] - [[package]] name = "generic-array" version = "0.14.7" @@ -926,28 +747,6 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "hash32" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0c35f58762feb77d74ebe43bdbc3210f09be9fe6742234d573bacc26ed92b67" -dependencies = [ - "byteorder", -] - -[[package]] -name = "heapless" -version = "0.7.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db04bc24a18b9ea980628ecf00e6c0264f3c1426dac36c00cb49b6fbad8b0743" -dependencies = [ - "atomic-polyfill 0.1.11", - "hash32", - "rustc_version", - "spin 0.9.8", - "stable_deref_trait", -] - [[package]] name = "heck" version = "0.4.1" @@ -1115,12 +914,6 @@ dependencies = [ "value-bag", ] -[[package]] -name = "managed" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ca88d725a0a943b096803bd34e73a4437208b6077654cc4ecb2947a5f91618d" - [[package]] name = "matchers" version = "0.1.0" @@ -1198,21 +991,6 @@ dependencies = [ "getrandom", ] -[[package]] -name = "nb" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "801d31da0513b6ec5214e9bf433a77966320625a37860f910be265be6e18d06f" -dependencies = [ - "nb 1.1.0", -] - -[[package]] -name = "nb" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d5439c4ad607c3c23abf66de8c8bf57ba8adcd1f129e699851a6e43935d339d" - [[package]] name = "nix" version = "0.23.2" @@ -1482,15 +1260,6 @@ version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" -[[package]] -name = "rustc_version" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" -dependencies = [ - "semver", -] - [[package]] name = "rustix" version = "0.37.23" @@ -1570,12 +1339,6 @@ dependencies = [ "untrusted", ] -[[package]] -name = "semver" -version = "1.0.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad977052201c6de01a8ef2aa3378c4bd23217a056337d1d6da40468d267a4fb0" - [[package]] name = "sha1" version = "0.10.6" @@ -1659,19 +1422,6 @@ dependencies = [ "futures-lite", ] -[[package]] -name = "smoltcp" -version = "0.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d2e3a36ac8fea7b94e666dfa3871063d6e0a5c9d5d4fec9a1a6b7b6760f0229" -dependencies = [ - "bitflags 1.3.2", - "byteorder", - "cfg-if", - "heapless", - "managed", -] - [[package]] name = "socket2" version = "0.4.9" @@ -1707,12 +1457,6 @@ dependencies = [ "lock_api", ] -[[package]] -name = "stable_deref_trait" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" - [[package]] name = "syn" version = "2.0.37" @@ -1881,12 +1625,6 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" -[[package]] -name = "void" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a02e4885ed3bc0f2de90ea6dd45ebcbb66dacffe03547fadbb0eeae2770887d" - [[package]] name = "waker-fn" version = "1.1.1" @@ -2075,8 +1813,8 @@ version = "0.8.0" dependencies = [ "arbitrary", "async-std", + "async-trait", "base64", - "embassy-net", "flate2", "futures-lite", "glommio", diff --git a/README.md b/README.md index d7639d52..0ec76f12 100644 --- a/README.md +++ b/README.md @@ -64,5 +64,4 @@ There are mainly 2 things that impact performance, the chosen runtime and the nu If you disagree with any of the above numbers, feel free to checkout `wtx-bench` to point any misunderstandings or misconfigurations. A more insightful analysis is available at https://c410-f3r.github.io/thoughts/the-fastest-websocket-implementation/. -¹ `monoio` and `tokio-uring` are slower because owned vectors need to be created on each read/write operation.
-² `embassy` is supported but there isn't a `std` example for measurement purposes. \ No newline at end of file +¹ `monoio` and `tokio-uring` are slower because owned vectors need to be created on each read/write operation.
\ No newline at end of file diff --git a/wtx/Cargo.toml b/wtx/Cargo.toml index 11d712ab..6ce490cb 100644 --- a/wtx/Cargo.toml +++ b/wtx/Cargo.toml @@ -46,7 +46,6 @@ required-features = ["rustls-pemfile", "tokio-rustls", "web-socket-handshake"] arbitrary = { default-features = false, features = ["derive_arbitrary"], optional = true, version = "1.0" } async-std = { default-features = false, features = ["default"], optional = true, version = "1.0" } base64 = { default-features = false, features = ["alloc"], optional = true, version = "0.21" } -embassy-net = { default-features = false, features = ["tcp"], optional = true, version = "0.1" } flate2 = { default-features = false, features = ["zlib-ng"], optional = true, version = "1.0" } futures-lite = { default-features = false, optional = true, version = "1.0" } glommio = { default-features = false, optional = true, version = "0.8" } @@ -64,6 +63,7 @@ tracing = { default-features = false, features = ["attributes"], optional = true webpki-roots = { default-features = false, optional = true, version = "0.25" } [dev-dependencies] +async-trait = { default-features = false, version = "0.1" } tokio = { default-features = false, features = ["macros", "rt", "time"], version = "1.0" } tracing-subscriber = { default-features = false, features = ["env-filter", "fmt"], version = "0.3" } tracing-tree = { default-features = false, version = "0.2" } @@ -79,11 +79,6 @@ tokio = ["std", "dep:tokio"] tokio-rustls = ["tokio", "dep:tokio-rustls"] web-socket-handshake = ["base64", "httparse", "sha1"] -# Dependencies that don't support `no-default-features`. -# -# Used internally to avoid unnecessary poll of unused dependencies or unused features for downstream. -_hack = ["embassy-net/medium-ethernet", "embassy-net/proto-ipv4"] - [package] authors = ["Caio Fernandes "] categories = ["asynchronous", "data-structures", "network-programming", "no-std", "web-programming"] diff --git a/wtx/src/error.rs b/wtx/src/error.rs index b2d5ba59..97b0e71b 100644 --- a/wtx/src/error.rs +++ b/wtx/src/error.rs @@ -42,9 +42,6 @@ pub enum Error { // External // - #[cfg(feature = "embassy-net")] - /// See [embassy_net::tcp::Error]. - EmbassyNetTcp(embassy_net::tcp::Error), #[cfg(feature = "flate2")] /// See [flate2::CompressError]. Flate2CompressError(flate2::CompressError), @@ -81,14 +78,6 @@ impl Display for Error { #[cfg(feature = "std")] impl std::error::Error for Error {} -#[cfg(feature = "embassy-net")] -impl From for Error { - #[inline] - fn from(from: embassy_net::tcp::Error) -> Self { - Self::EmbassyNetTcp(from) - } -} - #[cfg(feature = "flate2")] impl From for Error { #[inline] diff --git a/wtx/src/lib.rs b/wtx/src/lib.rs index 1c529271..05c9e43d 100644 --- a/wtx/src/lib.rs +++ b/wtx/src/lib.rs @@ -1,6 +1,6 @@ #![cfg_attr(not(feature = "std"), no_std)] #![doc = include_str!("../README.md")] -#![feature(array_chunks, async_fn_in_trait, impl_trait_projections)] +#![feature(array_chunks, async_fn_in_trait, impl_trait_in_assoc_type)] extern crate alloc; @@ -18,6 +18,8 @@ pub mod rng; #[cfg(feature = "tracing")] mod role; mod stream; +#[cfg(test)] +mod tests; pub mod web_socket; pub use cache::Cache; diff --git a/wtx/src/stream.rs b/wtx/src/stream.rs index d60feeca..8e00b2f1 100644 --- a/wtx/src/stream.rs +++ b/wtx/src/stream.rs @@ -1,28 +1,94 @@ use alloc::vec::Vec; -use core::cmp::Ordering; +use core::{ + cmp::Ordering, + future::{ready, Future, Ready}, +}; /// A stream of values produced asynchronously. pub trait Stream { + /// Future of `read` method + type Read<'read>: Future> + 'read + where + Self: 'read; + /// Future of `write` method + type Write<'write>: Future> + 'write + where + Self: 'write; + /// Pulls some bytes from this source into the specified buffer, returning how many bytes /// were read. - async fn read(&mut self, bytes: &mut [u8]) -> crate::Result; + fn read<'bytes, 'fut, 'this>(&'this mut self, bytes: &'bytes mut [u8]) -> Self::Read<'fut> + where + 'bytes: 'fut, + 'this: 'fut, + Self: 'fut; /// Attempts to write all elements of `bytes`. - async fn write_all(&mut self, bytes: &[u8]) -> crate::Result<()>; + fn write_all<'bytes, 'fut, 'this>(&'this mut self, bytes: &'bytes [u8]) -> Self::Write<'fut> + where + 'bytes: 'fut, + 'this: 'fut, + Self: 'fut; +} + +impl Stream for () { + type Read<'read> = Ready> + where + Self: 'read; + type Write<'write> = Ready> + where + Self: 'write; + + #[inline] + fn read<'bytes, 'fut, 'this>(&'this mut self, _: &'bytes mut [u8]) -> Self::Read<'fut> + where + 'bytes: 'fut, + 'this: 'fut, + Self: 'fut, + { + ready(Ok(0)) + } + + #[inline] + fn write_all<'bytes, 'fut, 'this>(&'this mut self, _: &'bytes [u8]) -> Self::Write<'fut> + where + 'bytes: 'fut, + 'this: 'fut, + Self: 'fut, + { + ready(Ok(())) + } } impl Stream for &mut T where T: Stream, { + type Read<'read> = T::Read<'read> + where + Self: 'read; + type Write<'write> = T::Write<'write> + where + Self: 'write; + #[inline] - async fn read(&mut self, bytes: &mut [u8]) -> crate::Result { - (*self).read(bytes).await + fn read<'bytes, 'fut, 'this>(&'this mut self, bytes: &'bytes mut [u8]) -> Self::Read<'fut> + where + 'bytes: 'fut, + 'this: 'fut, + Self: 'fut, + { + (*self).read(bytes) } #[inline] - async fn write_all(&mut self, bytes: &[u8]) -> crate::Result<()> { - (*self).write_all(bytes).await + fn write_all<'bytes, 'fut, 'this>(&'this mut self, bytes: &'bytes [u8]) -> Self::Write<'fut> + where + 'bytes: 'fut, + 'this: 'fut, + Self: 'fut, + { + (*self).write_all(bytes) } } @@ -43,45 +109,54 @@ impl BytesStream { } impl Stream for BytesStream { - #[inline] - async fn read(&mut self, bytes: &mut [u8]) -> crate::Result { - let working_buffer = self.buffer.get(self.idx..).unwrap_or_default(); - let working_buffer_len = working_buffer.len(); - Ok(match working_buffer_len.cmp(&bytes.len()) { - Ordering::Less => { - bytes.get_mut(..working_buffer_len).unwrap_or_default().copy_from_slice(working_buffer); - self.clear(); - working_buffer_len - } - Ordering::Equal => { - bytes.copy_from_slice(working_buffer); - self.clear(); - working_buffer_len - } - Ordering::Greater => { - bytes.copy_from_slice(working_buffer.get(..bytes.len()).unwrap_or_default()); - self.idx = self.idx.wrapping_add(bytes.len()); - bytes.len() - } - }) - } - - #[inline] - async fn write_all(&mut self, bytes: &[u8]) -> crate::Result<()> { - self.buffer.extend_from_slice(bytes); - Ok(()) - } -} + type Read<'read> = Ready> + where + Self: 'read; + type Write<'write> = Ready> + where + Self: 'write; -impl Stream for () { #[inline] - async fn read(&mut self, _: &mut [u8]) -> crate::Result { - Ok(0) + fn read<'bytes, 'fut, 'this>(&'this mut self, bytes: &'bytes mut [u8]) -> Self::Read<'fut> + where + 'bytes: 'fut, + 'this: 'fut, + Self: 'fut, + { + ready({ + let working_buffer = self.buffer.get(self.idx..).unwrap_or_default(); + let working_buffer_len = working_buffer.len(); + Ok(match working_buffer_len.cmp(&bytes.len()) { + Ordering::Less => { + bytes.get_mut(..working_buffer_len).unwrap_or_default().copy_from_slice(working_buffer); + self.clear(); + working_buffer_len + } + Ordering::Equal => { + bytes.copy_from_slice(working_buffer); + self.clear(); + working_buffer_len + } + Ordering::Greater => { + bytes.copy_from_slice(working_buffer.get(..bytes.len()).unwrap_or_default()); + self.idx = self.idx.wrapping_add(bytes.len()); + bytes.len() + } + }) + }) } #[inline] - async fn write_all(&mut self, _: &[u8]) -> crate::Result<()> { - Ok(()) + fn write_all<'bytes, 'fut, 'this>(&'this mut self, bytes: &'bytes [u8]) -> Self::Write<'fut> + where + 'bytes: 'fut, + 'this: 'fut, + Self: 'fut, + { + ready({ + self.buffer.extend_from_slice(bytes); + Ok(()) + }) } } @@ -92,36 +167,37 @@ mod async_std { io::{ReadExt, WriteExt}, net::TcpStream, }; + use core::future::Future; impl Stream for TcpStream { - #[inline] - async fn read(&mut self, bytes: &mut [u8]) -> crate::Result { - Ok(::read(self, bytes).await?) - } + type Read<'read> = impl Future> + 'read + where + Self: 'read; + type Write<'write> = impl Future> + 'write + where + Self: 'write; #[inline] - async fn write_all(&mut self, bytes: &[u8]) -> crate::Result<()> { - ::write_all(self, bytes).await?; - Ok(()) + fn read<'bytes, 'fut, 'this>(&'this mut self, bytes: &'bytes mut [u8]) -> Self::Read<'fut> + where + 'bytes: 'fut, + 'this: 'fut, + Self: 'fut, + { + async { Ok(::read(self, bytes).await?) } } - } -} - -#[cfg(feature = "embassy-net")] -mod embassy { - use crate::Stream; - use embassy_net::tcp::TcpSocket; - impl<'any> Stream for TcpSocket<'any> { #[inline] - async fn read(&mut self, bytes: &mut [u8]) -> crate::Result { - Ok((*self).read(bytes).await?) - } - - #[inline] - async fn write_all(&mut self, bytes: &[u8]) -> crate::Result<()> { - (*self).write(bytes).await?; - Ok(()) + fn write_all<'bytes, 'fut, 'this>(&'this mut self, bytes: &'bytes [u8]) -> Self::Write<'fut> + where + 'bytes: 'fut, + 'this: 'fut, + Self: 'fut, + { + async { + ::write_all(self, bytes).await?; + Ok(()) + } } } } @@ -129,19 +205,39 @@ mod embassy { #[cfg(feature = "glommio")] mod glommio { use crate::Stream; + use core::future::Future; use futures_lite::io::{AsyncReadExt, AsyncWriteExt}; use glommio::net::TcpStream; impl Stream for TcpStream { + type Read<'read> = impl Future> + 'read + where + Self: 'read; + type Write<'write> = impl Future> + 'write + where + Self: 'write; + #[inline] - async fn read(&mut self, bytes: &mut [u8]) -> crate::Result { - Ok(::read(self, bytes).await?) + fn read<'bytes, 'fut, 'this>(&'this mut self, bytes: &'bytes mut [u8]) -> Self::Read<'fut> + where + 'bytes: 'fut, + 'this: 'fut, + Self: 'fut, + { + async { Ok(::read(self, bytes).await?) } } #[inline] - async fn write_all(&mut self, bytes: &[u8]) -> crate::Result<()> { - ::write_all(self, bytes).await?; - Ok(()) + fn write_all<'bytes, 'fut, 'this>(&'this mut self, bytes: &'bytes [u8]) -> Self::Write<'fut> + where + 'bytes: 'fut, + 'this: 'fut, + Self: 'fut, + { + async { + ::write_all(self, bytes).await?; + Ok(()) + } } } } @@ -149,24 +245,46 @@ mod glommio { #[cfg(feature = "monoio")] mod monoio { use crate::Stream; + use core::future::Future; use monoio::{ io::{AsyncReadRent, AsyncWriteRentExt}, net::TcpStream, }; impl Stream for TcpStream { + type Read<'read> = impl Future> + 'read + where + Self: 'read; + type Write<'write> = impl Future> + 'write + where + Self: 'write; + #[inline] - async fn read(&mut self, bytes: &mut [u8]) -> crate::Result { - let (rslt, read) = AsyncReadRent::read(self, bytes.to_vec()).await; - bytes.get_mut(..read.len()).unwrap_or_default().copy_from_slice(&read); - Ok(rslt?) + fn read<'bytes, 'fut, 'this>(&'this mut self, bytes: &'bytes mut [u8]) -> Self::Read<'fut> + where + 'bytes: 'fut, + 'this: 'fut, + Self: 'fut, + { + async { + let (rslt, read) = AsyncReadRent::read(self, bytes.to_vec()).await; + bytes.get_mut(..read.len()).unwrap_or_default().copy_from_slice(&read); + Ok(rslt?) + } } #[inline] - async fn write_all(&mut self, bytes: &[u8]) -> crate::Result<()> { - let (rslt, _) = AsyncWriteRentExt::write_all(self, bytes.to_vec()).await; - rslt?; - Ok(()) + fn write_all<'bytes, 'fut, 'this>(&'this mut self, bytes: &'bytes [u8]) -> Self::Write<'fut> + where + 'bytes: 'fut, + 'this: 'fut, + Self: 'fut, + { + async { + let (rslt, _) = AsyncWriteRentExt::write_all(self, bytes.to_vec()).await; + rslt?; + Ok(()) + } } } } @@ -174,21 +292,41 @@ mod monoio { #[cfg(feature = "smol")] mod smol { use crate::Stream; + use core::future::Future; use smol::{ io::{AsyncReadExt, AsyncWriteExt}, net::TcpStream, }; impl Stream for TcpStream { + type Read<'read> = impl Future> + 'read + where + Self: 'read; + type Write<'write> = impl Future> + 'write + where + Self: 'write; + #[inline] - async fn read(&mut self, bytes: &mut [u8]) -> crate::Result { - Ok(::read(self, bytes).await?) + fn read<'bytes, 'fut, 'this>(&'this mut self, bytes: &'bytes mut [u8]) -> Self::Read<'fut> + where + 'bytes: 'fut, + 'this: 'fut, + Self: 'fut, + { + async { Ok(::read(self, bytes).await?) } } #[inline] - async fn write_all(&mut self, bytes: &[u8]) -> crate::Result<()> { - ::write_all(self, bytes).await?; - Ok(()) + fn write_all<'bytes, 'fut, 'this>(&'this mut self, bytes: &'bytes [u8]) -> Self::Write<'fut> + where + 'bytes: 'fut, + 'this: 'fut, + Self: 'fut, + { + async { + ::write_all(self, bytes).await?; + Ok(()) + } } } } @@ -196,21 +334,41 @@ mod smol { #[cfg(feature = "std")] mod std { use crate::Stream; + use core::future::Future; use std::{ io::{Read, Write}, net::TcpStream, }; impl Stream for TcpStream { + type Read<'read> = impl Future> + 'read + where + Self: 'read; + type Write<'write> = impl Future> + 'write + where + Self: 'write; + #[inline] - async fn read(&mut self, bytes: &mut [u8]) -> crate::Result { - Ok(::read(self, bytes)?) + fn read<'bytes, 'fut, 'this>(&'this mut self, bytes: &'bytes mut [u8]) -> Self::Read<'fut> + where + 'bytes: 'fut, + 'this: 'fut, + Self: 'fut, + { + async { Ok(::read(self, bytes)?) } } #[inline] - async fn write_all(&mut self, bytes: &[u8]) -> crate::Result<()> { - ::write_all(self, bytes)?; - Ok(()) + fn write_all<'bytes, 'fut, 'this>(&'this mut self, bytes: &'bytes [u8]) -> Self::Write<'fut> + where + 'bytes: 'fut, + 'this: 'fut, + Self: 'fut, + { + async { + ::write_all(self, bytes)?; + Ok(()) + } } } } @@ -218,21 +376,41 @@ mod std { #[cfg(feature = "tokio")] mod tokio { use crate::Stream; + use core::future::Future; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, net::TcpStream, }; impl Stream for TcpStream { + type Read<'read> = impl Future> + 'read + where + Self: 'read; + type Write<'write> = impl Future> + 'write + where + Self: 'write; + #[inline] - async fn read(&mut self, bytes: &mut [u8]) -> crate::Result { - Ok(::read(self, bytes).await?) + fn read<'bytes, 'fut, 'this>(&'this mut self, bytes: &'bytes mut [u8]) -> Self::Read<'fut> + where + 'bytes: 'fut, + 'this: 'fut, + Self: 'fut, + { + async { Ok(::read(self, bytes).await?) } } #[inline] - async fn write_all(&mut self, bytes: &[u8]) -> crate::Result<()> { - ::write_all(self, bytes).await?; - Ok(()) + fn write_all<'bytes, 'fut, 'this>(&'this mut self, bytes: &'bytes [u8]) -> Self::Write<'fut> + where + 'bytes: 'fut, + 'this: 'fut, + Self: 'fut, + { + async { + ::write_all(self, bytes).await?; + Ok(()) + } } } } @@ -240,21 +418,43 @@ mod tokio { #[cfg(feature = "tokio-uring")] mod tokio_uring { use crate::Stream; + use core::future::Future; use tokio_uring::net::TcpStream; impl Stream for TcpStream { + type Read<'read> = impl Future> + 'read + where + Self: 'read; + type Write<'write> = impl Future> + 'write + where + Self: 'write; + #[inline] - async fn read(&mut self, bytes: &mut [u8]) -> crate::Result { - let (rslt, read) = TcpStream::read(self, bytes.to_vec()).await; - bytes.get_mut(..read.len()).unwrap_or_default().copy_from_slice(&read); - Ok(rslt?) + fn read<'bytes, 'fut, 'this>(&'this mut self, bytes: &'bytes mut [u8]) -> Self::Read<'fut> + where + 'bytes: 'fut, + 'this: 'fut, + Self: 'fut, + { + async { + let (rslt, read) = TcpStream::read(self, bytes.to_vec()).await; + bytes.get_mut(..read.len()).unwrap_or_default().copy_from_slice(&read); + Ok(rslt?) + } } #[inline] - async fn write_all(&mut self, bytes: &[u8]) -> crate::Result<()> { - let (rslt, _) = TcpStream::write_all(self, bytes.to_vec()).await; - rslt?; - Ok(()) + fn write_all<'bytes, 'fut, 'this>(&'this mut self, bytes: &'bytes [u8]) -> Self::Write<'fut> + where + 'bytes: 'fut, + 'this: 'fut, + Self: 'fut, + { + async { + let (rslt, _) = TcpStream::write_all(self, bytes.to_vec()).await; + rslt?; + Ok(()) + } } } } @@ -262,21 +462,41 @@ mod tokio_uring { #[cfg(feature = "tokio-rustls")] mod tokio_rustls { use crate::Stream; + use core::future::Future; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; impl Stream for tokio_rustls::client::TlsStream where T: AsyncRead + AsyncWrite + Unpin, { + type Read<'read> = impl Future> + 'read + where + Self: 'read; + type Write<'write> = impl Future> + 'write + where + Self: 'write; + #[inline] - async fn read(&mut self, bytes: &mut [u8]) -> crate::Result { - Ok(::read(self, bytes).await?) + fn read<'bytes, 'fut, 'this>(&'this mut self, bytes: &'bytes mut [u8]) -> Self::Read<'fut> + where + 'bytes: 'fut, + 'this: 'fut, + Self: 'fut, + { + async { Ok(::read(self, bytes).await?) } } #[inline] - async fn write_all(&mut self, bytes: &[u8]) -> crate::Result<()> { - ::write_all(self, bytes).await?; - Ok(()) + fn write_all<'bytes, 'fut, 'this>(&'this mut self, bytes: &'bytes [u8]) -> Self::Write<'fut> + where + 'bytes: 'fut, + 'this: 'fut, + Self: 'fut, + { + async { + ::write_all(self, bytes).await?; + Ok(()) + } } } @@ -284,15 +504,34 @@ mod tokio_rustls { where T: AsyncRead + AsyncWrite + Unpin, { + type Read<'read> = impl Future> + 'read + where + Self: 'read; + type Write<'write> = impl Future> + 'write + where + Self: 'write; + #[inline] - async fn read(&mut self, bytes: &mut [u8]) -> crate::Result { - Ok(::read(self, bytes).await?) + fn read<'bytes, 'fut, 'this>(&'this mut self, bytes: &'bytes mut [u8]) -> Self::Read<'fut> + where + 'bytes: 'fut, + 'this: 'fut, + Self: 'fut, + { + async { Ok(::read(self, bytes).await?) } } #[inline] - async fn write_all(&mut self, bytes: &[u8]) -> crate::Result<()> { - ::write_all(self, bytes).await?; - Ok(()) + fn write_all<'bytes, 'fut, 'this>(&'this mut self, bytes: &'bytes [u8]) -> Self::Write<'fut> + where + 'bytes: 'fut, + 'this: 'fut, + Self: 'fut, + { + async { + ::write_all(self, bytes).await?; + Ok(()) + } } } } diff --git a/wtx/src/tests.rs b/wtx/src/tests.rs new file mode 100644 index 00000000..560978d8 --- /dev/null +++ b/wtx/src/tests.rs @@ -0,0 +1,28 @@ +use core::borrow::BorrowMut; + +use crate::{ + rng::Rng, + web_socket::{compression::NegotiatedCompression, FrameBufferVec, WebSocketClient}, + PartitionedBuffer, Stream, +}; + +#[async_trait::async_trait] +pub trait AsyncTrait { + async fn method(&mut self) -> crate::Result<()>; +} + +#[async_trait::async_trait] +impl AsyncTrait for (&mut FrameBufferVec, &mut WebSocketClient) +where + NC: NegotiatedCompression + Send + Sync, + PB: BorrowMut + Send + Sync, + RNG: Rng + Send + Sync, + S: Stream + Send + Sync, + for<'read> S::Read<'read>: Send + Sync, + for<'write> S::Write<'write>: Send + Sync, +{ + async fn method(&mut self) -> crate::Result<()> { + let _ = self.1.borrow_mut().read_frame(self.0).await?; + Ok(()) + } +} diff --git a/wtx/src/web_socket/handshake.rs b/wtx/src/web_socket/handshake.rs index 939c0514..5e00022b 100644 --- a/wtx/src/web_socket/handshake.rs +++ b/wtx/src/web_socket/handshake.rs @@ -12,24 +12,28 @@ pub use raw::{WebSocketAcceptRaw, WebSocketConnectRaw}; /// Reads external data to figure out if incoming requests can be accepted as WebSocket connections. pub trait WebSocketAccept { + /// Future of the `accept` method. + type Accept: Future>>; /// Specific implementation stream. type Stream: Stream; /// Reads external data to figure out if incoming requests can be accepted as WebSocket connections. - async fn accept(self) -> crate::Result>; + fn accept(self) -> Self::Accept; } /// Initial negotiation sent by a client to start a WebSocket connection. pub trait WebSocketConnect { + /// Future of the `accept` method. + type Connect: Future< + Output = crate::Result<(Self::Response, WebSocketClient)>, + >; /// Specific implementation response. type Response; /// Specific implementation stream. type Stream: Stream; /// Initial negotiation sent by a client to start a WebSocket connection. - async fn connect( - self, - ) -> crate::Result<(Self::Response, WebSocketClient)>; + fn connect(self) -> Self::Connect; } /// Manages the upgrade of already established requests into WebSocket connections. diff --git a/wtx/src/web_socket/handshake/raw.rs b/wtx/src/web_socket/handshake/raw.rs index 70f11232..208b1ced 100644 --- a/wtx/src/web_socket/handshake/raw.rs +++ b/wtx/src/web_socket/handshake/raw.rs @@ -53,7 +53,7 @@ mod httparse_impls { }, ExpectedHeader, PartitionedBuffer, Stream, UriParts, }; - use core::{borrow::BorrowMut, str}; + use core::{borrow::BorrowMut, future::Future, str}; use httparse::{Header, Request, Response, Status, EMPTY_HEADER}; const MAX_READ_LEN: usize = 2 * 1024; @@ -66,57 +66,59 @@ mod httparse_impls { RNG: Rng, S: Stream, { + type Accept = + impl Future>>; type Stream = S; #[inline] - async fn accept( - mut self, - ) -> crate::Result> { - let pb = self.pb.borrow_mut(); - pb._set_indices_through_expansion(0, 0, MAX_READ_LEN); - let mut read = 0; - loop { - let read_buffer = pb._following_mut().get_mut(read..).unwrap_or_default(); - let local_read = self.stream.read(read_buffer).await?; - if local_read == 0 { - return Err(crate::Error::UnexpectedEOF); - } - read = read.wrapping_add(local_read); - let mut req_buffer = [EMPTY_HEADER; MAX_READ_HEADER_LEN]; - let mut req = Request::new(&mut req_buffer); - match req.parse(pb._following())? { - Status::Complete(_) => { - if !_trim(req.method()).eq_ignore_ascii_case(b"get") { - return Err(crate::Error::UnexpectedHttpMethod); + fn accept(mut self) -> Self::Accept { + async { + let pb = self.pb.borrow_mut(); + pb._set_indices_through_expansion(0, 0, MAX_READ_LEN); + let mut read = 0; + loop { + let read_buffer = pb._following_mut().get_mut(read..).unwrap_or_default(); + let local_read = self.stream.read(read_buffer).await?; + if local_read == 0 { + return Err(crate::Error::UnexpectedEOF); + } + read = read.wrapping_add(local_read); + let mut req_buffer = [EMPTY_HEADER; MAX_READ_HEADER_LEN]; + let mut req = Request::new(&mut req_buffer); + match req.parse(pb._following())? { + Status::Complete(_) => { + if !_trim(req.method()).eq_ignore_ascii_case(b"get") { + return Err(crate::Error::UnexpectedHttpMethod); + } + verify_common_header(req.headers)?; + if !has_header_key_and_value(req.headers, b"sec-websocket-version", b"13") { + return Err(crate::Error::MissingHeader { + expected: ExpectedHeader::SecWebSocketVersion_13, + }); + }; + let Some(key) = req.headers.iter().find_map(|el| { + (el.name().eq_ignore_ascii_case(b"sec-websocket-key")).then_some(el.value()) + }) else { + return Err(crate::Error::MissingHeader { + expected: ExpectedHeader::SecWebSocketKey, + }); + }; + let compression = self.compression.negotiate(req.headers.iter())?; + let swa = derived_key(self.key_buffer, key); + let mut headers_buffer = HeadersBuffer::<_, 3>::default(); + headers_buffer.headers[0] = Header { name: "Connection", value: b"Upgrade" }; + headers_buffer.headers[1] = Header { name: "Sec-WebSocket-Accept", value: swa }; + headers_buffer.headers[2] = Header { name: "Upgrade", value: b"websocket" }; + let mut res = Response::new(&mut headers_buffer.headers); + res.code = Some(101); + res.version = Some(req.version().into()); + let res_bytes = build_res(&compression, res.headers, pb); + self.stream.write_all(res_bytes).await?; + pb.clear(); + return Ok(WebSocketServer::new(compression, self.pb, self.rng, self.stream)); } - verify_common_header(req.headers)?; - if !has_header_key_and_value(req.headers, b"sec-websocket-version", b"13") { - return Err(crate::Error::MissingHeader { - expected: ExpectedHeader::SecWebSocketVersion_13, - }); - }; - let Some(key) = req.headers.iter().find_map(|el| { - (el.name().eq_ignore_ascii_case(b"sec-websocket-key")).then_some(el.value()) - }) else { - return Err(crate::Error::MissingHeader { - expected: ExpectedHeader::SecWebSocketKey, - }); - }; - let compression = self.compression.negotiate(req.headers.iter())?; - let swa = derived_key(self.key_buffer, key); - let mut headers_buffer = HeadersBuffer::<_, 3>::default(); - headers_buffer.headers[0] = Header { name: "Connection", value: b"Upgrade" }; - headers_buffer.headers[1] = Header { name: "Sec-WebSocket-Accept", value: swa }; - headers_buffer.headers[2] = Header { name: "Upgrade", value: b"websocket" }; - let mut res = Response::new(&mut headers_buffer.headers); - res.code = Some(101); - res.version = Some(req.version().into()); - let res_bytes = build_res(&compression, res.headers, pb); - self.stream.write_all(res_bytes).await?; - pb.clear(); - return Ok(WebSocketServer::new(compression, self.pb, self.rng, self.stream)); + Status::Partial => {} } - Status::Partial => {} } } } @@ -132,53 +134,58 @@ mod httparse_impls { S: Stream, 'fb: 'hb, { + type Connect = impl Future< + Output = crate::Result<( + Self::Response, + WebSocketClient, + )>, + >; type Response = Response<'hb, 'fb>; type Stream = S; #[inline] - async fn connect( - mut self, - ) -> crate::Result<(Self::Response, WebSocketClient)> - { - let key_buffer = &mut <_>::default(); - let pb = self.pb.borrow_mut(); - pb.clear(); - let (key, req) = build_req(&self.compression, key_buffer, pb, &mut self.rng, self.uri); - self.stream.write_all(req).await?; - let mut read = 0; - self.fb._set_indices_through_expansion(0, 0, MAX_READ_LEN); - let len = loop { - let mut local_header = [EMPTY_HEADER; MAX_READ_HEADER_LEN]; - let read_buffer = self.fb.payload_mut().get_mut(read..).unwrap_or_default(); - let local_read = self.stream.read(read_buffer).await?; - if local_read == 0 { - return Err(crate::Error::UnexpectedEOF); + fn connect(mut self) -> Self::Connect { + async { + let key_buffer = &mut <_>::default(); + let pb = self.pb.borrow_mut(); + pb.clear(); + let (key, req) = build_req(&self.compression, key_buffer, pb, &mut self.rng, self.uri); + self.stream.write_all(req).await?; + let mut read = 0; + self.fb._set_indices_through_expansion(0, 0, MAX_READ_LEN); + let len = loop { + let mut local_header = [EMPTY_HEADER; MAX_READ_HEADER_LEN]; + let read_buffer = self.fb.payload_mut().get_mut(read..).unwrap_or_default(); + let local_read = self.stream.read(read_buffer).await?; + if local_read == 0 { + return Err(crate::Error::UnexpectedEOF); + } + read = read.wrapping_add(local_read); + match Response::new(&mut local_header).parse(self.fb.payload())? { + Status::Complete(len) => break len, + Status::Partial => {} + } + }; + let mut res = Response::new(&mut self.headers_buffer.headers); + let _status = res.parse(self.fb.payload())?; + if res.code != Some(101) { + return Err(WebSocketError::MissingSwitchingProtocols.into()); } - read = read.wrapping_add(local_read); - match Response::new(&mut local_header).parse(self.fb.payload())? { - Status::Complete(len) => break len, - Status::Partial => {} + verify_common_header(res.headers)?; + if !has_header_key_and_value( + res.headers, + b"sec-websocket-accept", + derived_key(&mut <_>::default(), key), + ) { + return Err(crate::Error::MissingHeader { + expected: crate::ExpectedHeader::SecWebSocketKey, + }); } - }; - let mut res = Response::new(&mut self.headers_buffer.headers); - let _status = res.parse(self.fb.payload())?; - if res.code != Some(101) { - return Err(WebSocketError::MissingSwitchingProtocols.into()); - } - verify_common_header(res.headers)?; - if !has_header_key_and_value( - res.headers, - b"sec-websocket-accept", - derived_key(&mut <_>::default(), key), - ) { - return Err(crate::Error::MissingHeader { - expected: crate::ExpectedHeader::SecWebSocketKey, - }); + let compression = self.compression.negotiate(res.headers.iter())?; + pb.borrow_mut()._set_indices_through_expansion(0, 0, read.wrapping_sub(len)); + pb._following_mut().copy_from_slice(self.fb.payload().get(len..read).unwrap_or_default()); + Ok((res, WebSocketClient::new(compression, self.pb, self.rng, self.stream))) } - let compression = self.compression.negotiate(res.headers.iter())?; - pb.borrow_mut()._set_indices_through_expansion(0, 0, read.wrapping_sub(len)); - pb._following_mut().copy_from_slice(self.fb.payload().get(len..read).unwrap_or_default()); - Ok((res, WebSocketClient::new(compression, self.pb, self.rng, self.stream))) } }