diff --git a/Cargo.lock b/Cargo.lock index a583dc99..3bb9a4fd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -652,7 +652,6 @@ name = "engineioxide" version = "0.17.1" dependencies = [ "axum", - "base64 0.22.1", "bytes", "criterion", "engineioxide-core", @@ -663,8 +662,6 @@ dependencies = [ "http-body-util", "hyper", "hyper-util", - "itoa", - "memchr", "pin-project-lite", "serde", "serde_json", @@ -678,7 +675,31 @@ dependencies = [ "tower-service", "tracing", "tracing-subscriber", - "unicode-segmentation", +] + +[[package]] +name = "engineioxide-client" +version = "0.17.0" +dependencies = [ + "bytes", + "engineioxide", + "engineioxide-core", + "futures-core", + "futures-util", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "pin-project-lite", + "serde", + "serde_json", + "smallvec", + "thiserror 2.0.15", + "tokio", + "tokio-tungstenite", + "tracing", + "tracing-subscriber", ] [[package]] @@ -687,8 +708,20 @@ version = "0.2.0" dependencies = [ "base64 0.22.1", "bytes", + "futures-util", + "http", + "http-body", + "http-body-util", + "itoa", + "memchr", + "pin-project-lite", "rand 0.9.1", "serde", + "serde_json", + "smallvec", + "tokio", + "tracing", + "unicode-segmentation", ] [[package]] @@ -1044,6 +1077,7 @@ dependencies = [ "pin-utils", "smallvec", "tokio", + "want", ] [[package]] @@ -2694,6 +2728,12 @@ dependencies = [ "tracing-log", ] +[[package]] +name = "try-lock" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" + [[package]] name = "tungstenite" version = "0.26.2" @@ -2830,6 +2870,15 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "want" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" +dependencies = [ + "try-lock", +] + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" diff --git a/crates/engineioxide-client/Cargo.toml b/crates/engineioxide-client/Cargo.toml new file mode 100644 index 00000000..2d7b7fcb --- /dev/null +++ b/crates/engineioxide-client/Cargo.toml @@ -0,0 +1,44 @@ +[package] +name = "engineioxide-client" +description = "Engine IO client implementation in rust" +version = "0.17.0" +edition.workspace = true +rust-version.workspace = true +authors.workspace = true +repository.workspace = true +homepage.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true +readme = "README.md" + +[dependencies] +engineioxide-core = { path = "../engineioxide-core", version = "0.2" } +bytes.workspace = true +futures-core.workspace = true +futures-util.workspace = true +http.workspace = true +http-body.workspace = true +serde.workspace = true +serde_json.workspace = true +thiserror.workspace = true +tokio = { workspace = true, features = ["rt", "time"] } +hyper = { workspace = true, features = ["client", "http1"] } +tokio-tungstenite.workspace = true +http-body-util.workspace = true +pin-project-lite.workspace = true +smallvec.workspace = true +hyper-util = { workspace = true, features = ["tokio"] } + +# Tracing +tracing = { workspace = true, optional = true } + +[dev-dependencies] +tokio = { workspace = true, features = ["macros", "parking_lot"] } +tracing-subscriber = { workspace = true, features = ["env-filter"] } +engineioxide = { path = "../engineioxide", features = ["tracing", "v3"] } + +[features] +v3 = ["engineioxide-core/v3"] +tracing = ["dep:tracing", "engineioxide-core/tracing"] +__test_harness = [] diff --git a/crates/engineioxide-client/README.md b/crates/engineioxide-client/README.md new file mode 100644 index 00000000..e69de29b diff --git a/crates/engineioxide-client/src/client.rs b/crates/engineioxide-client/src/client.rs new file mode 100644 index 00000000..ea6f4223 --- /dev/null +++ b/crates/engineioxide-client/src/client.rs @@ -0,0 +1,112 @@ +use std::{ + fmt, + pin::Pin, + sync::Mutex, + task::{Context, Poll}, +}; + +use engineioxide_core::{Packet, PacketBuf, PacketParseError, Sid}; +use futures_core::Stream; +use futures_util::{ + Sink, SinkExt, StreamExt, + stream::{SplitSink, SplitStream}, +}; +use tokio::sync::mpsc::{self, error::TrySendError}; + +use crate::{ + HttpClient, poll, + transport::{Transport, polling::PollingSvc}, +}; + +type SendPongFut = Pin< + Box< + dyn Future, Packet> as Sink>::Error>> + + 'static, + >, +>; + +pin_project_lite::pin_project! { + pub struct Client { + #[pin] + pub transport_rx: SplitStream>, + // TODO: is this the right implementation? We need something that can be driven itself. + // Otherwise we need a way to drive the transport_tx. Normally it should be driven by the user. + // But what if we need to send a PONG packet from the inner lib? + #[pin] + pub transport_tx: SplitSink, Packet>, + + pub sid: Sid, + // pub tx: mpsc::Sender, + // pub(crate) rx: Mutex>, + } +} + +impl Client +where + S::Error: fmt::Debug, + ::Error: fmt::Debug, +{ + pub async fn connect(svc: S) -> Result { + let mut inner = HttpClient::new(svc); + let packet = inner.handshake().await.unwrap(); + + let transport = Transport::Polling { inner }; + let (transport_tx, transport_rx) = transport.split(); + let client = Client { + transport_tx, + transport_rx, + sid: packet.sid, + }; + + Ok(client) + } +} + +impl Stream for Client +where + S::Error: fmt::Debug, + ::Error: fmt::Debug, +{ + type Item = Result; + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + match poll!(this.transport_rx.poll_next(cx)) { + Some(Ok(Packet::Ping)) => { + cx.waker().wake_by_ref(); + // let mut tx = self.transport_tx.clone(); + // let fut = async move { + // tx.send(Packet::Pong).await?; + // tx.flush().await + // }; + // this.pending_pong.set(Some(Box::pin(fut))); + + Poll::Pending + } + packet => Poll::Ready(packet), + } + } +} + +impl Sink for Client +where + S::Error: fmt::Debug, + ::Error: fmt::Debug, +{ + type Error = (); + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().transport_tx.poll_ready(cx) + } + + fn start_send(self: Pin<&mut Self>, item: Packet) -> Result<(), Self::Error> { + self.project().transport_tx.start_send(item) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().transport_tx.poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().transport_tx.poll_close(cx) + } +} diff --git a/crates/engineioxide-client/src/io.rs b/crates/engineioxide-client/src/io.rs new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/crates/engineioxide-client/src/io.rs @@ -0,0 +1 @@ + diff --git a/crates/engineioxide-client/src/lib.rs b/crates/engineioxide-client/src/lib.rs new file mode 100644 index 00000000..fb39e455 --- /dev/null +++ b/crates/engineioxide-client/src/lib.rs @@ -0,0 +1,19 @@ +// #![warn(clippy::pedantic)] +#![allow(clippy::similar_names)] +//! Engine.IO client library for Rust. + +mod client; +mod io; +mod transport; +pub use crate::client::Client; +pub use crate::transport::polling::HttpClient; + +#[macro_export] +macro_rules! poll { + ($expr:expr) => { + match $expr { + std::task::Poll::Pending => return std::task::Poll::Pending, + std::task::Poll::Ready(value) => value, + } + }; +} diff --git a/crates/engineioxide-client/src/transport/mod.rs b/crates/engineioxide-client/src/transport/mod.rs new file mode 100644 index 00000000..d7bd1b1f --- /dev/null +++ b/crates/engineioxide-client/src/transport/mod.rs @@ -0,0 +1,81 @@ +use std::{ + fmt, + pin::Pin, + task::{Context, Poll}, +}; + +use engineioxide_core::{Packet, PacketParseError, TransportType}; +use futures_core::Stream; +use futures_util::Sink; + +use crate::{HttpClient, transport::polling::PollingSvc}; + +pub mod polling; + +pin_project_lite::pin_project! { + #[project = TransportProj] + pub enum Transport { + Polling { + #[pin] + inner: HttpClient + }, + Websocket { + #[pin] + inner: HttpClient + } + } +} + +impl Transport { + pub fn transport_type(&self) -> TransportType { + match self { + Transport::Polling { .. } => TransportType::Polling, + Transport::Websocket { .. } => TransportType::Websocket, + } + } +} + +impl Stream for Transport +where + S::Error: fmt::Debug, + ::Error: fmt::Debug, +{ + type Item = Result; + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.as_mut().project() { + TransportProj::Polling { inner } => inner.poll_next(cx), + TransportProj::Websocket { inner } => inner.poll_next(cx), + } + } +} +impl Sink for Transport { + type Error = (); + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.project() { + TransportProj::Polling { inner } => inner.poll_ready(cx), + TransportProj::Websocket { inner } => inner.poll_ready(cx), + } + } + + fn start_send(self: Pin<&mut Self>, item: Packet) -> Result<(), Self::Error> { + match self.project() { + TransportProj::Polling { inner } => inner.start_send(item), + TransportProj::Websocket { inner } => inner.start_send(item), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.project() { + TransportProj::Polling { inner } => inner.poll_flush(cx), + TransportProj::Websocket { inner } => inner.poll_flush(cx), + } + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.project() { + TransportProj::Polling { inner } => inner.poll_close(cx), + TransportProj::Websocket { inner } => inner.poll_close(cx), + } + } +} diff --git a/crates/engineioxide-client/src/transport/polling.rs b/crates/engineioxide-client/src/transport/polling.rs new file mode 100644 index 00000000..a365988a --- /dev/null +++ b/crates/engineioxide-client/src/transport/polling.rs @@ -0,0 +1,304 @@ +use std::collections::VecDeque; +use std::fmt; +use std::pin::Pin; +use std::task::Context; +use std::task::Poll; + +use bytes::Bytes; +use engineioxide_core::OpenPacket; +use engineioxide_core::Packet; +use engineioxide_core::PacketBuf; +use engineioxide_core::PacketParseError; +use engineioxide_core::ProtocolVersion; +use engineioxide_core::Sid; +use engineioxide_core::payload; +use engineioxide_core::payload::Payload; +use futures_core::Stream; +use futures_util::Sink; +use futures_util::StreamExt; +use futures_util::stream; +use http::Request; +use http::Response; +use http::StatusCode; +use http_body_util::BodyExt; +use http_body_util::Full; +use hyper::service::Service as HyperSvc; +use pin_project_lite::pin_project; +use smallvec::smallvec; + +use crate::poll; + +pub trait PollingSvc: HyperSvc>, Response = Response> { + type Body: hyper::body::Body + 'static; +} + +impl PollingSvc for S +where + S: HyperSvc>, Response = Response>, + >>>::Error: fmt::Debug, + B: hyper::body::Body + 'static, + ::Error: std::fmt::Debug + 'static, + ::Data: Send + std::fmt::Debug + 'static, +{ + type Body = B; +} + +pin_project! { + #[project = PollStateProj] + #[derive(Default)] + enum PollState { + #[default] + No, + Pending { + #[pin] + fut: F + }, + Decoding { + stream: Pin>>> + } + } +} + +pin_project! { + #[project = PostStateProj] + enum PostState { + /// TODO: Ideally the queue should gradually encode packets in an async fashion + Queuing { + queue: VecDeque + }, + Encoding { + #[pin] + fut: Pin>>, + }, + Pending { + #[pin] + fut: F + } + } +} + +impl Default for PostState { + fn default() -> Self { + PostState::Queuing { + queue: VecDeque::new(), + } + } +} + +pin_project! { + pub struct HttpClient + { + svc: S, + + #[pin] + poll_state: PollState, + + #[pin] + post_state: PostState, + + sid: Option, + } +} + +impl HttpClient +where + S::Error: fmt::Debug, + ::Error: fmt::Debug, +{ + pub fn new(svc: S) -> Self { + Self { + svc, + poll_state: PollState::default(), + post_state: PostState::default(), + sid: None, + } + } + + pub async fn handshake(&mut self) -> Result { + #[cfg(feature = "tracing")] + tracing::trace!(?self, "handshake request"); + + let req = Request::builder() + .method("GET") + .uri("http://localhost:3000/engine.io?EIO=4&transport=polling") + .body(Full::default()) + .unwrap(); + + let res = self.svc.call(req).await; + let body = res.unwrap().collect().await.unwrap(); + let packet = Packet::try_from(String::from_utf8(body.to_bytes().to_vec()).unwrap())?; + + match packet { + Packet::Open(open) => { + self.sid = Some(open.sid); + Ok(open) + } + _ => Err(PacketParseError::InvalidPacketType(Some('1'))), + } + } +} + +impl Stream for HttpClient +where + S::Error: fmt::Debug, + ::Error: fmt::Debug, +{ + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + #[cfg(feature = "tracing")] + tracing::trace!(poll_state = ?self.poll_state, "polling"); + + let mut poll_state_proj = self.as_mut().project().poll_state.project(); + match poll_state_proj { + PollStateProj::No => { + let id = self.sid.unwrap(); + let uri = + format!("http://localhost:3000/engine.io?EIO=4&transport=polling&sid={id}"); + let req = Request::get(uri).body(Full::new(Bytes::new())).unwrap(); + let fut = self.svc.call(req); + self.project().poll_state.set(PollState::Pending { fut }); + cx.waker().wake_by_ref(); + Poll::Pending + } + PollStateProj::Pending { ref mut fut } => { + match poll!(fut.as_mut().poll(cx)) { + Ok(res) => { + let (parts, body) = res.into_parts(); + dbg!(&parts); + assert!(parts.status == StatusCode::OK); + let body = Box::pin(body); + //TODO: implement limited body + Content-Type + let stream = + payload::decoder(body, None, ProtocolVersion::V4, 200).boxed_local(); + + self.project() + .poll_state + .set(PollState::Decoding { stream }); + + cx.waker().wake_by_ref(); + Poll::Pending + } + Err(err) => { + #[cfg(feature = "tracing")] + tracing::debug!(?err, "got body error"); + Poll::Ready(Some(Err(PacketParseError::InvalidPacketPayload))) + } + } + } + PollStateProj::Decoding { ref mut stream } => { + if let Some(packet) = poll!(stream.poll_next_unpin(cx)) { + dbg!(&packet); + Poll::Ready(Some(packet)) + } else { + self.project().poll_state.set(PollState::No); + // Should not be needed. + cx.waker().wake_by_ref(); + Poll::Pending + } + } + } + } +} + +impl Sink for HttpClient { + type Error = (); + + fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + match self.post_state { + PostState::Queuing { .. } => Poll::Ready(Ok(())), + _ => Poll::Pending, + } + } + + fn start_send(self: Pin<&mut Self>, item: Packet) -> Result<(), Self::Error> { + #[cfg(feature = "tracing")] + tracing::trace!(post_state = ?self.post_state, "sending packet"); + + match self.project().post_state.project() { + PostStateProj::Queuing { queue } => queue.push_back(smallvec![item]), + _ => panic!("unexpected post state"), + } + + Ok(()) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + #[cfg(feature = "tracing")] + tracing::trace!(post_state = ?self.post_state, "flushing"); + + match self.as_mut().project().post_state.project() { + PostStateProj::Queuing { queue } if queue.is_empty() => Poll::Ready(Ok(())), + PostStateProj::Queuing { queue } => { + let fut = payload::encoder( + stream::iter(queue.clone()), + ProtocolVersion::V4, + true, + 102400000, + ); + self.project() + .post_state + .set(PostState::Encoding { fut: Box::pin(fut) }); + cx.waker().wake_by_ref(); + Poll::Pending + } + PostStateProj::Encoding { fut } => { + let body = poll!(fut.poll(cx)); + let req = Request::post( + "http://localhost:3000/engine.io?EIO=4&transport=polling&sid={id}", + ) + .body(Full::new(body.data)) //TODO: fix cloning + .unwrap(); + let fut = self.svc.call(req); + self.project().post_state.set(PostState::Pending { fut }); + cx.waker().wake_by_ref(); + Poll::Pending + } + PostStateProj::Pending { fut } => { + match poll!(fut.poll(cx)) { + Ok(res) => { + assert!(res.status().is_success()); + self.project().post_state.set(PostState::default()); + Poll::Ready(Ok(())) // TODO: check response == ok + } + Err(err) => { + self.project().post_state.set(PostState::default()); + todo!("handle error") + } + } + } + } + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +impl fmt::Debug for PollState { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::No => write!(f, "No"), + Self::Pending { .. } => write!(f, "Pending"), + Self::Decoding { .. } => write!(f, "Decoding"), + } + } +} +impl fmt::Debug for PostState { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Encoding { .. } => f.debug_struct("Encoding").finish(), + Self::Pending { .. } => write!(f, "Pending"), + Self::Queuing { .. } => write!(f, "Queuing"), + } + } +} +impl fmt::Debug for HttpClient { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("HttpClient") + .field("poll_state", &self.poll_state) + .field("post_state", &self.post_state) + .field("sid", &self.sid) + .finish() + } +} diff --git a/crates/engineioxide-client/tests/handshake.rs b/crates/engineioxide-client/tests/handshake.rs new file mode 100644 index 00000000..70d9d50b --- /dev/null +++ b/crates/engineioxide-client/tests/handshake.rs @@ -0,0 +1,121 @@ +use std::str::FromStr; +use std::sync::Arc; +use std::time::Duration; + +use bytes::Bytes; +use engineioxide::handler::EngineIoHandler; +use engineioxide::{DisconnectReason, service::EngineIoService}; +use engineioxide::{Socket, Str}; +use engineioxide_client::{Client, HttpClient}; +use engineioxide_core::{Packet, Sid}; +use futures_util::{SinkExt, StreamExt, TryFutureExt}; +use tokio::sync::mpsc; +use tracing_subscriber::EnvFilter; + +#[derive(Debug, PartialEq, Eq)] +enum Event { + Connect(Sid), + Disconnect(Sid, DisconnectReason), + Message(Sid, Str), + Binary(Sid, Bytes), +} + +#[derive(Debug)] +struct Handler { + tx: mpsc::Sender, +} + +impl Handler { + fn new() -> (Self, mpsc::Receiver) { + let (tx, rx) = mpsc::channel(100); + (Self { tx }, rx) + } +} + +impl EngineIoHandler for Handler { + type Data = (); + fn on_connect(self: Arc, socket: Arc>) { + self.tx.try_send(Event::Connect(socket.id)).unwrap(); + } + + fn on_disconnect(&self, socket: Arc>, reason: DisconnectReason) { + self.tx + .try_send(Event::Disconnect(socket.id, reason)) + .unwrap(); + } + + fn on_message(self: &Arc, msg: Str, socket: Arc>) { + self.tx.try_send(Event::Message(socket.id, msg)).unwrap(); + } + + fn on_binary(self: &Arc, data: Bytes, socket: Arc>) { + self.tx.try_send(Event::Binary(socket.id, data)).unwrap(); + } +} + +#[tokio::test] +async fn handshake() { + tracing_subscriber::fmt::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .try_init() + .ok(); + let (handler, mut rx) = Handler::new(); + let svc = EngineIoService::new(Arc::new(handler)); + let packet = HttpClient::new(svc).handshake().await.unwrap(); + assert_eq!(rx.recv().await.unwrap(), Event::Connect(packet.sid)); +} + +#[tokio::test] +async fn connect() { + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .try_init() + .ok(); + let (handler, mut rx) = Handler::new(); + let svc = EngineIoService::new(Arc::new(handler)); + let client = Client::connect(svc).await.unwrap(); + assert_eq!(rx.recv().await.unwrap(), Event::Connect(client.sid)); + let (ctx, mut crx) = client.split::(); + + while let Some(event) = crx.next().await { + match event { + Ok(event) => { + dbg!(event); + } + Err(e) => panic!("Error: {e}"), + } + } +} + +#[tokio::test] +async fn spaaam() { + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .try_init() + .ok(); + let (handler, mut rx) = Handler::new(); + let svc = EngineIoService::new(Arc::new(handler)); + let client = Client::connect(svc).await.unwrap(); + assert_eq!(rx.recv().await.unwrap(), Event::Connect(client.sid)); + let (mut ctx, mut crx) = client.split::(); + + tokio::task::LocalSet::new() + .run_until(async move { + tokio::task::spawn_local(async move { + loop { + ctx.send(Packet::Pong).await.unwrap(); + tokio::time::sleep(Duration::from_millis(100)).await; + } + }); + + while let Some(event) = crx.next().await { + match event { + Ok(event) => { + dbg!(event); + } + Err(e) => panic!("Error: {e}"), + } + } + }) + .await; +} diff --git a/crates/engineioxide-core/Cargo.toml b/crates/engineioxide-core/Cargo.toml index 58797493..ea146aff 100644 --- a/crates/engineioxide-core/Cargo.toml +++ b/crates/engineioxide-core/Cargo.toml @@ -17,3 +17,26 @@ rand = "0.9" base64 = "0.22" serde.workspace = true bytes.workspace = true +serde_json.workspace = true +http-body.workspace = true +http-body-util.workspace = true +http.workspace = true +futures-util.workspace = true +smallvec.workspace = true +pin-project-lite.workspace = true + +# Engine.io V3 payload +itoa = { workspace = true, optional = true } +memchr = { version = "2.7", optional = true } +unicode-segmentation = { version = "1.12", optional = true } + +# Tracing +tracing = { workspace = true, optional = true } + + +[features] +v3 = ["dep:memchr", "dep:unicode-segmentation", "dep:itoa"] +tracing = ["dep:tracing"] + +[dev-dependencies] +tokio = { workspace = true, features = ["macros", "rt"] } diff --git a/crates/engineioxide-core/src/lib.rs b/crates/engineioxide-core/src/lib.rs index 2d3c66cb..9cd36e90 100644 --- a/crates/engineioxide-core/src/lib.rs +++ b/crates/engineioxide-core/src/lib.rs @@ -29,8 +29,13 @@ )] #![doc = include_str!("../README.md")] +mod packet; +mod protocol; mod sid; mod str; +pub use packet::{OpenPacket, Packet, PacketBuf, PacketParseError}; +pub use protocol::{ProtocolVersion, TransportType, UnknownTransportError}; pub use sid::Sid; pub use str::Str; +pub mod payload; diff --git a/crates/engineioxide/src/packet.rs b/crates/engineioxide-core/src/packet.rs similarity index 65% rename from crates/engineioxide/src/packet.rs rename to crates/engineioxide-core/src/packet.rs index b65a8cf1..6cb1950d 100644 --- a/crates/engineioxide/src/packet.rs +++ b/crates/engineioxide-core/src/packet.rs @@ -1,11 +1,11 @@ +use std::fmt; + use base64::{Engine, engine::general_purpose}; use bytes::Bytes; -use engineioxide_core::{Sid, Str}; -use serde::Serialize; +use serde::{Deserialize, Serialize}; +use smallvec::SmallVec; -use crate::TransportType; -use crate::config::EngineIoConfig; -use crate::errors::Error; +use crate::{Sid, Str}; /// A Packet type to use when receiving and sending data from the client #[derive(Debug, Clone, PartialEq, PartialOrd)] @@ -51,6 +51,67 @@ pub enum Packet { BinaryV3(Bytes), // Not part of the protocol, used internally } +/// An error that occurs when parsing a packet. +#[derive(Debug)] +pub enum PacketParseError { + /// Invalid connect packet + InvalidConnectPacket(serde_json::Error), + /// The packet type is invalid. + InvalidPacketType(Option), + /// The packet payload is invalid. + InvalidPacketPayload, + /// The packet length is invalid. + InvalidPacketLen, + /// The packet chunk is invalid + InvalidUtf8Boundary(std::str::Utf8Error), + /// The base64 decoding failed. + Base64Decode(base64::DecodeError), + /// The payload is too large. + PayloadTooLarge { + /// The maximum allowed payload size. + max: u64, + }, +} +impl fmt::Display for PacketParseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + PacketParseError::InvalidConnectPacket(e) => write!(f, "invalid connect packet: {e}"), + PacketParseError::InvalidPacketType(c) => write!(f, "invalid packet type: {c:?}"), + PacketParseError::InvalidPacketPayload => write!(f, "invalid packet payload"), + PacketParseError::InvalidPacketLen => write!(f, "invalid packet length"), + PacketParseError::InvalidUtf8Boundary(err) => write!( + f, + "invalid utf8 boundary when parsing payload into packet chunks: {err}" + ), + PacketParseError::Base64Decode(err) => write!(f, "base64 decode error: {err}"), + PacketParseError::PayloadTooLarge { max } => { + write!(f, "payload too large: max {max}") + } + } + } +} +impl From for PacketParseError { + fn from(err: base64::DecodeError) -> Self { + PacketParseError::Base64Decode(err) + } +} +impl From for PacketParseError { + fn from(err: std::string::FromUtf8Error) -> Self { + PacketParseError::InvalidUtf8Boundary(err.utf8_error()) + } +} +impl From for PacketParseError { + fn from(err: std::str::Utf8Error) -> Self { + PacketParseError::InvalidUtf8Boundary(err) + } +} +impl From for PacketParseError { + fn from(err: serde_json::Error) -> Self { + PacketParseError::InvalidConnectPacket(err) + } +} +impl std::error::Error for PacketParseError {} + impl Packet { /// Check if the packet is a binary packet pub fn is_binary(&self) -> bool { @@ -58,7 +119,7 @@ impl Packet { } /// If the packet is a message packet (text), it returns the message - pub(crate) fn into_message(self) -> Str { + pub fn into_message(self) -> Str { match self { Packet::Message(msg) => msg, _ => panic!("Packet is not a message"), @@ -66,7 +127,7 @@ impl Packet { } /// If the packet is a binary packet, it returns the binary data - pub(crate) fn into_binary(self) -> Bytes { + pub fn into_binary(self) -> Bytes { match self { Packet::Binary(data) => data, Packet::BinaryV3(data) => data, @@ -79,7 +140,7 @@ impl Packet { /// If b64 is true, it returns the max size when serialized to base64 /// /// The base64 max size factor is `ceil(n / 3) * 4` - pub(crate) fn get_size_hint(&self, b64: bool) -> usize { + pub fn get_size_hint(&self, b64: bool) -> usize { match self { Packet::Open(_) => 156, // max possible size for the open packet serialized Packet::Close => 1, @@ -108,6 +169,12 @@ impl Packet { } } +impl From for Bytes { + fn from(value: Packet) -> Self { + String::from(value).into() + } +} + /// Serialize a [Packet] to a [String] according to the Engine.IO protocol impl From for String { fn from(packet: Packet) -> String { @@ -141,26 +208,18 @@ impl From for String { buffer } } -impl From for tokio_tungstenite::tungstenite::Utf8Bytes { - fn from(value: Packet) -> Self { - String::from(value).into() - } -} -impl From for Bytes { - fn from(value: Packet) -> Self { - String::from(value).into() - } -} + /// Deserialize a [Packet] from a [String] according to the Engine.IO protocol impl TryFrom for Packet { - type Error = Error; + type Error = PacketParseError; fn try_from(value: Str) -> Result { let packet_type = value .as_bytes() .first() - .ok_or(Error::InvalidPacketType(None))?; + .ok_or(PacketParseError::InvalidPacketType(None))?; let is_upgrade = value.len() == 6 && &value[1..6] == "probe"; let res = match packet_type { + b'0' => Packet::Open(serde_json::from_str(value.slice(1..).as_str())?), b'1' => Packet::Close, b'2' if is_upgrade => Packet::PingUpgrade, b'2' => Packet::Ping, @@ -179,59 +238,58 @@ impl TryFrom for Packet { .decode(value.slice(1..).as_bytes())? .into(), ), - c => Err(Error::InvalidPacketType(Some(*c as char)))?, + c => Err(PacketParseError::InvalidPacketType(Some(*c as char)))?, }; Ok(res) } } -impl TryFrom for Packet { - type Error = Error; - fn try_from(value: tokio_tungstenite::tungstenite::Utf8Bytes) -> Result { - // SAFETY: The utf8 bytes are guaranteed to be valid utf8 - Packet::try_from(unsafe { Str::from_bytes_unchecked(value.into()) }) - } -} impl TryFrom for Packet { - type Error = Error; + type Error = PacketParseError; fn try_from(value: String) -> Result { Packet::try_from(Str::from(value)) } } /// An OpenPacket is used to initiate a connection -#[derive(Debug, Clone, Serialize, PartialEq, PartialOrd)] +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, PartialOrd)] #[serde(rename_all = "camelCase")] pub struct OpenPacket { - sid: Sid, - upgrades: Vec, - ping_interval: u64, - ping_timeout: u64, - max_payload: u64, + /// The session ID. + pub sid: Sid, + /// The list of available transport upgrades. + pub upgrades: Vec, + /// The ping interval, used in the heartbeat mechanism (in milliseconds). + pub ping_interval: u64, + /// The ping timeout, used in the heartbeat mechanism (in milliseconds). + pub ping_timeout: u64, + /// The maximum number of bytes per chunk, used by the client to + /// aggregate packets into payloads. + pub max_payload: u64, } -impl OpenPacket { - /// Create a new [OpenPacket] - /// If the current transport is polling, the server will always allow the client to upgrade to websocket - pub fn new(transport: TransportType, sid: Sid, config: &EngineIoConfig) -> Self { - let upgrades = if transport == TransportType::Polling { - vec!["websocket".to_string()] - } else { - vec![] - }; - OpenPacket { - sid, - upgrades, - ping_interval: config.ping_interval.as_millis() as u64, - ping_timeout: config.ping_timeout.as_millis() as u64, - max_payload: config.max_payload, +/// This default implementation should only be used for testing purposes. +impl Default for OpenPacket { + fn default() -> Self { + Self { + sid: Sid::ZERO, + upgrades: vec!["websocket".to_string()], + ping_interval: 25000, + ping_timeout: 20000, + max_payload: 100000, } } } +/// Buffered packets to send to the client. +/// It is used to ensure atomicity when sending multiple packets to the client. +/// +/// The [`PacketBuf`] stack size will impact the dynamically allocated buffer +/// of the internal mpsc channel. +pub type PacketBuf = SmallVec<[Packet; 2]>; + #[cfg(test)] mod tests { - use crate::config::EngineIoConfig; use super::*; use std::{convert::TryInto, time::Duration}; @@ -239,11 +297,13 @@ mod tests { #[test] fn test_open_packet() { let sid = Sid::new(); - let packet = Packet::Open(OpenPacket::new( - TransportType::Polling, + let packet = Packet::Open(OpenPacket { sid, - &EngineIoConfig::default(), - )); + upgrades: vec!["websocket".to_string()], + ping_interval: Duration::from_millis(25000).as_millis() as u64, + ping_timeout: Duration::from_millis(20000).as_millis() as u64, + max_payload: 100000, + }); let packet_str: String = packet.into(); assert_eq!( packet_str, @@ -253,6 +313,23 @@ mod tests { ); } + #[test] + fn test_open_packet_deserialize() { + let sid = Sid::new(); + let ref_packet = OpenPacket { + sid, + upgrades: vec!["websocket".to_string()], + ping_interval: Duration::from_millis(25000).as_millis() as u64, + ping_timeout: Duration::from_millis(20000).as_millis() as u64, + max_payload: 100000, + }; + let packet_str = format!( + "0{{\"sid\":\"{sid}\",\"upgrades\":[\"websocket\"],\"pingInterval\":25000,\"pingTimeout\":20000,\"maxPayload\":100000}}" + ); + let packet = Packet::try_from(packet_str).unwrap(); + assert!(matches!(packet, Packet::Open(p) if p == ref_packet)); + } + #[test] fn test_message_packet() { let packet = Packet::Message("hello".into()); @@ -298,18 +375,13 @@ mod tests { #[test] fn test_packet_get_size_hint() { // Max serialized packet - let open = OpenPacket::new( - TransportType::Polling, - Sid::new(), - &EngineIoConfig { - max_buffer_size: usize::MAX, - max_payload: u64::MAX, - ping_interval: Duration::MAX, - ping_timeout: Duration::MAX, - transports: TransportType::Polling as u8 | TransportType::Websocket as u8, - ..Default::default() - }, - ); + let open = OpenPacket { + sid: Sid::new(), + ping_interval: u64::MAX, + ping_timeout: u64::MAX, + max_payload: u64::MAX, + upgrades: vec!["websocket".to_string()], + }; let size = serde_json::to_string(&open).unwrap().len(); let packet = Packet::Open(open); assert_eq!(packet.get_size_hint(false), size); diff --git a/crates/engineioxide/src/transport/polling/payload/buf.rs b/crates/engineioxide-core/src/payload/buf.rs similarity index 100% rename from crates/engineioxide/src/transport/polling/payload/buf.rs rename to crates/engineioxide-core/src/payload/buf.rs diff --git a/crates/engineioxide/src/transport/polling/payload/decoder.rs b/crates/engineioxide-core/src/payload/decoder.rs similarity index 88% rename from crates/engineioxide/src/transport/polling/payload/decoder.rs rename to crates/engineioxide-core/src/payload/decoder.rs index 3ef454d3..555a94ba 100644 --- a/crates/engineioxide/src/transport/polling/payload/decoder.rs +++ b/crates/engineioxide-core/src/payload/decoder.rs @@ -5,11 +5,9 @@ //! - v3_decoder: Decodes the payload stream according to the [engine.io v3 protocol](https://github.com/socketio/engine.io-protocol/tree/v3#payload) //! -use futures_core::Stream; -use futures_util::StreamExt; -use http::StatusCode; +use crate::{Packet, PacketParseError}; +use futures_util::{Stream, StreamExt}; -use crate::{errors::Error, packet::Packet}; use bytes::Buf; use http_body::Body; use http_body_util::BodyStream; @@ -42,7 +40,7 @@ impl Payload { /// Polls the body stream for data and adds it to the chunk list in the state /// Returns an error if the packet length exceeds the maximum allowed payload size -async fn poll_body(state: &mut Payload, max_payload: u64) -> Result<(), Error> +async fn poll_body(state: &mut Payload, max_payload: u64) -> Result<(), PacketParseError> where B: Body + Unpin, E: std::fmt::Debug, @@ -59,7 +57,7 @@ where Err(_e) => { #[cfg(feature = "tracing")] tracing::debug!("error reading body stream: {:?}", _e); - Err(Error::HttpErrorResponse(StatusCode::BAD_REQUEST)) + Err(PacketParseError::InvalidPacketPayload) } }?; if state.current_payload_size + (data.remaining() as u64) <= max_payload { @@ -67,11 +65,14 @@ where state.buffer.push(data); Ok(()) } else { - Err(Error::PayloadTooLarge) + Err(PacketParseError::PayloadTooLarge { max: max_payload }) } } -pub fn v4_decoder(body: B, max_payload: u64) -> impl Stream> +pub fn v4_decoder( + body: B, + max_payload: u64, +) -> impl Stream> where B: Body + Unpin, E: std::fmt::Debug, @@ -93,11 +94,14 @@ where } // Read from the buffer until the packet separator is found - if let Err(e) = (&mut state.buffer) + if let Err(_err) = (&mut state.buffer) .reader() .read_until(PACKET_SEPARATOR_V4, &mut packet_buf) { - break Some((Err(Error::Io(e)), state)); + #[cfg(feature = "tracing")] + tracing::debug!("failed to read packet payload: {_err}"); + + break Some((Err(PacketParseError::InvalidPacketPayload), state)); } let separator_found = packet_buf.ends_with(&[PACKET_SEPARATOR_V4]); @@ -111,7 +115,7 @@ where || (state.end_of_stream && state.buffer.remaining() == 0 && !packet_buf.is_empty()) { let packet = String::from_utf8(packet_buf) - .map_err(|_| Error::InvalidPacketLength) + .map_err(PacketParseError::from) .and_then(Packet::try_from); // Convert the packet buffer to a Packet object break Some((packet, state)); // Emit the packet and the updated state } else if state.end_of_stream && state.buffer.remaining() == 0 { @@ -125,14 +129,14 @@ where pub fn v3_binary_decoder( body: B, max_payload: u64, -) -> impl Stream> +) -> impl Stream> where B: Body + Unpin, E: std::fmt::Debug, { use std::io::Read; - use crate::transport::polling::payload::{ + use crate::payload::{ BINARY_PACKET_IDENTIFIER_V3, BINARY_PACKET_SEPARATOR_V3, STRING_PACKET_IDENTIFIER_V3, }; @@ -156,11 +160,14 @@ where // If there is no packet_type found if packet_type.is_none() && state.buffer.remaining() > 0 { // Read from the buffer until the packet separator is found - if let Err(e) = (&mut state.buffer) + if let Err(_err) = (&mut state.buffer) .reader() .read_until(BINARY_PACKET_SEPARATOR_V3, &mut packet_buf) { - break Some((Err(Error::Io(e)), state)); + #[cfg(feature = "tracing")] + tracing::debug!("failed to read packet payload: {_err}"); + + break Some((Err(PacketParseError::InvalidPacketPayload), state)); } // Extract packet_type and packet_size @@ -173,11 +180,11 @@ where Some(&STRING_PACKET_IDENTIFIER_V3) => { packet_type = Some(STRING_PACKET_IDENTIFIER_V3) } - _ => break Some((Err(Error::InvalidPacketLength), state)), + _ => break Some((Err(PacketParseError::InvalidPacketLen), state)), } if packet_buf.len() > 9 { - break Some((Err(Error::InvalidPacketLength), state)); + break Some((Err(PacketParseError::InvalidPacketLen), state)); } let size_str = &packet_buf[1..] @@ -187,7 +194,7 @@ where if let Ok(size) = size_str.parse() { packet_size = size; } else { - break Some((Err(Error::InvalidPacketLength), state)); + break Some((Err(PacketParseError::InvalidPacketLen), state)); } packet_buf.clear(); } @@ -204,10 +211,10 @@ where // Read the packet data let packet = match packet_type.unwrap() { STRING_PACKET_IDENTIFIER_V3 => String::from_utf8(packet_buf) - .map_err(|_| Error::InvalidPacketLength) + .map_err(PacketParseError::from) .and_then(Packet::try_from), // Convert the packet buffer to a Packet object BINARY_PACKET_IDENTIFIER_V3 => Ok(Packet::BinaryV3(packet_buf.into())), - _ => Err(Error::InvalidPacketLength), + _ => Err(PacketParseError::InvalidPacketLen), }; break Some((packet, state)); @@ -222,11 +229,11 @@ where pub fn v3_string_decoder( body: impl Body + Unpin, max_payload: u64, -) -> impl Stream> { +) -> impl Stream> { use std::io::ErrorKind; use unicode_segmentation::UnicodeSegmentation; - use crate::transport::polling::payload::STRING_PACKET_SEPARATOR_V3; + use crate::payload::STRING_PACKET_SEPARATOR_V3; #[cfg(feature = "tracing")] tracing::debug!("decoding payload with v3 string decoder"); @@ -245,7 +252,7 @@ pub fn v3_string_decoder( if state.end_of_stream && state.buffer.remaining() == 0 && state.yield_packets > 0 { break None; // Reached end of stream with no more data, end the stream } else if state.end_of_stream && state.buffer.remaining() == 0 { - return Some((Err(Error::InvalidPacketLength), state)); + return Some((Err(PacketParseError::InvalidPacketLen), state)); } let mut reader = (&mut state.buffer).reader(); @@ -258,7 +265,12 @@ pub fn v3_string_decoder( let available = match reader.fill_buf() { Ok(n) => n, Err(ref e) if e.kind() == ErrorKind::Interrupted => continue, - Err(e) => return Some((Err(Error::Io(e)), state)), + Err(_err) => { + #[cfg(feature = "tracing")] + tracing::debug!("failed to read packet payload: {_err}"); + + return Some((Err(PacketParseError::InvalidPacketPayload), state)); + } }; let old_len = packet_buf.len(); packet_buf.extend_from_slice(available); @@ -267,9 +279,10 @@ pub fn v3_string_decoder( Some(i) => { // Extract the packet length from the available data packet_graphemes_len = match std::str::from_utf8(&packet_buf[..i]) - .map_err(|_| Error::InvalidPacketLength) + .map_err(PacketParseError::from) .and_then(|s| { - s.parse::().map_err(|_| Error::InvalidPacketLength) + s.parse::() + .map_err(|_| PacketParseError::InvalidPacketLen) }) { Ok(size) => size, Err(e) => return Some((Err(e), state)), @@ -279,7 +292,7 @@ pub fn v3_string_decoder( (true, i + 1 - old_len) // Mark as done and set the used bytes count } None if state.end_of_stream && remaining - available.len() == 0 => { - return Some((Err(Error::InvalidPacketLength), state)); + return Some((Err(PacketParseError::InvalidPacketLen), state)); } // Reached end of stream and end of bufferered chunks without finding the separator None => (false, available.len()), // Continue reading more data } @@ -336,8 +349,9 @@ pub fn v3_string_decoder( if let Ok(packet) = std::str::from_utf8(&packet_buf) { if packet.graphemes(true).count() == packet_graphemes_len { // SAFETY: packet_buf is a valid utf8 string checkd above + let packet = unsafe { String::from_utf8_unchecked(packet_buf) }; - let packet = Packet::try_from(packet).map_err(|_| Error::InvalidPacketLength); + let packet = Packet::try_from(packet); state.yield_packets += 1; break Some((packet, state)); // Emit the packet and the updated state } @@ -356,8 +370,6 @@ mod tests { use http_body::Frame; use http_body_util::{Full, StreamBody}; - use crate::packet::Packet; - use super::*; const MAX_PAYLOAD: u64 = 100_000; @@ -423,7 +435,10 @@ mod tests { let payload = v4_decoder(stream, MAX_PAYLOAD); futures_util::pin_mut!(payload); let packet = payload.next().await.unwrap(); - assert!(matches!(packet, Err(Error::PayloadTooLarge))); + assert!(matches!( + packet, + Err(PacketParseError::PayloadTooLarge { max: MAX_PAYLOAD }) + )); } } @@ -551,7 +566,10 @@ mod tests { let payload = v3_binary_decoder(stream, MAX_PAYLOAD); futures_util::pin_mut!(payload); let packet = payload.next().await.unwrap(); - assert!(matches!(packet, Err(Error::PayloadTooLarge))); + assert!(matches!( + packet, + Err(PacketParseError::PayloadTooLarge { max: MAX_PAYLOAD }) + )); } for i in 1..DATA.len() { let stream = StreamBody::new(futures_util::stream::iter( @@ -562,7 +580,10 @@ mod tests { let payload = v3_string_decoder(stream, MAX_PAYLOAD); futures_util::pin_mut!(payload); let packet = payload.next().await.unwrap(); - assert!(matches!(packet, Err(Error::PayloadTooLarge))); + assert!(matches!( + packet, + Err(PacketParseError::PayloadTooLarge { max: MAX_PAYLOAD }) + )); } } } diff --git a/crates/engineioxide/src/transport/polling/payload/encoder.rs b/crates/engineioxide-core/src/payload/encoder.rs similarity index 61% rename from crates/engineioxide/src/transport/polling/payload/encoder.rs rename to crates/engineioxide-core/src/payload/encoder.rs index c40fc3ce..30de4190 100644 --- a/crates/engineioxide/src/transport/polling/payload/encoder.rs +++ b/crates/engineioxide-core/src/payload/encoder.rs @@ -7,12 +7,12 @@ //! * binary encoder (used when there are binary packets and the client supports binary) //! -use tokio::sync::MutexGuard; +use std::pin::Pin; -use crate::{ - errors::Error, packet::Packet, peekable::PeekableReceiver, socket::PacketBuf, - transport::polling::payload::Payload, -}; +use futures_util::{FutureExt, Stream, StreamExt, stream::Peekable}; +use smallvec::smallvec; + +use crate::{Packet, packet::PacketBuf, payload::Payload}; /// Try to immediately poll a new packet buf from the rx channel and check that the new packet can be added to the payload /// @@ -25,12 +25,12 @@ use crate::{ /// * `max_payload` - The maximum payload length /// * `b64` - If binary packets should be encoded in base64 fn try_recv_packet( - rx: &mut MutexGuard<'_, PeekableReceiver>, + mut rx: Pin<&mut Peekable>>, payload_len: usize, max_payload: u64, b64: bool, ) -> Option { - if let Some(packets) = rx.peek() { + if let Some(packets) = rx.as_mut().peek().now_or_never().flatten() { let size = packets.iter().map(|p| p.get_size_hint(b64)).sum::(); if (payload_len + size) as u64 > max_payload { #[cfg(feature = "tracing")] @@ -39,14 +39,14 @@ fn try_recv_packet( } } - let packets = rx.try_recv().ok(); + let packets = rx.next().now_or_never().flatten(); - if Some(&Packet::Close) == packets.as_ref().and_then(|p| p.first()) { - #[cfg(feature = "tracing")] - tracing::debug!("Received close packet, closing channel"); - rx.try_recv().ok(); - rx.close(); - } + // if Some(&Packet::Close) == packets.as_ref().and_then(|p| p.first()) { + // #[cfg(feature = "tracing")] + // tracing::debug!("Received close packet, closing channel"); + // rx.try_recv().ok(); + // rx.close(); + // } #[cfg(feature = "tracing")] tracing::debug!("sending packet: {:?}", packets); @@ -55,28 +55,28 @@ fn try_recv_packet( /// Same as [`try_recv_packet`] /// but wait for a new packet if there is no packet in the buffer -async fn recv_packet( - rx: &mut MutexGuard<'_, PeekableReceiver>, -) -> Result { - let packet = rx.recv().await.ok_or(Error::Aborted)?; +async fn recv_packet(mut rx: Pin<&mut Peekable>>) -> PacketBuf { + let packet = rx.next().await.unwrap_or(smallvec![]); + if Some(&Packet::Close) == packet.first() { #[cfg(feature = "tracing")] tracing::debug!("Received close packet, closing channel"); - rx.close(); + + // rx.close(); } #[cfg(feature = "tracing")] tracing::debug!("sending packet: {:?}", packet); - Ok(packet) + packet } /// Encode multiple packets into a string payload according to the /// [engine.io v4 protocol](https://socket.io/fr/docs/v4/engine-io-protocol/#http-long-polling-1) pub async fn v4_encoder( - mut rx: MutexGuard<'_, PeekableReceiver>, + mut rx: Pin<&mut Peekable>>, max_payload: u64, -) -> Result { - use crate::transport::polling::payload::PACKET_SEPARATOR_V4; +) -> Payload { + use crate::payload::PACKET_SEPARATOR_V4; #[cfg(feature = "tracing")] tracing::debug!("encoding payload with v4 encoder"); @@ -85,7 +85,7 @@ pub async fn v4_encoder( // Send all packets in the buffer const PUNCTUATION_LEN: usize = 1; while let Some(packets) = - try_recv_packet(&mut rx, data.len() + PUNCTUATION_LEN, max_payload, true) + try_recv_packet(rx.as_mut(), data.len() + PUNCTUATION_LEN, max_payload, true) { for packet in packets { let packet: String = packet.into(); @@ -99,21 +99,21 @@ pub async fn v4_encoder( // If there is no packet in the buffer, wait for the next packet if data.is_empty() { - let packets = recv_packet(&mut rx).await?; + let packets = recv_packet(rx.as_mut()).await; for packet in packets { let packet: String = packet.into(); data.push_str(&packet); } } - Ok(Payload::new(data.into(), false)) + Payload::new(data.into(), false) } /// Encode one packet into a *binary* payload according to the /// [engine.io v3 protocol](https://github.com/socketio/engine.io-protocol/tree/v3#payload) #[cfg(feature = "v3")] pub fn v3_bin_packet_encoder(packet: Packet, data: &mut bytes::BytesMut) { - use crate::transport::polling::payload::BINARY_PACKET_SEPARATOR_V3; + use crate::payload::BINARY_PACKET_SEPARATOR_V3; use bytes::BufMut; let mut itoa = itoa::Buffer::new(); @@ -153,7 +153,7 @@ pub fn v3_bin_packet_encoder(packet: Packet, data: &mut bytes::BytesMut) { /// [engine.io v3 protocol](https://github.com/socketio/engine.io-protocol/tree/v3#payload) #[cfg(feature = "v3")] pub fn v3_string_packet_encoder(packet: Packet, data: &mut bytes::BytesMut) { - use crate::transport::polling::payload::STRING_PACKET_SEPARATOR_V3; + use crate::payload::STRING_PACKET_SEPARATOR_V3; use bytes::BufMut; let packet: String = packet.into(); let packet = format!( @@ -169,9 +169,9 @@ pub fn v3_string_packet_encoder(packet: Packet, data: &mut bytes::BytesMut) { /// according to the [engine.io v3 protocol](https://github.com/socketio/engine.io-protocol/tree/v3#payload) #[cfg(feature = "v3")] pub async fn v3_binary_encoder( - mut rx: MutexGuard<'_, PeekableReceiver>, + mut rx: Pin<&mut Peekable>>, max_payload: u64, -) -> Result { +) -> Payload { let mut data = bytes::BytesMut::new(); let mut packet_buffer: Vec = Vec::new(); @@ -185,7 +185,7 @@ pub async fn v3_binary_encoder( // buffer all packets to find if there is binary packets let mut has_binary = false; - while let Some(packets) = try_recv_packet(&mut rx, estimated_size, max_payload, false) { + while let Some(packets) = try_recv_packet(rx.as_mut(), estimated_size, max_payload, false) { for packet in packets { if packet.is_binary() { has_binary = true; @@ -210,7 +210,7 @@ pub async fn v3_binary_encoder( // If there is no packet in the buffer, wait for the next packet if data.is_empty() { - let packets = recv_packet(&mut rx).await?; + let packets = recv_packet(rx.as_mut()).await; for packet in packets { match packet { Packet::BinaryV3(_) | Packet::Binary(_) => { @@ -226,16 +226,16 @@ pub async fn v3_binary_encoder( #[cfg(feature = "tracing")] tracing::debug!("sending packet: {:?}", &data); - Ok(Payload::new(data.freeze(), has_binary)) + Payload::new(data.freeze(), has_binary) } /// Encode multiple packet packet into a *string* payload according to the /// [engine.io v3 protocol](https://github.com/socketio/engine.io-protocol/tree/v3#payload) #[cfg(feature = "v3")] pub async fn v3_string_encoder( - mut rx: MutexGuard<'_, PeekableReceiver>, + mut rx: Pin<&mut Peekable>>, max_payload: u64, -) -> Result { +) -> Payload { let mut data = bytes::BytesMut::new(); #[cfg(feature = "tracing")] @@ -246,7 +246,7 @@ pub async fn v3_string_encoder( let max_packet_size_len = max_payload.checked_ilog10().unwrap_or(0) as usize + 1; // Current size of the payload let current_size = data.len() + PUNCTUATION_LEN + max_packet_size_len; - while let Some(packets) = try_recv_packet(&mut rx, current_size, max_payload, true) { + while let Some(packets) = try_recv_packet(rx.as_mut(), current_size, max_payload, true) { for packet in packets { v3_string_packet_encoder(packet, &mut data); } @@ -254,21 +254,19 @@ pub async fn v3_string_encoder( // If there is no packet in the buffer, wait for the next packet if data.is_empty() { - let packets = recv_packet(&mut rx).await?; + let packets = recv_packet(rx.as_mut()).await; for packet in packets { v3_string_packet_encoder(packet, &mut data); } } - Ok(Payload::new(data.freeze(), false)) + Payload::new(data.freeze(), false) } #[cfg(test)] mod tests { use bytes::Bytes; - use tokio::sync::Mutex; - - use PacketBuf; + use futures_util::stream; use super::*; const MAX_PAYLOAD: u64 = 100_000; @@ -276,49 +274,41 @@ mod tests { #[tokio::test] async fn encode_v4_payload() { const PAYLOAD: &str = "4hello€\x1ebAQIDBA==\x1e4hello€"; - let (tx, rx) = tokio::sync::mpsc::channel::(10); - let rx = Mutex::new(PeekableReceiver::new(rx)); - let rx = rx.lock().await; - tx.try_send(smallvec::smallvec![Packet::Message("hello€".into())]) - .unwrap(); - tx.try_send(smallvec::smallvec![Packet::Binary(Bytes::from_static(&[ - 1, 2, 3, 4 - ]))]) - .unwrap(); - tx.try_send(smallvec::smallvec![Packet::Message("hello€".into())]) - .unwrap(); - let Payload { data, .. } = v4_encoder(rx, MAX_PAYLOAD).await.unwrap(); + + let rx = stream::iter([ + smallvec![Packet::Message("hello€".into())], + smallvec![Packet::Binary(Bytes::from_static(&[1, 2, 3, 4]))], + smallvec![Packet::Message("hello€".into())], + ]); + let rx = std::pin::pin!(rx.peekable()); + + let Payload { data, .. } = v4_encoder(rx, MAX_PAYLOAD).await; assert_eq!(data, PAYLOAD.as_bytes()); } #[tokio::test] async fn max_payload_v4() { const MAX_PAYLOAD: u64 = 10; - let (tx, rx) = tokio::sync::mpsc::channel::(10); - let mutex = Mutex::new(PeekableReceiver::new(rx)); - tx.try_send(smallvec::smallvec![Packet::Message("hello€".into())]) - .unwrap(); - tx.try_send(smallvec::smallvec![Packet::Binary(Bytes::from_static(&[ - 1, 2, 3, 4 - ]))]) - .unwrap(); - tx.try_send(smallvec::smallvec![Packet::Message("hello€".into())]) - .unwrap(); - tx.try_send(smallvec::smallvec![Packet::Message("hello€".into())]) - .unwrap(); + + let rx = stream::iter([ + smallvec![Packet::Message("hello€".into())], + smallvec![Packet::Binary(Bytes::from_static(&[1, 2, 3, 4]))], + smallvec![Packet::Message("hello€".into())], + smallvec![Packet::Message("hello€".into())], + ]); + + let mut rx = std::pin::pin!(rx.peekable()); + { - let rx = mutex.lock().await; - let Payload { data, .. } = v4_encoder(rx, MAX_PAYLOAD).await.unwrap(); + let Payload { data, .. } = v4_encoder(rx.as_mut(), MAX_PAYLOAD).await; assert_eq!(data, "4hello€".as_bytes()); } { - let rx = mutex.lock().await; - let Payload { data, .. } = v4_encoder(rx, MAX_PAYLOAD + 10).await.unwrap(); + let Payload { data, .. } = v4_encoder(rx.as_mut(), MAX_PAYLOAD + 10).await; assert_eq!(data, "bAQIDBA==\x1e4hello€".as_bytes()); } { - let rx = mutex.lock().await; - let Payload { data, .. } = v4_encoder(rx, MAX_PAYLOAD + 10).await.unwrap(); + let Payload { data, .. } = v4_encoder(rx.as_mut(), MAX_PAYLOAD + 10).await; assert_eq!(data, "4hello€".as_bytes()); } } @@ -327,21 +317,14 @@ mod tests { #[tokio::test] async fn encode_v3b64_payload() { const PAYLOAD: &str = "7:4hello€10:b4AQIDBA==7:4hello€"; - let (tx, rx) = tokio::sync::mpsc::channel::(10); - let mutex = Mutex::new(PeekableReceiver::new(rx)); - let rx = mutex.lock().await; - - tx.try_send(smallvec::smallvec![Packet::Message("hello€".into())]) - .unwrap(); - tx.try_send(smallvec::smallvec![Packet::BinaryV3(Bytes::from_static( - &[1, 2, 3, 4] - ))]) - .unwrap(); - tx.try_send(smallvec::smallvec![Packet::Message("hello€".into())]) - .unwrap(); - let Payload { - data, has_binary, .. - } = v3_string_encoder(rx, MAX_PAYLOAD).await.unwrap(); + let rx = stream::iter([ + smallvec![Packet::Message("hello€".into())], + smallvec![Packet::BinaryV3(Bytes::from_static(&[1, 2, 3, 4]))], + smallvec![Packet::Message("hello€".into())], + ]); + + let rx = std::pin::pin!(rx.peekable()); + let Payload { data, has_binary } = v3_string_encoder(rx, MAX_PAYLOAD).await; assert_eq!(data, PAYLOAD.as_bytes()); assert!(!has_binary); } @@ -351,26 +334,20 @@ mod tests { async fn max_payload_v3_b64() { const MAX_PAYLOAD: u64 = 10; - let (tx, rx) = tokio::sync::mpsc::channel::(10); - let mutex = Mutex::new(PeekableReceiver::new(rx)); - tx.try_send(smallvec::smallvec![Packet::Message("hello€".into())]) - .unwrap(); - tx.try_send(smallvec::smallvec![Packet::BinaryV3(Bytes::from_static( - &[1, 2, 3, 4] - ))]) - .unwrap(); - tx.try_send(smallvec::smallvec![Packet::Message("hello€".into())]) - .unwrap(); - tx.try_send(smallvec::smallvec![Packet::Message("hello€".into())]) - .unwrap(); + let rx = stream::iter(vec![ + smallvec::smallvec![Packet::Message("hello€".into())], + smallvec::smallvec![Packet::BinaryV3(Bytes::from_static(&[1, 2, 3, 4]))], + smallvec::smallvec![Packet::Message("hello€".into())], + smallvec::smallvec![Packet::Message("hello€".into())], + ]); + let mut rx = std::pin::pin!(rx.peekable()); + { - let rx = mutex.lock().await; - let Payload { data, .. } = v3_string_encoder(rx, MAX_PAYLOAD).await.unwrap(); + let Payload { data, .. } = v3_string_encoder(rx.as_mut(), MAX_PAYLOAD).await; assert_eq!(data, "7:4hello€".as_bytes()); } { - let rx = mutex.lock().await; - let Payload { data, .. } = v3_string_encoder(rx, MAX_PAYLOAD + 10).await.unwrap(); + let Payload { data, .. } = v3_string_encoder(rx.as_mut(), MAX_PAYLOAD + 10).await; assert_eq!(data, "10:b4AQIDBA==7:4hello€7:4hello€".as_bytes()); } } @@ -381,19 +358,14 @@ mod tests { const PAYLOAD: [u8; 20] = [ 0, 9, 255, 52, 104, 101, 108, 108, 111, 226, 130, 172, 1, 5, 255, 4, 1, 2, 3, 4, ]; - let (tx, rx) = tokio::sync::mpsc::channel::(10); - let mutex = Mutex::new(PeekableReceiver::new(rx)); - let rx = mutex.lock().await; - - tx.try_send(smallvec::smallvec![Packet::Message("hello€".into())]) - .unwrap(); - tx.try_send(smallvec::smallvec![Packet::BinaryV3(Bytes::from_static( - &[1, 2, 3, 4] - ))]) - .unwrap(); - let Payload { - data, has_binary, .. - } = v3_binary_encoder(rx, MAX_PAYLOAD).await.unwrap(); + + let rx = stream::iter([ + smallvec![Packet::Message("hello€".into())], + smallvec![Packet::BinaryV3(Bytes::from_static(&[1, 2, 3, 4]))], + ]); + let rx = std::pin::pin!(rx.peekable()); + + let Payload { data, has_binary } = v3_binary_encoder(rx, MAX_PAYLOAD).await; assert_eq!(*data, PAYLOAD); assert!(has_binary); } @@ -407,26 +379,21 @@ mod tests { 0, 1, 1, 255, 52, 104, 101, 108, 108, 111, 111, 111, 226, 130, 172, 1, 5, 255, 4, 1, 2, 3, 4, ]; - let (tx, rx) = tokio::sync::mpsc::channel::(10); - let mutex = Mutex::new(PeekableReceiver::new(rx)); - tx.try_send(smallvec::smallvec![Packet::Message("hellooo€".into())]) - .unwrap(); - tx.try_send(smallvec::smallvec![Packet::BinaryV3(Bytes::from_static( - &[1, 2, 3, 4] - ))]) - .unwrap(); - tx.try_send(smallvec::smallvec![Packet::Message("hello€".into())]) - .unwrap(); - tx.try_send(smallvec::smallvec![Packet::Message("hello€".into())]) - .unwrap(); + + let rx = stream::iter([ + smallvec![Packet::Message("hellooo€".into())], + smallvec![Packet::BinaryV3(Bytes::from_static(&[1, 2, 3, 4]))], + smallvec![Packet::Message("hello€".into())], + smallvec![Packet::Message("hello€".into())], + ]); + let mut rx = std::pin::pin!(rx.peekable()); + { - let rx = mutex.lock().await; - let Payload { data, .. } = v3_binary_encoder(rx, MAX_PAYLOAD).await.unwrap(); + let Payload { data, .. } = v3_binary_encoder(rx.as_mut(), MAX_PAYLOAD).await; assert_eq!(*data, PAYLOAD); } { - let rx = mutex.lock().await; - let Payload { data, .. } = v3_binary_encoder(rx, MAX_PAYLOAD).await.unwrap(); + let Payload { data, .. } = v3_binary_encoder(rx.as_mut(), MAX_PAYLOAD).await; assert_eq!(data, "7:4hello€7:4hello€".as_bytes()); } } diff --git a/crates/engineioxide/src/transport/polling/payload/mod.rs b/crates/engineioxide-core/src/payload/mod.rs similarity index 79% rename from crates/engineioxide/src/transport/polling/payload/mod.rs rename to crates/engineioxide-core/src/payload/mod.rs index cfd349d0..32e4a700 100644 --- a/crates/engineioxide/src/transport/polling/payload/mod.rs +++ b/crates/engineioxide-core/src/payload/mod.rs @@ -1,13 +1,10 @@ //! Payload encoder and decoder for polling transport. -use crate::{ - errors::Error, packet::Packet, peekable::PeekableReceiver, service::ProtocolVersion, - socket::PacketBuf, -}; +use crate::{Packet, PacketParseError, ProtocolVersion, packet::PacketBuf}; + use bytes::Bytes; -use futures_core::Stream; -use http::Request; -use tokio::sync::MutexGuard; +use futures_util::{Stream, StreamExt}; +use http::HeaderValue; mod buf; mod decoder; @@ -23,19 +20,19 @@ const STRING_PACKET_IDENTIFIER_V3: u8 = 0x00; #[cfg(feature = "v3")] const BINARY_PACKET_IDENTIFIER_V3: u8 = 0x01; +/// Decode a payload into a stream of packets. pub fn decoder( - body: Request + Unpin>, + body: impl http_body::Body + Unpin, + content_type: Option<&HeaderValue>, #[allow(unused_variables)] protocol: ProtocolVersion, max_payload: u64, -) -> impl Stream> { +) -> impl Stream> { #[cfg(feature = "v3")] { use futures_util::future::Either; - use http::header::CONTENT_TYPE; #[cfg(feature = "tracing")] - tracing::debug!("decoding payload {:?}", body.headers().get(CONTENT_TYPE)); - let is_binary = - body.headers().get(CONTENT_TYPE) == Some(&"application/octet-stream".parse().unwrap()); + tracing::debug!("decoding payload {:?}", content_type); + let is_binary = content_type == Some(&"application/octet-stream".parse().unwrap()); match protocol { ProtocolVersion::V4 => Either::Left(decoder::v4_decoder(body, max_payload)), ProtocolVersion::V3 if is_binary => { @@ -54,23 +51,28 @@ pub fn decoder( } /// A payload to transmit to the client through http polling -#[non_exhaustive] pub struct Payload { + /// The data of the payload. pub data: Bytes, + /// Whether the payload contains binary data. pub has_binary: bool, } impl Payload { + /// Creates a new payload with the given data and binary flag. pub fn new(data: Bytes, has_binary: bool) -> Self { Self { data, has_binary } } } +/// Encodes a payload into a byte stream. pub async fn encoder( - rx: MutexGuard<'_, PeekableReceiver>, + rx: impl Stream, #[allow(unused_variables)] protocol: ProtocolVersion, #[cfg(feature = "v3")] supports_binary: bool, max_payload: u64, -) -> Result { +) -> Payload { + let rx = std::pin::pin!(rx.peekable()); + #[cfg(feature = "v3")] { match protocol { diff --git a/crates/engineioxide-core/src/protocol.rs b/crates/engineioxide-core/src/protocol.rs new file mode 100644 index 00000000..1a7474c6 --- /dev/null +++ b/crates/engineioxide-core/src/protocol.rs @@ -0,0 +1,97 @@ +use std::str::FromStr; + +/// The type of `transport` used to connect to the client/server. +#[derive(Debug, Copy, Clone, PartialEq)] +pub enum TransportType { + /// Polling transport + Polling = 0x01, + /// Websocket transport + Websocket = 0x02, +} + +impl From for TransportType { + fn from(t: u8) -> Self { + match t { + 0x01 => TransportType::Polling, + 0x02 => TransportType::Websocket, + _ => panic!("unknown transport type"), + } + } +} + +/// Cannot determine the transport type to connect to the client/server. +#[derive(Debug, Copy, Clone)] +pub struct UnknownTransportError; +impl std::fmt::Display for UnknownTransportError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "unknown transport type") + } +} +impl std::error::Error for UnknownTransportError {} + +impl FromStr for TransportType { + type Err = UnknownTransportError; + + fn from_str(s: &str) -> Result { + match s { + "websocket" => Ok(TransportType::Websocket), + "polling" => Ok(TransportType::Polling), + _ => Err(UnknownTransportError), + } + } +} +impl From for &'static str { + fn from(t: TransportType) -> Self { + match t { + TransportType::Polling => "polling", + TransportType::Websocket => "websocket", + } + } +} +impl From for String { + fn from(t: TransportType) -> Self { + match t { + TransportType::Polling => "polling".into(), + TransportType::Websocket => "websocket".into(), + } + } +} + +#[derive(Debug)] +pub struct UnknownProtocolVersionError; +impl std::fmt::Display for UnknownProtocolVersionError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "unknown protocol version") + } +} +impl std::error::Error for UnknownProtocolVersionError {} + +/// The engine.io protocol version +#[derive(Debug, Copy, Clone, PartialEq)] +pub enum ProtocolVersion { + /// The protocol version 3 + V3 = 3, + /// The protocol version 4 + V4 = 4, +} + +impl FromStr for ProtocolVersion { + type Err = UnknownProtocolVersionError; + + #[cfg(feature = "v3")] + fn from_str(s: &str) -> Result { + match s { + "3" => Ok(ProtocolVersion::V3), + "4" => Ok(ProtocolVersion::V4), + _ => Err(UnknownProtocolVersionError), + } + } + + #[cfg(not(feature = "v3"))] + fn from_str(s: &str) -> Result { + match s { + "4" => Ok(ProtocolVersion::V4), + _ => Err(UnknownProtocolVersionError), + } + } +} diff --git a/crates/engineioxide/Cargo.toml b/crates/engineioxide/Cargo.toml index 0490fa58..be2d3ab2 100644 --- a/crates/engineioxide/Cargo.toml +++ b/crates/engineioxide/Cargo.toml @@ -32,22 +32,15 @@ tokio = { workspace = true, features = ["rt", "time"] } tower-service.workspace = true tower-layer.workspace = true hyper.workspace = true -tokio-tungstenite.workspace = true +tokio-tungstenite = { workspace = true, features = ["handshake"] } http-body-util.workspace = true pin-project-lite.workspace = true smallvec.workspace = true hyper-util = { workspace = true, features = ["tokio"] } -base64 = "0.22" - # Tracing tracing = { workspace = true, optional = true } -# Engine.io V3 payload -itoa = { workspace = true, optional = true } -memchr = { version = "2.7", optional = true } -unicode-segmentation = { version = "1.12", optional = true } - [dev-dependencies] tokio = { workspace = true, features = ["macros", "parking_lot"] } tracing-subscriber.workspace = true @@ -56,8 +49,9 @@ criterion.workspace = true axum.workspace = true tokio-stream.workspace = true tokio-util.workspace = true + [features] -v3 = ["memchr", "unicode-segmentation", "itoa"] +v3 = ["engineioxide-core/v3"] tracing = ["dep:tracing"] __test_harness = [] diff --git a/crates/engineioxide/benches/packet_encode.rs b/crates/engineioxide/benches/packet_encode.rs index 0cfc8842..f1d8f483 100644 --- a/crates/engineioxide/benches/packet_encode.rs +++ b/crates/engineioxide/benches/packet_encode.rs @@ -1,15 +1,11 @@ use bytes::Bytes; use criterion::{BatchSize, Criterion, black_box, criterion_group, criterion_main}; -use engineioxide::{OpenPacket, Packet, TransportType, config::EngineIoConfig, socket::Sid}; +use engineioxide::{OpenPacket, Packet}; fn criterion_benchmark(c: &mut Criterion) { let mut group = c.benchmark_group("engineio_packet/encode"); group.bench_function("Encode packet open", |b| { - let packet = Packet::Open(OpenPacket::new( - black_box(TransportType::Polling), - black_box(Sid::ZERO), - &EngineIoConfig::default(), - )); + let packet = Packet::Open(OpenPacket::default()); b.iter_batched( || packet.clone(), TryInto::::try_into, diff --git a/crates/engineioxide/src/config.rs b/crates/engineioxide/src/config.rs index 3eeb1fbc..abb8dece 100644 --- a/crates/engineioxide/src/config.rs +++ b/crates/engineioxide/src/config.rs @@ -32,7 +32,7 @@ use std::{borrow::Cow, time::Duration}; -use crate::service::TransportType; +use engineioxide_core::TransportType; /// Configuration for the engine.io engine & transports #[derive(Debug, Clone)] diff --git a/crates/engineioxide/src/engine.rs b/crates/engineioxide/src/engine.rs index 25f2d573..38e2e19d 100644 --- a/crates/engineioxide/src/engine.rs +++ b/crates/engineioxide/src/engine.rs @@ -3,13 +3,12 @@ use std::{ sync::{Arc, RwLock}, }; -use engineioxide_core::Sid; +use engineioxide_core::{ProtocolVersion, Sid, TransportType}; use http::request::Parts; use crate::{ config::EngineIoConfig, handler::EngineIoHandler, - service::{ProtocolVersion, TransportType}, socket::{DisconnectReason, Socket}, }; diff --git a/crates/engineioxide/src/errors.rs b/crates/engineioxide/src/errors.rs index 90b976dc..38a6495e 100644 --- a/crates/engineioxide/src/errors.rs +++ b/crates/engineioxide/src/errors.rs @@ -3,48 +3,38 @@ use tokio::sync::mpsc; use tokio_tungstenite::tungstenite; use crate::body::ResponseBody; -use crate::packet::Packet; -use engineioxide_core::Sid; +use engineioxide_core::{Packet, Sid}; + +pub use engineioxide_core::PacketParseError; #[derive(thiserror::Error, Debug)] pub enum Error { - #[error("error decoding binary packet from polling request: {0:?}")] - Base64(#[from] base64::DecodeError), - #[error("error decoding packet: {0:?}")] - StrUtf8(#[from] std::str::Utf8Error), - #[error("io error: {0:?}")] - Io(#[from] std::io::Error), + #[error("error decoding packet from request: {0}")] + PacketParse(#[from] PacketParseError), #[error("bad packet received")] BadPacket(Packet), - #[error("ws transport error: {0:?}")] + #[error("ws transport error: {0}")] WsTransport(#[from] Box), - #[error("http error: {0:?}")] + #[error("http error: {0}")] Http(#[from] http::Error), - #[error("internal channel error: {0:?}")] + #[error("internal channel error: {0}")] SendChannel(#[from] mpsc::error::TrySendError), - #[error("internal channel error: {0:?}")] + #[error("internal channel error: {0}")] RecvChannel(#[from] mpsc::error::TryRecvError), #[error("heartbeat timeout")] HeartbeatTimeout, #[error("upgrade error")] Upgrade, - #[error("aborted connection")] - Aborted, - #[error("http error response: {0:?}")] - HttpErrorResponse(StatusCode), + #[error("multiple http polling error")] + MultipleHttpPolling, + #[error("invalid websocket Sec-WebSocket-Key http header")] + InvalidWebSocketKey, #[error("unknown session id")] UnknownSessionID(Sid), #[error("transport mismatch")] TransportMismatch, - #[error("payload too large")] - PayloadTooLarge, - - #[error("Invalid packet length")] - InvalidPacketLength, - #[error("Invalid packet type")] - InvalidPacketType(Option), } /// Convert an error into an http response @@ -60,18 +50,15 @@ impl From for Response> { .unwrap() }; match err { - Error::HttpErrorResponse(code) => Response::builder() - .status(code) + Error::PacketParse(PacketParseError::PayloadTooLarge { .. }) => Response::builder() + .status(413) .body(ResponseBody::empty_response()) .unwrap(), - Error::BadPacket(_) | Error::InvalidPacketLength | Error::InvalidPacketType(_) => { - Response::builder() - .status(400) - .body(ResponseBody::empty_response()) - .unwrap() - } - Error::PayloadTooLarge => Response::builder() - .status(413) + Error::BadPacket(_) + | Error::PacketParse(_) + | Error::MultipleHttpPolling + | Error::InvalidWebSocketKey => Response::builder() + .status(400) .body(ResponseBody::empty_response()) .unwrap(), diff --git a/crates/engineioxide/src/lib.rs b/crates/engineioxide/src/lib.rs index 3bf3469a..b0c9f0ef 100644 --- a/crates/engineioxide/src/lib.rs +++ b/crates/engineioxide/src/lib.rs @@ -31,13 +31,12 @@ )] #![doc = include_str!("../Readme.md")] -pub use engineioxide_core::Str; -pub use service::{ProtocolVersion, TransportType}; +pub use engineioxide_core::{ProtocolVersion, Str, TransportType}; pub use socket::{DisconnectReason, Socket}; #[doc(hidden)] #[cfg(feature = "__test_harness")] -pub use packet::*; +pub use engineioxide_core::{OpenPacket, Packet, PacketParseError}; pub mod config; pub mod handler; @@ -54,6 +53,4 @@ pub mod sid { mod body; mod engine; mod errors; -mod packet; -mod peekable; mod transport; diff --git a/crates/engineioxide/src/peekable.rs b/crates/engineioxide/src/peekable.rs deleted file mode 100644 index bc8ddb6b..00000000 --- a/crates/engineioxide/src/peekable.rs +++ /dev/null @@ -1,79 +0,0 @@ -use tokio::sync::mpsc::{Receiver, error::TryRecvError}; - -/// Peekable receiver for polling transport -/// It is a thin wrapper around a [`Receiver`](tokio::sync::mpsc::Receiver) that allows to peek the next packet without consuming it -/// -/// Its main goal is to be able to peek the next packet without consuming it to calculate the -/// packet length when using polling transport to check if it fits according to the max_payload setting -#[derive(Debug)] -pub struct PeekableReceiver { - rx: Receiver, - next: Option, -} -impl PeekableReceiver { - pub fn new(rx: Receiver) -> Self { - Self { rx, next: None } - } - pub fn peek(&mut self) -> Option<&T> { - if self.next.is_none() { - self.next = self.rx.try_recv().ok(); - } - self.next.as_ref() - } - pub async fn recv(&mut self) -> Option { - if self.next.is_none() { - self.rx.recv().await - } else { - self.next.take() - } - } - pub fn try_recv(&mut self) -> Result { - if self.next.is_none() { - self.rx.try_recv() - } else { - Ok(self.next.take().unwrap()) - } - } - - pub fn close(&mut self) { - self.rx.close() - } -} - -#[cfg(test)] -mod tests { - use tokio::sync::Mutex; - - #[tokio::test] - async fn peek() { - use super::PeekableReceiver; - use crate::packet::Packet; - use tokio::sync::mpsc::channel; - - let (tx, rx) = channel(1); - let rx = Mutex::new(PeekableReceiver::new(rx)); - let mut rx = rx.lock().await; - - assert!(rx.peek().is_none()); - - tx.send(Packet::Ping).await.unwrap(); - assert_eq!(rx.peek(), Some(&Packet::Ping)); - assert_eq!(rx.recv().await, Some(Packet::Ping)); - assert!(rx.peek().is_none()); - - tx.send(Packet::Pong).await.unwrap(); - assert_eq!(rx.peek(), Some(&Packet::Pong)); - assert_eq!(rx.recv().await, Some(Packet::Pong)); - assert!(rx.peek().is_none()); - - tx.send(Packet::Close).await.unwrap(); - assert_eq!(rx.peek(), Some(&Packet::Close)); - assert_eq!(rx.recv().await, Some(Packet::Close)); - assert!(rx.peek().is_none()); - - tx.send(Packet::Close).await.unwrap(); - assert_eq!(rx.peek(), Some(&Packet::Close)); - assert_eq!(rx.recv().await, Some(Packet::Close)); - assert!(rx.peek().is_none()); - } -} diff --git a/crates/engineioxide/src/service/mod.rs b/crates/engineioxide/src/service/mod.rs index 4bce8fe6..6b9d1c13 100644 --- a/crates/engineioxide/src/service/mod.rs +++ b/crates/engineioxide/src/service/mod.rs @@ -33,6 +33,8 @@ use std::{ }; use bytes::Bytes; +#[cfg(feature = "__test_harness")] +use engineioxide_core::ProtocolVersion; use futures_util::future::{self, Ready}; use http::{Request, Response}; use http_body::Body; @@ -47,7 +49,6 @@ use crate::{ mod futures; mod parser; -pub use self::parser::{ProtocolVersion, TransportType}; use self::{futures::ResponseFuture, parser::dispatch_req}; /// A `Service` that handles engine.io requests as a middleware. diff --git a/crates/engineioxide/src/service/parser.rs b/crates/engineioxide/src/service/parser.rs index e911d3d5..c015f5e2 100644 --- a/crates/engineioxide/src/service/parser.rs +++ b/crates/engineioxide/src/service/parser.rs @@ -1,10 +1,10 @@ //! A Parser module to parse any `EngineIo` query -use std::{future::Future, str::FromStr, sync::Arc}; +use std::{future::Future, sync::Arc}; use http::{Method, Request, Response}; -use engineioxide_core::Sid; +use engineioxide_core::{ProtocolVersion, Sid, TransportType}; use crate::{ body::ResponseBody, @@ -90,6 +90,11 @@ pub enum ParseError { #[error("unsupported protocol version")] UnsupportedProtocolVersion, } +impl From for ParseError { + fn from(_: engineioxide_core::UnknownTransportError) -> Self { + ParseError::UnknownTransport + } +} /// Convert an error into an http response /// If it is a known error, return the appropriate http status code @@ -117,83 +122,6 @@ impl From for Response> { } } -/// The engine.io protocol version -#[derive(Debug, Copy, Clone, PartialEq)] -pub enum ProtocolVersion { - /// The protocol version 3 - V3 = 3, - /// The protocol version 4 - V4 = 4, -} - -impl FromStr for ProtocolVersion { - type Err = ParseError; - - #[cfg(feature = "v3")] - fn from_str(s: &str) -> Result { - match s { - "3" => Ok(ProtocolVersion::V3), - "4" => Ok(ProtocolVersion::V4), - _ => Err(ParseError::UnsupportedProtocolVersion), - } - } - - #[cfg(not(feature = "v3"))] - fn from_str(s: &str) -> Result { - match s { - "4" => Ok(ProtocolVersion::V4), - _ => Err(ParseError::UnsupportedProtocolVersion), - } - } -} - -/// The type of `transport` used by the client. -#[derive(Debug, Copy, Clone, PartialEq)] -pub enum TransportType { - /// Polling transport - Polling = 0x01, - /// Websocket transport - Websocket = 0x02, -} - -impl From for TransportType { - fn from(t: u8) -> Self { - match t { - 0x01 => TransportType::Polling, - 0x02 => TransportType::Websocket, - _ => panic!("unknown transport type"), - } - } -} - -impl FromStr for TransportType { - type Err = ParseError; - - fn from_str(s: &str) -> Result { - match s { - "websocket" => Ok(TransportType::Websocket), - "polling" => Ok(TransportType::Polling), - _ => Err(ParseError::UnknownTransport), - } - } -} -impl From for &'static str { - fn from(t: TransportType) -> Self { - match t { - TransportType::Polling => "polling", - TransportType::Websocket => "websocket", - } - } -} -impl From for String { - fn from(t: TransportType) -> Self { - match t { - TransportType::Polling => "polling".into(), - TransportType::Websocket => "websocket".into(), - } - } -} - /// The request information extracted from the request URI. #[derive(Debug)] pub struct RequestInfo { @@ -220,8 +148,8 @@ impl RequestInfo { .split('&') .find(|s| s.starts_with("EIO=")) .and_then(|s| s.split('=').nth(1)) - .ok_or(UnsupportedProtocolVersion) - .and_then(|t| t.parse())?; + .and_then(|t| t.parse().ok()) + .ok_or(UnsupportedProtocolVersion)?; let sid = query .split('&') @@ -233,8 +161,8 @@ impl RequestInfo { .split('&') .find(|s| s.starts_with("transport=")) .and_then(|s| s.split('=').nth(1)) - .ok_or(UnknownTransport) - .and_then(|t| t.parse())?; + .and_then(|t| t.parse().ok()) + .ok_or(UnknownTransport)?; if !config.allowed_transport(transport) { return Err(TransportMismatch); diff --git a/crates/engineioxide/src/socket.rs b/crates/engineioxide/src/socket.rs index 56776426..846233d6 100644 --- a/crates/engineioxide/src/socket.rs +++ b/crates/engineioxide/src/socket.rs @@ -63,15 +63,9 @@ use std::{ time::Duration, }; -use crate::{ - config::EngineIoConfig, - errors::Error, - packet::Packet, - peekable::PeekableReceiver, - service::{ProtocolVersion, TransportType}, -}; +use crate::{config::EngineIoConfig, errors::Error}; use bytes::Bytes; -use engineioxide_core::Str; +use engineioxide_core::{Packet, PacketBuf, ProtocolVersion, Str, TransportType}; use http::request::Parts; use smallvec::{SmallVec, smallvec}; use tokio::{ @@ -110,9 +104,8 @@ impl From<&Error> for Option { fn from(err: &Error) -> Self { use Error::*; match err { - WsTransport(_) | Io(_) => Some(DisconnectReason::TransportError), - BadPacket(_) | Base64(_) | StrUtf8(_) | PayloadTooLarge | InvalidPacketLength - | InvalidPacketType(_) => Some(DisconnectReason::PacketParsingError), + WsTransport(_) => Some(DisconnectReason::TransportError), + BadPacket(_) | PacketParse(_) => Some(DisconnectReason::PacketParsingError), HeartbeatTimeout => Some(DisconnectReason::HeartbeatTimeout), _ => None, } @@ -161,13 +154,6 @@ impl Permit<'_> { } } -/// Buffered packets to send to the client. -/// It is used to ensure atomicity when sending multiple packets to the client. -/// -/// The [`PacketBuf`] stack size will impact the dynamically allocated buffer -/// of the internal mpsc channel. -pub(crate) type PacketBuf = SmallVec<[Packet; 2]>; - /// A [`Socket`] represents a client connection to the server. /// It is agnostic to the [`TransportType`]. /// @@ -209,7 +195,7 @@ where /// Because with polling transport, if the client is not currently polling then the encoder will never be able to close the channel /// /// The channel is made of a [`SmallVec`] of [`Packet`]s so that adjacent packets can be sent atomically. - pub(crate) internal_rx: Mutex>, + pub(crate) internal_rx: Mutex>, /// Channel to send [PacketBuf] to the internal connection internal_tx: mpsc::Sender, @@ -257,7 +243,7 @@ where transport: AtomicU8::new(transport as u8), upgrading: AtomicBool::new(false), - internal_rx: Mutex::new(PeekableReceiver::new(internal_rx)), + internal_rx: Mutex::new(internal_rx), internal_tx, heartbeat_rx: Mutex::new(heartbeat_rx), @@ -556,7 +542,7 @@ where transport: AtomicU8::new(TransportType::Websocket as u8), upgrading: AtomicBool::new(false), - internal_rx: Mutex::new(PeekableReceiver::new(internal_rx)), + internal_rx: Mutex::new(internal_rx), internal_tx, heartbeat_rx: Mutex::new(heartbeat_rx), diff --git a/crates/engineioxide/src/transport/mod.rs b/crates/engineioxide/src/transport/mod.rs index 94d843a5..bbdfc589 100644 --- a/crates/engineioxide/src/transport/mod.rs +++ b/crates/engineioxide/src/transport/mod.rs @@ -1,4 +1,23 @@ //! All transports modules available in engineioxide +use engineioxide_core::{OpenPacket, Sid, TransportType}; + +use crate::config::EngineIoConfig; + pub mod polling; pub mod ws; + +fn make_open_packet(transport: TransportType, id: Sid, config: &EngineIoConfig) -> OpenPacket { + let upgrades = if transport == TransportType::Polling { + vec!["websocket".to_string()] + } else { + vec![] + }; + OpenPacket { + sid: id, + upgrades, + ping_timeout: config.ping_timeout.as_millis() as u64, + ping_interval: config.ping_interval.as_millis() as u64, + max_payload: config.max_payload, + } +} diff --git a/crates/engineioxide/src/transport/polling/mod.rs b/crates/engineioxide/src/transport/polling.rs similarity index 77% rename from crates/engineioxide/src/transport/polling/mod.rs rename to crates/engineioxide/src/transport/polling.rs index e5207d35..c759ca0a 100644 --- a/crates/engineioxide/src/transport/polling/mod.rs +++ b/crates/engineioxide/src/transport/polling.rs @@ -2,26 +2,19 @@ use std::sync::Arc; use bytes::Bytes; +use engineioxide_core::payload::{self, Payload}; use futures_util::StreamExt; -use http::{Request, Response, StatusCode}; +use http::{Request, Response, StatusCode, header::CONTENT_TYPE}; use http_body::Body; use http_body_util::Full; -use engineioxide_core::Sid; +use engineioxide_core::{Packet, ProtocolVersion, Sid, TransportType}; use crate::{ - DisconnectReason, - body::ResponseBody, - engine::EngineIo, - errors::Error, - handler::EngineIoHandler, - packet::{OpenPacket, Packet}, - service::{ProtocolVersion, TransportType}, - transport::polling::payload::Payload, + DisconnectReason, body::ResponseBody, engine::EngineIo, errors::Error, + handler::EngineIoHandler, transport::make_open_packet, }; -mod payload; - /// Create a response for http request fn http_response( code: StatusCode, @@ -62,7 +55,7 @@ where supports_binary, ); - let packet = OpenPacket::new(TransportType::Polling, socket.id, &engine.config); + let packet = make_open_packet(TransportType::Polling, socket.id, &engine.config); socket.spawn_heartbeat(engine.config.ping_interval, engine.config.ping_timeout); @@ -117,24 +110,29 @@ where // If the socket is already locked, it means that the socket is being used by another request // In case of multiple http polling, session should be closed - let rx = match socket.internal_rx.try_lock() { + let mut rx = match socket.internal_rx.try_lock() { Ok(s) => s, Err(_) => { socket.close(DisconnectReason::MultipleHttpPollingError); - return Err(Error::HttpErrorResponse(StatusCode::BAD_REQUEST)); + return Err(Error::MultipleHttpPolling); } }; + //TODO: handle closing channel packet better than in the encoding process. + #[cfg(feature = "tracing")] tracing::debug!("[sid={sid}] polling request"); let max_payload = engine.config.max_payload; + let rx = rx_stream::ReceiverStream::new(&mut rx); + #[cfg(feature = "v3")] let Payload { data, has_binary } = - payload::encoder(rx, protocol, socket.supports_binary, max_payload).await?; + payload::encoder(rx, protocol, socket.supports_binary, max_payload).await; + #[cfg(not(feature = "v3"))] - let Payload { data, has_binary } = payload::encoder(rx, protocol, max_payload).await?; + let Payload { data, has_binary } = payload::encoder(rx, protocol, max_payload).await; #[cfg(feature = "tracing")] tracing::debug!("[sid={sid}] sending data: {:?}", data); @@ -148,7 +146,7 @@ pub async fn post_req( engine: Arc>, protocol: ProtocolVersion, sid: Sid, - body: Request, + req: Request, ) -> Result>, Error> where H: EngineIoHandler, @@ -162,7 +160,9 @@ where return Err(Error::TransportMismatch); } - let packets = payload::decoder(body, protocol, engine.config.max_payload); + let (parts, body) = req.into_parts(); + let content_type = parts.headers.get(CONTENT_TYPE); + let packets = payload::decoder(body, content_type, protocol, engine.config.max_payload); futures_util::pin_mut!(packets); while let Some(packet) = packets.next().await { @@ -195,9 +195,39 @@ where #[cfg(feature = "tracing")] tracing::debug!("[sid={sid}] error parsing packet: {:?}", e); engine.close_session(sid, DisconnectReason::PacketParsingError); - return Err(e); + return Err(e.into()); } }?; } Ok(http_response(StatusCode::OK, "ok", false)?) } + +mod rx_stream { + use std::{ + pin::Pin, + task::{Context, Poll}, + }; + + use futures_core::Stream; + use tokio::sync::mpsc::Receiver; + + /// [`ReceiverStream`] is a stream that wraps a tokio::sync::mpsc::Receiver by reference. + /// Allowing to use it as a stream even if it is behind a mutex. + pub struct ReceiverStream<'a, T> { + inner: &'a mut Receiver, + } + + impl<'a, T> ReceiverStream<'a, T> { + pub fn new(inner: &'a mut Receiver) -> Self { + Self { inner } + } + } + + impl<'a, T> Stream for ReceiverStream<'a, T> { + type Item = T; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_recv(cx) + } + } +} diff --git a/crates/engineioxide/src/transport/ws.rs b/crates/engineioxide/src/transport/ws.rs index 29b6c7ce..38156037 100644 --- a/crates/engineioxide/src/transport/ws.rs +++ b/crates/engineioxide/src/transport/ws.rs @@ -18,24 +18,17 @@ use tokio::{ use tokio_tungstenite::{ WebSocketStream, tungstenite::{ - Message, + self, Message, handshake::derive_accept_key, protocol::{Role, WebSocketConfig}, }, }; -use engineioxide_core::Sid; +use engineioxide_core::{Packet, ProtocolVersion, Sid, Str, TransportType}; use crate::{ - DisconnectReason, Socket, - body::ResponseBody, - config::EngineIoConfig, - engine::EngineIo, - errors::Error, - handler::EngineIoHandler, - packet::{OpenPacket, Packet}, - service::ProtocolVersion, - service::TransportType, + DisconnectReason, Socket, body::ResponseBody, config::EngineIoConfig, engine::EngineIo, + errors::Error, handler::EngineIoHandler, transport::make_open_packet, }; /// Create a response for websocket upgrade @@ -69,7 +62,7 @@ pub fn new_req( let ws_key = parts .headers .get("Sec-WebSocket-Key") - .ok_or(Error::HttpErrorResponse(StatusCode::BAD_REQUEST))? + .ok_or(Error::InvalidWebSocketKey)? .clone(); tokio::spawn(async move { @@ -171,7 +164,7 @@ where { while let Some(msg) = rx.try_next().await? { match msg { - Message::Text(msg) => match Packet::try_from(msg)? { + Message::Text(msg) => match Packet::try_from(ws_bytes_to_str(msg))? { Packet::Close => { #[cfg(feature = "tracing")] tracing::debug!("[sid={}] closing session", socket.id); @@ -283,7 +276,8 @@ async fn init_handshake( where S: AsyncRead + AsyncWrite + Unpin + Send + 'static, { - let packet = Packet::Open(OpenPacket::new(TransportType::Websocket, sid, config)); + let packet = Packet::Open(make_open_packet(TransportType::Websocket, sid, config)); + let packet: String = packet.into(); ws.send(Message::Text(packet.into())).await?; Ok(()) } @@ -338,13 +332,13 @@ where Some(Ok(Message::Text(d))) => d, _ => Err(Error::Upgrade)?, }; - match Packet::try_from(msg)? { + match Packet::try_from(ws_bytes_to_str(msg))? { Packet::PingUpgrade => { #[cfg(feature = "tracing")] tracing::debug!("received first ping upgrade"); - // Respond with a PongUpgrade packet - ws.send(Message::Text(Packet::PongUpgrade.into())).await?; + ws.send(Message::Text(String::from(Packet::PongUpgrade).into())) + .await?; } p => Err(Error::BadPacket(p))?, }; @@ -363,7 +357,7 @@ where Err(Error::Upgrade)? } }; - match Packet::try_from(msg)? { + match Packet::try_from(ws_bytes_to_str(msg))? { Packet::Upgrade => { #[cfg(feature = "tracing")] tracing::debug!("ws upgraded successfully") @@ -374,3 +368,9 @@ where socket.upgrade_to_websocket(); Ok(()) } + +fn ws_bytes_to_str(bytes: tungstenite::Utf8Bytes) -> Str { + // SAFETY: We are converting a valid UTF-8 byte slice + // to a string without checking its validity. + unsafe { Str::from_bytes_unchecked(bytes.into()) } +}