diff --git a/crates/engineioxide/src/socket.rs b/crates/engineioxide/src/socket.rs index 222e3d8f..33be44f3 100644 --- a/crates/engineioxide/src/socket.rs +++ b/crates/engineioxide/src/socket.rs @@ -539,6 +539,10 @@ where s } + pub async fn close_internal_rx(self: Arc) { + self.internal_rx.lock().await.close(); + } + /// Create a dummy socket for testing purpose with a /// receiver to get the packets sent to the client pub fn new_dummy_piped( diff --git a/crates/socketioxide/src/client.rs b/crates/socketioxide/src/client.rs index e4ac7bfd..674bf771 100644 --- a/crates/socketioxide/src/client.rs +++ b/crates/socketioxide/src/client.rs @@ -113,6 +113,67 @@ impl Client { } } + //just add latency to connect for test purpose + #[cfg(feature = "__test_harness")] + #[allow(unused)] + fn sock_connect_for_test( + self: &Arc, + auth: Option, + ns_path: &str, + esocket: &Arc>>, + ) { + #[cfg(feature = "tracing")] + tracing::debug!("auth: {:?}", auth); + let protocol: ProtocolVersion = esocket.protocol.into(); + let connect = async move |ns: Arc>, esocket: Arc>>| { + // add latency to connect + tokio::time::sleep(tokio::time::Duration::from_secs(3)).await; + if ns.connect(esocket.id, esocket.clone(), auth).await.is_ok() { + // cancel the connect timeout task for v5 + if let Some(tx) = esocket.data.connect_recv_tx.lock().unwrap().take() { + tx.send(()).ok(); + } + } + }; + + if let Some(ns) = self.get_ns(ns_path) { + tokio::spawn(connect(ns, esocket.clone())); + } else if let Ok(Match { value: ns_ctr, .. }) = self.router.read().unwrap().at(ns_path) { + let path = Str::copy_from_slice(ns_path); + let ns = ns_ctr.get_new_ns(path.clone(), &self.adapter_state, &self.config); + let this = self.clone(); + let esocket = esocket.clone(); + let adapter = ns.adapter.clone(); + let on_success = move || { + this.nsps.write().unwrap().insert(path, ns.clone()); + tokio::spawn(connect(ns, esocket)); + }; + // We "ask" the adapter implementation to manage the init response itself + socketioxide_core::adapter::Spawnable::spawn(adapter.init(on_success)); + } else if protocol == ProtocolVersion::V4 && ns_path == "/" { + #[cfg(feature = "tracing")] + tracing::error!( + "the root namespace \"/\" must be defined before any connection for protocol V4 (legacy)!" + ); + esocket.close(EIoDisconnectReason::TransportClose); + } else { + let path = Str::copy_from_slice(ns_path); + let packet = self + .parser() + .encode(Packet::connect_error(path, "Invalid namespace")); + let _ = match packet { + Value::Str(p, _) => esocket.emit(p).map_err(|_e| { + #[cfg(feature = "tracing")] + tracing::error!("error while sending invalid namespace packet: {}", _e); + }), + Value::Bytes(p) => esocket.emit_binary(p).map_err(|_e| { + #[cfg(feature = "tracing")] + tracing::error!("error while sending invalid namespace packet: {}", _e); + }), + }; + } + } + /// Propagate a packet to its target namespace fn sock_propagate_packet(&self, packet: Packet, sid: Sid) -> Result<(), Error> { if let Some(ns) = self.get_ns(&packet.ns) { @@ -426,6 +487,7 @@ mod test { use tokio::sync::mpsc; use crate::adapter::LocalAdapter; + use std::time::Duration; const CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(50); fn create_client() -> Arc> { @@ -490,4 +552,50 @@ mod test { .await .unwrap_err(); } + #[derive(Debug, Clone)] + struct MockConnectTimeoutHandler(Arc>); + + impl EngineIoHandler for MockConnectTimeoutHandler { + type Data = as engineioxide::handler::EngineIoHandler>::Data; + + fn on_connect(self: Arc, socket: Arc>) { + socket.data.io.set(SocketIo::from(self.0.clone())).ok(); + + self.0.sock_connect_for_test(None, "/", &socket); + } + + fn on_disconnect(&self, socket: Arc>, reason: EIoDisconnectReason) { + self.0.clone().on_disconnect(socket, reason); + } + + fn on_message(self: &Arc, msg: Str, socket: Arc>) { + self.0.clone().on_message(msg, socket); + } + + fn on_binary(self: &Arc, data: Bytes, socket: Arc>) { + self.0.clone().on_binary(data, socket); + } + } + + #[tokio::test] + async fn should_not_reserve_socket_if_connect_time_out() { + let client = create_client(); + let client = Arc::new(MockConnectTimeoutHandler(client)); + let sid = Sid::new(); + let sock = EIoSocket::new_dummy(sid, Box::new(move |_, _| {})); + // connect to ns but spawned fn `connect` will be stuck + client.clone().on_connect(sock.clone()); + //spawned fn `connect` is stuck ,so client don't keep heartbeat + client + .clone() + .on_disconnect(sock.clone(), EIoDisconnectReason::HeartbeatTimeout); + //equal to engineio.close_session(_,_) + sock.close_internal_rx().await; + //wait for spawned fn `connect(crates/socketioxide/src/client.rs:129)` finish + tokio::time::sleep(Duration::from_secs(4)).await; + let guard = client.0.nsps.read().unwrap(); + let ns = guard.get("/").unwrap(); + // ns should not reserve socket + assert!(ns.get_socket(sid).is_err()); + } } diff --git a/crates/socketioxide/src/ns.rs b/crates/socketioxide/src/ns.rs index 51365cc6..221a7d95 100644 --- a/crates/socketioxide/src/ns.rs +++ b/crates/socketioxide/src/ns.rs @@ -134,6 +134,15 @@ impl Namespace { #[cfg(feature = "tracing")] tracing::debug!("error sending connect packet: {:?}, closing conn", _e); esocket.close(engineioxide::DisconnectReason::PacketParsingError); + // also remove sid inserted before + self.sockets + .write() + .map_err(|_| { + #[cfg(feature = "tracing")] + tracing::debug!("get lock err for {sid}"); + ConnectFail + })? + .remove(&sid); return Err(ConnectFail); }