Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Synchronization-less async connections in ws-server #388

Merged
merged 14 commits into from
Jun 30, 2021
179 changes: 127 additions & 52 deletions ws-server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,25 @@
// IN background_task WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.

use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{net::SocketAddr, sync::Arc};

use futures_channel::mpsc;
use futures_util::future::{join_all, FutureExt};
use futures_util::stream::StreamExt;
use futures_util::{
io::{BufReader, BufWriter},
SinkExt,
};
use jsonrpsee_types::TEN_MB_SIZE_BYTES;
use soketto::handshake::{server::Response, Server as SokettoServer};
use std::{net::SocketAddr, sync::Arc};
use tokio::{
net::{TcpListener, ToSocketAddrs},
net::{TcpListener, TcpStream, ToSocketAddrs},
sync::RwLock,
};
use tokio_stream::wrappers::TcpListenerStream;
use tokio_util::compat::TokioAsyncReadCompatExt;
use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};

use jsonrpsee_types::error::Error;
use jsonrpsee_types::v2::error::JsonRpcErrorCode;
Expand Down Expand Up @@ -91,64 +95,114 @@ impl Server {
// Acquire read access to the lock such that additional reader(s) may share this lock.
// Write access to this lock will only be possible after the server and all background tasks have stopped.
let _stop_handle = self.stop_handle.read().await;
let shutdown = self.stop_pair.0;

let mut incoming = TcpListenerStream::new(self.listener).fuse();
let methods = self.methods;
let conn_counter = Arc::new(());
let mut id = 0;
let mut stop_receiver = self.stop_pair.1;
let shutdown = self.stop_pair.0;

let mut driver = ConnDriver::new(self.listener, self.stop_pair.1);

loop {
futures_util::select! {
socket = incoming.next() => {
if let Some(Ok(socket)) = socket {
if let Err(e) = socket.set_nodelay(true) {
log::error!("Could not set NODELAY on socket: {:?}", e);
continue;
}

if Arc::strong_count(&conn_counter) > self.cfg.max_connections as usize {
log::warn!("Too many connections. Try again in a while");
continue;
}

let conn_counter2 = conn_counter.clone();
let shutdown2 = shutdown.clone();
let methods = methods.clone();
let cfg = self.cfg.clone();
let stop_handle2 = self.stop_handle.clone();

tokio::spawn(async move {
let _ = background_task(socket, id, methods, cfg, shutdown2, stop_handle2).await;
drop(conn_counter2);
});

id = id.wrapping_add(1);
} else {
break;
match Pin::new(&mut driver).await {
Ok((socket, _addr)) => {
if let Err(e) = socket.set_nodelay(true) {
log::error!("Could not set NODELAY on socket: {:?}", e);
continue;
}
},
stop = stop_receiver.next() => {
if stop.is_some() {
break;

if driver.connection_count() >= self.cfg.max_connections as usize {
log::warn!("Too many connections. Try again in a while.");
continue;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is where we'd send a CLOSE with a reason right?

Copy link
Member

@niklasad1 niklasad1 Jun 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we just ignore the socket.

The close reason is only sent after the websocket handshake has been completed AFAIU

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're confusing this PR with telemetry :). At this stage the WS connection hasn't been established yet, so you can't send a frame. A nice thing to do would be to send a HTTP response with appropriate status code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A nice thing to do would be to send a HTTP response with appropriate status code.

Indeed I'm confused. And strong yes for returning a proper http status code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Proper status code is 429 Too Many Requests? Should we include a Retry-After header too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do that in a separate PR, since that would need extra changes to soketto too.

}
},
complete => break,

let methods = &methods;
let cfg = &self.cfg;

driver.add(Box::pin(handshake(socket, id, methods, cfg, &shutdown, &self.stop_handle)));

id += 1;
maciejhirsz marked this conversation as resolved.
Show resolved Hide resolved
}
Err(DriverError::Io(err)) => {
log::error!("Error while awaiting a new connection: {:?}", err);
}
Err(DriverError::Shutdown) => break,
}
}
}
}

async fn background_task(
/// This is a glorified select `Future` that will attempt to drive all
/// connection futures `F` to completion on each `poll`, while also
/// handling incoming connections.
struct ConnDriver<F> {
listener: TcpListener,
stop_receiver: mpsc::Receiver<()>,
connections: Vec<F>,
}

impl<F> ConnDriver<F>
where
F: Future + Unpin,
{
fn new(listener: TcpListener, stop_receiver: mpsc::Receiver<()>) -> Self {
ConnDriver { listener, stop_receiver, connections: Vec::new() }
}

fn connection_count(&self) -> usize {
self.connections.len()
}

fn add(&mut self, conn: F) {
self.connections.push(conn);
}
}

enum DriverError {
Shutdown,
Io(std::io::Error),
}

impl<F> Future for ConnDriver<F>
where
F: Future + Unpin,
{
type Output = Result<(TcpStream, SocketAddr), DriverError>;

fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = Pin::into_inner(self);

let mut i = 0;

while i < this.connections.len() {
if this.connections[i].poll_unpin(cx).is_ready() {
dvdplm marked this conversation as resolved.
Show resolved Hide resolved
// Using `swap_remove` since we don't care about ordering
// but we do care about removing being `O(1)`.
//
// We don't increment `i` in this branch, since we now
// have a shorter length, and potentially a new value at
// current index
this.connections.swap_remove(i);
maciejhirsz marked this conversation as resolved.
Show resolved Hide resolved
} else {
i += 1;
}
}

if let Poll::Ready(Some(())) = this.stop_receiver.next().poll_unpin(cx) {
return Poll::Ready(Err(DriverError::Shutdown));
}
Comment on lines +190 to +192
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This bit replaces the select that was added on master.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to document how the shutdown process works and what connected clients should expect to happen to their pending requests.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I'll try again: is this where we should send a CLOSE frame?

Copy link
Contributor Author

@maciejhirsz maciejhirsz Jun 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I reckon @niklasad1's planned refactor will make the code much easier to follow here too.


this.listener.poll_accept(cx).map_err(DriverError::Io)
}
}

async fn handshake(
socket: tokio::net::TcpStream,
conn_id: ConnectionId,
methods: Methods,
cfg: Settings,
shutdown: mpsc::Sender<()>,
stop_handle: Arc<RwLock<()>>,
methods: &Methods,
cfg: &Settings,
shutdown: &mpsc::Sender<()>,
stop_handle: &Arc<RwLock<()>>,
) -> Result<(), Error> {
let _lock = stop_handle.read().await;
// For each incoming background_task we perform a handshake.
let mut server = SokettoServer::new(BufReader::new(BufWriter::new(socket.compat())));

Expand All @@ -170,6 +224,27 @@ async fn background_task(
}
}

tokio::spawn(background_task(
server,
conn_id,
methods.clone(),
cfg.max_request_body_size,
shutdown.clone(),
stop_handle.clone(),
))
.await
.unwrap()
dvdplm marked this conversation as resolved.
Show resolved Hide resolved
}

async fn background_task(
server: SokettoServer<'_, BufReader<BufWriter<Compat<tokio::net::TcpStream>>>>,
conn_id: ConnectionId,
methods: Methods,
max_request_body_size: u32,
shutdown: mpsc::Sender<()>,
stop_handle: Arc<RwLock<()>>,
) -> Result<(), Error> {
let _lock = stop_handle.read().await;
// And we can finally transition to a websocket background_task.
let (mut sender, mut receiver) = server.into_builder().finish();
let (tx, mut rx) = mpsc::unbounded::<String>();
Expand Down Expand Up @@ -199,8 +274,8 @@ async fn background_task(

receiver.receive_data(&mut data).await?;

if data.len() > cfg.max_request_body_size as usize {
log::warn!("Request is too big ({} bytes, max is {})", data.len(), cfg.max_request_body_size);
if data.len() > max_request_body_size as usize {
log::warn!("Request is too big ({} bytes, max is {})", data.len(), max_request_body_size);
send_error(Id::Null, &tx, JsonRpcErrorCode::OversizedRequest.into());
continue;
}
Expand All @@ -219,9 +294,9 @@ async fn background_task(
// batch and read the results off of a new channel, `rx_batch`, and then send the complete batch response
// back to the client over `tx`.
let (tx_batch, mut rx_batch) = mpsc::unbounded::<String>();
for req in batch {
methods.execute(&tx_batch, req, conn_id).await;
}

join_all(batch.into_iter().map(|req| methods.execute(&tx_batch, req, conn_id))).await;

// Closes the receiving half of a channel without dropping it. This prevents any further messages from
// being sent on the channel.
rx_batch.close();
Expand Down