Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions crates/engineioxide/src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,10 @@ where
s
}

pub async fn close_internal_rx(self: Arc<Self>) {
self.internal_rx.lock().await.close();
}

/// Create a dummy socket for testing purpose with a
/// receiver to get the packets sent to the client
pub fn new_dummy_piped(
Expand Down
108 changes: 108 additions & 0 deletions crates/socketioxide/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,67 @@ impl<A: Adapter> Client<A> {
}
}

//just add latency to connect for test purpose
#[cfg(feature = "__test_harness")]
#[allow(unused)]
fn sock_connect_for_test(
self: &Arc<Self>,
auth: Option<Value>,
ns_path: &str,
esocket: &Arc<engineioxide::Socket<SocketData<A>>>,
) {
#[cfg(feature = "tracing")]
tracing::debug!("auth: {:?}", auth);
let protocol: ProtocolVersion = esocket.protocol.into();
let connect = async move |ns: Arc<Namespace<A>>, esocket: Arc<EIoSocket<SocketData<A>>>| {
// add latency to connect
tokio::time::sleep(tokio::time::Duration::from_secs(3)).await;
if ns.connect(esocket.id, esocket.clone(), auth).await.is_ok() {
// cancel the connect timeout task for v5
if let Some(tx) = esocket.data.connect_recv_tx.lock().unwrap().take() {
tx.send(()).ok();
}
}
};

if let Some(ns) = self.get_ns(ns_path) {
tokio::spawn(connect(ns, esocket.clone()));
} else if let Ok(Match { value: ns_ctr, .. }) = self.router.read().unwrap().at(ns_path) {
let path = Str::copy_from_slice(ns_path);
let ns = ns_ctr.get_new_ns(path.clone(), &self.adapter_state, &self.config);
let this = self.clone();
let esocket = esocket.clone();
let adapter = ns.adapter.clone();
let on_success = move || {
this.nsps.write().unwrap().insert(path, ns.clone());
tokio::spawn(connect(ns, esocket));
};
// We "ask" the adapter implementation to manage the init response itself
socketioxide_core::adapter::Spawnable::spawn(adapter.init(on_success));
} else if protocol == ProtocolVersion::V4 && ns_path == "/" {
#[cfg(feature = "tracing")]
tracing::error!(
"the root namespace \"/\" must be defined before any connection for protocol V4 (legacy)!"
);
esocket.close(EIoDisconnectReason::TransportClose);
} else {
let path = Str::copy_from_slice(ns_path);
let packet = self
.parser()
.encode(Packet::connect_error(path, "Invalid namespace"));
let _ = match packet {
Value::Str(p, _) => esocket.emit(p).map_err(|_e| {
#[cfg(feature = "tracing")]
tracing::error!("error while sending invalid namespace packet: {}", _e);
}),
Value::Bytes(p) => esocket.emit_binary(p).map_err(|_e| {
#[cfg(feature = "tracing")]
tracing::error!("error while sending invalid namespace packet: {}", _e);
}),
};
}
}

/// Propagate a packet to its target namespace
fn sock_propagate_packet(&self, packet: Packet, sid: Sid) -> Result<(), Error> {
if let Some(ns) = self.get_ns(&packet.ns) {
Expand Down Expand Up @@ -426,6 +487,7 @@ mod test {
use tokio::sync::mpsc;

use crate::adapter::LocalAdapter;
use std::time::Duration;
const CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(50);

fn create_client() -> Arc<super::Client<LocalAdapter>> {
Expand Down Expand Up @@ -490,4 +552,50 @@ mod test {
.await
.unwrap_err();
}
#[derive(Debug, Clone)]
struct MockConnectTimeoutHandler(Arc<super::Client<LocalAdapter>>);

impl EngineIoHandler for MockConnectTimeoutHandler {
type Data = <Client<LocalAdapter> as engineioxide::handler::EngineIoHandler>::Data;

fn on_connect(self: Arc<Self>, socket: Arc<EIoSocket<Self::Data>>) {
socket.data.io.set(SocketIo::from(self.0.clone())).ok();

self.0.sock_connect_for_test(None, "/", &socket);
}

fn on_disconnect(&self, socket: Arc<EIoSocket<Self::Data>>, reason: EIoDisconnectReason) {
self.0.clone().on_disconnect(socket, reason);
}

fn on_message(self: &Arc<Self>, msg: Str, socket: Arc<EIoSocket<Self::Data>>) {
self.0.clone().on_message(msg, socket);
}

fn on_binary(self: &Arc<Self>, data: Bytes, socket: Arc<EIoSocket<Self::Data>>) {
self.0.clone().on_binary(data, socket);
}
}

#[tokio::test]
async fn should_not_reserve_socket_if_connect_time_out() {
let client = create_client();
let client = Arc::new(MockConnectTimeoutHandler(client));
let sid = Sid::new();
let sock = EIoSocket::new_dummy(sid, Box::new(move |_, _| {}));
// connect to ns but spawned fn `connect` will be stuck
client.clone().on_connect(sock.clone());
//spawned fn `connect` is stuck ,so client don't keep heartbeat
client
.clone()
.on_disconnect(sock.clone(), EIoDisconnectReason::HeartbeatTimeout);
//equal to engineio.close_session(_,_)
sock.close_internal_rx().await;
//wait for spawned fn `connect(crates/socketioxide/src/client.rs:129)` finish
tokio::time::sleep(Duration::from_secs(4)).await;
let guard = client.0.nsps.read().unwrap();
let ns = guard.get("/").unwrap();
// ns should not reserve socket
assert!(ns.get_socket(sid).is_err());
}
}
9 changes: 9 additions & 0 deletions crates/socketioxide/src/ns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,15 @@ impl<A: Adapter> Namespace<A> {
#[cfg(feature = "tracing")]
tracing::debug!("error sending connect packet: {:?}, closing conn", _e);
esocket.close(engineioxide::DisconnectReason::PacketParsingError);
// also remove sid inserted before
self.sockets
.write()
.map_err(|_| {
#[cfg(feature = "tracing")]
tracing::debug!("get lock err for {sid}");
ConnectFail
})?
.remove(&sid);
return Err(ConnectFail);
}

Expand Down
Loading