Skip to content

Commit

Permalink
Temporal fix for poll leak in tungstenite
Browse files Browse the repository at this point in the history
  • Loading branch information
lemunozm committed May 11, 2021
1 parent 288fa8a commit f033988
Showing 1 changed file with 106 additions and 42 deletions.
148 changes: 106 additions & 42 deletions src/adapters/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use tungstenite::error::{Error};

use url::Url;

use std::sync::{Mutex};
use std::sync::{Mutex, Arc};
use std::net::{SocketAddr};
use std::io::{self, ErrorKind};
use std::ops::{DerefMut};
Expand All @@ -36,14 +36,15 @@ impl Adapter for WsAdapter {
}

enum PendingHandshake {
Connect(Url, TcpStream),
Client(MidHandshake<ClientHandshake<TcpStream>>),
Server(MidHandshake<ServerHandshake<TcpStream, NoCallback>>),
Connect(Url, ArcTcpStream),
Client(MidHandshake<ClientHandshake<ArcTcpStream>>),
Server(MidHandshake<ServerHandshake<ArcTcpStream, NoCallback>>),
}

enum RemoteState {
WebSocket(WebSocket<TcpStream>),
WebSocket(WebSocket<ArcTcpStream>),
Handshake(Option<PendingHandshake>),
Error(ArcTcpStream),
}

pub(crate) struct RemoteResource {
Expand All @@ -53,13 +54,20 @@ pub(crate) struct RemoteResource {
impl Resource for RemoteResource {
fn source(&mut self) -> &mut dyn Source {
match self.state.get_mut().unwrap() {
RemoteState::WebSocket(web_socket) => web_socket.get_mut(),
RemoteState::WebSocket(web_socket) => {
Arc::get_mut(&mut web_socket.get_mut().0).unwrap()
}
RemoteState::Handshake(Some(handshake)) => match handshake {
PendingHandshake::Connect(_, stream) => stream,
PendingHandshake::Client(mid_handshake) => mid_handshake.get_mut().get_mut(),
PendingHandshake::Server(mid_handshake) => mid_handshake.get_mut().get_mut(),
PendingHandshake::Connect(_, stream) => Arc::get_mut(&mut stream.0).unwrap(),
PendingHandshake::Client(handshake) => {
Arc::get_mut(&mut handshake.get_mut().get_mut().0).unwrap()
}
PendingHandshake::Server(handshake) => {
Arc::get_mut(&mut handshake.get_mut().get_mut().0).unwrap()
}
},
RemoteState::Handshake(None) => unreachable!(),
RemoteState::Error(stream) => Arc::get_mut(&mut stream.0).unwrap(),
}
}
}
Expand Down Expand Up @@ -89,7 +97,8 @@ impl Remote for RemoteResource {
Ok(ConnectionInfo {
remote: RemoteResource {
state: Mutex::new(RemoteState::Handshake(Some(PendingHandshake::Connect(
url, stream,
url,
stream.into(),
)))),
},
local_addr,
Expand All @@ -112,7 +121,7 @@ impl Remote for RemoteResource {
// Seems like windows consume the `WouldBlock` notification
// at peek() when it happens, and the poll never wakes it again.
#[cfg(not(target_os = "windows"))]
let _peek_result = web_socket.get_ref().peek(&mut [0; 0]);
let _peek_result = web_socket.get_ref().0.peek(&mut [0; 0]);

// We can not call process_data while the socket is blocked.
// The user could lock it again if sends from the callback.
Expand All @@ -134,6 +143,7 @@ impl Remote for RemoteResource {
}
},
RemoteState::Handshake(_) => unreachable!(),
RemoteState::Error(_) => unreachable!(),
}
}
}
Expand All @@ -160,6 +170,7 @@ impl Remote for RemoteResource {
}
}
RemoteState::Handshake(_) => unreachable!(),
RemoteState::Error(_) => unreachable!(),
}
}

Expand All @@ -169,7 +180,8 @@ impl Remote for RemoteResource {
RemoteState::WebSocket(_) => PendingStatus::Ready,
RemoteState::Handshake(pending) => match pending.take().unwrap() {
PendingHandshake::Connect(url, stream) => {
match super::tcp::check_stream_ready(&stream) {
let stream_backup = stream.clone();
match super::tcp::check_stream_ready(&stream.0) {
PendingStatus::Ready => match ws_connect(url, stream) {
Ok((web_socket, _)) => {
*state = RemoteState::WebSocket(web_socket);
Expand All @@ -180,47 +192,64 @@ impl Remote for RemoteResource {
PendingStatus::Incomplete
}
Err(HandshakeError::Failure(Error::Io(_))) => {
*state = RemoteState::Error(stream_backup);
PendingStatus::Disconnected
}
Err(HandshakeError::Failure(err)) => {
*state = RemoteState::Error(stream_backup);
log::error!("WS connect handshake error: {}", err);
PendingStatus::Disconnected // should not happen
}
},
other => other,
}
}
PendingHandshake::Client(mid_handshake) => match mid_handshake.handshake() {
Ok((web_socket, _)) => {
*state = RemoteState::WebSocket(web_socket);
PendingStatus::Ready
}
Err(HandshakeError::Interrupted(mid_handshake)) => {
*pending = Some(PendingHandshake::Client(mid_handshake));
PendingStatus::Incomplete
}
Err(HandshakeError::Failure(Error::Io(_))) => PendingStatus::Disconnected,
Err(HandshakeError::Failure(err)) => {
log::error!("WS client handshake error: {}", err);
PendingStatus::Disconnected // should not happen
}
},
PendingHandshake::Server(mid_handshake) => match mid_handshake.handshake() {
Ok(web_socket) => {
*state = RemoteState::WebSocket(web_socket);
PendingStatus::Ready
}
Err(HandshakeError::Interrupted(mid_handshake)) => {
*pending = Some(PendingHandshake::Server(mid_handshake));
PendingStatus::Incomplete
PendingHandshake::Client(mid_handshake) => {
let stream_backup = mid_handshake.get_ref().get_ref().clone();
match mid_handshake.handshake() {
Ok((web_socket, _)) => {
*state = RemoteState::WebSocket(web_socket);
PendingStatus::Ready
}
Err(HandshakeError::Interrupted(mid_handshake)) => {
*pending = Some(PendingHandshake::Client(mid_handshake));
PendingStatus::Incomplete
}
Err(HandshakeError::Failure(Error::Io(_))) => {
*state = RemoteState::Error(stream_backup);
PendingStatus::Disconnected
}
Err(HandshakeError::Failure(err)) => {
*state = RemoteState::Error(stream_backup);
log::error!("WS client handshake error: {}", err);
PendingStatus::Disconnected // should not happen
}
}
Err(HandshakeError::Failure(Error::Io(_))) => PendingStatus::Disconnected,
Err(HandshakeError::Failure(err)) => {
log::error!("WS server handshake error: {}", err);
PendingStatus::Disconnected // should not happen
}
PendingHandshake::Server(mid_handshake) => {
let stream_backup = mid_handshake.get_ref().get_ref().clone();
match mid_handshake.handshake() {
Ok(web_socket) => {
*state = RemoteState::WebSocket(web_socket);
PendingStatus::Ready
}
Err(HandshakeError::Interrupted(mid_handshake)) => {
*pending = Some(PendingHandshake::Server(mid_handshake));
PendingStatus::Incomplete
}
Err(HandshakeError::Failure(Error::Io(_))) => {
*state = RemoteState::Error(stream_backup);
PendingStatus::Disconnected
}
Err(HandshakeError::Failure(err)) => {
*state = RemoteState::Error(stream_backup);
log::error!("WS server handshake error: {}", err);
PendingStatus::Disconnected // should not happen
}
}
},
}
},
RemoteState::Error(_) => unreachable!(),
}
}

Expand All @@ -230,9 +259,10 @@ impl Remote for RemoteResource {
Ok(_) => true,
Err(Error::Io(ref err)) if err.kind() == ErrorKind::WouldBlock => true,
Err(_) => false, // Will be disconnected,
}
},
// This function is only call on ready resources.
RemoteState::Handshake(_) => unreachable!(),
RemoteState::Error(_) => unreachable!(),
}
}
}
Expand Down Expand Up @@ -275,7 +305,7 @@ impl Local for LocalResource {
loop {
match self.listener.accept() {
Ok((stream, addr)) => {
let remote_state = match ws_accept(stream) {
let remote_state = match ws_accept(stream.into()) {
Ok(web_socket) => Some(RemoteState::WebSocket(web_socket)),
Err(HandshakeError::Interrupted(mid_handshake)) => Some(
RemoteState::Handshake(Some(PendingHandshake::Server(mid_handshake))),
Expand All @@ -298,3 +328,37 @@ impl Local for LocalResource {
}
}
}

/// This struct is used to avoid the tungstenite handshake to take the ownership of the stream
/// an drop it without allow to the driver to deregister from the poll.
/// It can be removed when this issue is resolved:
/// https://github.com/snapview/tungstenite-rs/issues/51
struct ArcTcpStream(Arc<TcpStream>);

impl From<TcpStream> for ArcTcpStream {
fn from(stream: TcpStream) -> Self {
Self(Arc::new(stream))
}
}

impl io::Read for ArcTcpStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
(&*self.0).read(buf)
}
}

impl io::Write for ArcTcpStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
(&*self.0).write(buf)
}

fn flush(&mut self) -> io::Result<()> {
(&*self.0).flush()
}
}

impl Clone for ArcTcpStream {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}

0 comments on commit f033988

Please sign in to comment.