diff --git a/humphrey-ws/src/async_app.rs b/humphrey-ws/src/async_app.rs index e831642..c55221a 100644 --- a/humphrey-ws/src/async_app.rs +++ b/humphrey-ws/src/async_app.rs @@ -28,18 +28,15 @@ where /// Represents an asynchronous WebSocket stream. pub struct AsyncStream { addr: SocketAddr, - message_queue: Vec, + message_queue: Arc>>, connected: bool, } -pub trait EventHandler: Fn(&mut AsyncStream, Arc) + Send + Sync + 'static {} -impl EventHandler for T where T: Fn(&mut AsyncStream, Arc) + Send + Sync + 'static {} +pub trait EventHandler: Fn(AsyncStream, Arc) + Send + Sync + 'static {} +impl EventHandler for T where T: Fn(AsyncStream, Arc) + Send + Sync + 'static {} -pub trait MessageHandler: Fn(&mut AsyncStream, Message, Arc) + Send + Sync + 'static {} -impl MessageHandler for T where - T: Fn(&mut AsyncStream, Message, Arc) + Send + Sync + 'static -{ -} +pub trait MessageHandler: Fn(AsyncStream, Message, Arc) + Send + Sync + 'static {} +impl MessageHandler for T where T: Fn(AsyncStream, Message, Arc) + Send + Sync + 'static {} impl AsyncWebsocketApp where @@ -90,19 +87,21 @@ where match stream.recv_nonblocking() { Restion::Ok(message) => { - let mut async_stream = AsyncStream::new(addr); + let messages = Arc::new(Mutex::new(Vec::new())); + let async_stream = AsyncStream::new(addr, messages.clone()); if let Some(handler) = &self.on_message { - handler(&mut async_stream, message, self.state.clone()); + handler(async_stream, message, self.state.clone()); } - for message in async_stream.into_inner() { + for message in messages.lock().unwrap().drain(..) { stream.send(message).unwrap(); } } Restion::Err(_) => { - let mut async_stream = AsyncStream::disconnected(addr); + let messages = Arc::new(Mutex::new(Vec::new())); + let async_stream = AsyncStream::disconnected(addr, messages.clone()); if let Some(handler) = &self.on_disconnect { - handler(&mut async_stream, self.state.clone()) + handler(async_stream, self.state.clone()); } self.streams.remove(&addr); @@ -119,12 +118,13 @@ where .try_iter() .filter_map(|s| s.peer_addr().map(|a| (a, s)).ok()) { - let mut async_stream = AsyncStream::new(addr); + let messages = Arc::new(Mutex::new(Vec::new())); + let async_stream = AsyncStream::new(addr, messages.clone()); if let Some(handler) = &self.on_connect { - handler(&mut async_stream, self.state.clone()); + handler(async_stream, self.state.clone()); } - for message in async_stream.into_inner() { + for message in messages.lock().unwrap().drain(..) { stream.send(message).unwrap(); } @@ -135,18 +135,18 @@ where } impl AsyncStream { - pub fn new(addr: SocketAddr) -> Self { + pub fn new(addr: SocketAddr, messages: Arc>>) -> Self { Self { addr, - message_queue: vec![], + message_queue: messages, connected: true, } } - pub fn disconnected(addr: SocketAddr) -> Self { + pub fn disconnected(addr: SocketAddr, messages: Arc>>) -> Self { Self { addr, - message_queue: vec![], + message_queue: messages, connected: false, } } @@ -155,12 +155,8 @@ impl AsyncStream { self.addr } - pub fn send(&mut self, message: Message) { + pub fn send(&self, message: Message) { assert!(self.connected); - self.message_queue.push(message); - } - - fn into_inner(self) -> Vec { - self.message_queue + self.message_queue.lock().unwrap().push(message); } }