Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ws server: respect max limit for received messages #537

Merged
merged 21 commits into from
Nov 9, 2021
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions examples/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@ use std::net::SocketAddr;

#[tokio::main]
async fn main() -> anyhow::Result<()> {
// init tracing `FmtSubscriber`.
let subscriber = tracing_subscriber::FmtSubscriber::new();
tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");
tracing_subscriber::FmtSubscriber::builder()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.try_init()
.expect("setting default subscriber failed");

let (server_addr, _handle) = run_server().await?;
let url = format!("http://{}", server_addr);
Expand Down
7 changes: 4 additions & 3 deletions examples/proc_macro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,10 @@ impl RpcServer<ExampleHash, ExampleStorageKey> for RpcServerImpl {

#[tokio::main]
async fn main() -> anyhow::Result<()> {
// init tracing `FmtSubscriber`.
let subscriber = tracing_subscriber::FmtSubscriber::builder().finish();
tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");
tracing_subscriber::FmtSubscriber::builder()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use RUST_LOG env variable

.try_init()
.expect("setting default subscriber failed");

let (server_addr, _handle) = run_server().await?;
let url = format!("ws://{}", server_addr);
Expand Down
7 changes: 4 additions & 3 deletions examples/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ use std::net::SocketAddr;

#[tokio::main]
async fn main() -> anyhow::Result<()> {
// init tracing `FmtSubscriber`.
let subscriber = tracing_subscriber::FmtSubscriber::builder().finish();
tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");
tracing_subscriber::FmtSubscriber::builder()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.try_init()
.expect("setting default subscriber failed");

let addr = run_server().await?;
let url = format!("ws://{}", addr);
Expand Down
7 changes: 4 additions & 3 deletions examples/ws_sub_with_params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@ use std::net::SocketAddr;

#[tokio::main]
async fn main() -> anyhow::Result<()> {
// init tracing `FmtSubscriber`.
let subscriber = tracing_subscriber::FmtSubscriber::builder().finish();
tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");
tracing_subscriber::FmtSubscriber::builder()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.try_init()
.expect("setting default subscriber failed");

let addr = run_server().await?;
let url = format!("ws://{}", addr);
Expand Down
7 changes: 4 additions & 3 deletions examples/ws_subscription.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ const NUM_SUBSCRIPTION_RESPONSES: usize = 5;

#[tokio::main]
async fn main() -> anyhow::Result<()> {
// init tracing `FmtSubscriber`.
let subscriber = tracing_subscriber::FmtSubscriber::builder().finish();
tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");
tracing_subscriber::FmtSubscriber::builder()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.try_init()
.expect("setting default subscriber failed");

let addr = run_server().await?;
let url = format!("ws://{}", addr);
Expand Down
2 changes: 1 addition & 1 deletion test-utils/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ hyper = { version = "0.14.10", features = ["full"] }
tracing = "0.1"
serde = { version = "1", default-features = false, features = ["derive"] }
serde_json = "1"
soketto = { version = "0.7", features = ["http"] }
soketto = { version = "0.7.1", features = ["http"] }
tokio = { version = "1", features = ["net", "rt-multi-thread", "macros", "time"] }
tokio-util = { version = "0.6", features = ["compat"] }
2 changes: 1 addition & 1 deletion types/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@ tracing = { version = "0.1", default-features = false }
serde = { version = "1", default-features = false, features = ["derive"] }
serde_json = { version = "1", default-features = false, features = ["alloc", "raw_value", "std"] }
thiserror = "1.0"
soketto = "0.7"
soketto = "0.7.1"
hyper = "0.14.10"
2 changes: 1 addition & 1 deletion ws-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pin-project = "1"
rustls-native-certs = "0.6.0"
serde = "1"
serde_json = "1"
soketto = "0.7"
soketto = "0.7.1"
thiserror = "1"
tokio = { version = "1", features = ["net", "time", "rt-multi-thread", "macros"] }
tokio-rustls = "0.23"
Expand Down
4 changes: 2 additions & 2 deletions ws-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ jsonrpsee-types = { path = "../types", version = "0.4.1" }
jsonrpsee-utils = { path = "../utils", version = "0.4.1", features = ["server"] }
tracing = "0.1"
serde_json = { version = "1", features = ["raw_value"] }
soketto = "0.7"
soketto = "0.7.1"
tokio = { version = "1", features = ["net", "rt-multi-thread", "macros"] }
tokio-util = { version = "0.6", features = ["compat"] }

[dev-dependencies]
anyhow = "1"
env_logger = "0.9"
jsonrpsee-test-utils = { path = "../test-utils" }
jsonrpsee = { path = "../jsonrpsee", features = ["full"] }
tracing-subscriber = "0.2.25"
52 changes: 38 additions & 14 deletions ws-server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ use futures_channel::mpsc;
use futures_util::future::FutureExt;
use futures_util::io::{BufReader, BufWriter};
use futures_util::stream::{self, StreamExt};
use soketto::connection::Error as SokettoError;
use soketto::handshake::{server::Response, Server as SokettoServer};
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};
Expand Down Expand Up @@ -195,6 +196,7 @@ async fn handshake(socket: tokio::net::TcpStream, mode: HandshakeResponse<'_>) -
Ok(())
}
HandshakeResponse::Accept { conn_id, methods, resources, cfg, stop_monitor } => {
tracing::debug!("Accepting new connection: {}", conn_id);
let key = {
let req = server.receive_request().await?;
let host_check = cfg.allowed_hosts.verify("Host", Some(req.headers().host));
Expand Down Expand Up @@ -243,7 +245,9 @@ async fn background_task(
stop_server: StopMonitor,
) -> Result<(), Error> {
// And we can finally transition to a websocket background_task.
let (mut sender, mut receiver) = server.into_builder().finish();
let mut builder = server.into_builder();
builder.set_max_message_size(max_request_body_size as usize);
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved
let (mut sender, mut receiver) = builder.finish();
let (tx, mut rx) = mpsc::unbounded::<String>();
let stop_server2 = stop_server.clone();

Expand All @@ -252,8 +256,10 @@ async fn background_task(
while !stop_server2.shutdown_requested() {
match rx.next().await {
Some(response) => {
tracing::debug!("send: {}", response);
let _ = sender.send_text(response).await;
// TODO: check length of response https://github.com/paritytech/jsonrpsee/issues/536
tracing::debug!("send {} bytes", response.len());
tracing::trace!("send: {}", response);
let _ = sender.send_text_owned(response).await;
Copy link
Contributor

@maciejhirsz maciejhirsz Nov 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 for skipping a clone :)

let _ = sender.flush().await;
}
None => break,
Expand All @@ -272,22 +278,38 @@ async fn background_task(
while !stop_server.shutdown_requested() {
data.clear();

if let Err(e) = method_executors.select_with(receiver.receive_data(&mut data)).await {
tracing::error!("Could not receive WS data: {:?}; closing connection", e);
tx.close_channel();
return Err(e.into());
}
if let Err(err) = method_executors.select_with(receiver.receive_data(&mut data)).await {
match err {
SokettoError::Closed => {
tracing::debug!("Remote peer terminated the connection: {}", conn_id);
tx.close_channel();
return Ok(());
}
SokettoError::MessageTooLarge { current, maximum } => {
tracing::warn!(
"WS transport error: message is too big error ({} bytes, max is {})",
current,
maximum
);
send_error(Id::Null, &tx, ErrorCode::OversizedRequest.into());
continue;
}
// These errors can not be gracefully handled, so just log them and terminate the connection.
err => {
tracing::error!("WS transport error: {:?} => terminate connection {}", err, conn_id);
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved
tx.close_channel();
return Err(err.into());
}
};
};

if data.len() > max_request_body_size as usize {
tracing::warn!("Request is too big ({} bytes, max is {})", data.len(), max_request_body_size);
send_error(Id::Null, &tx, ErrorCode::OversizedRequest.into());
continue;
}
tracing::debug!("recv {} bytes", data.len());

match data.get(0) {
Some(b'{') => {
if let Ok(req) = serde_json::from_slice::<Request>(&data) {
tracing::debug!("recv: {:?}", req);
tracing::debug!("recv call={}", req.method);
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved
tracing::trace!("recv: {:?}", req);
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved
if let Some(fut) = methods.execute_with_resources(&tx, req, conn_id, &resources) {
method_executors.add(fut);
}
Expand All @@ -309,6 +331,8 @@ async fn background_task(
// complete batch response back to the client over `tx`.
let (tx_batch, mut rx_batch) = mpsc::unbounded();
if let Ok(batch) = serde_json::from_slice::<Vec<Request>>(&d) {
tracing::debug!("recv batch={}", batch.len());
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved
tracing::trace!("recv: {:?}", batch);
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved
if !batch.is_empty() {
let methods_stream =
stream::iter(batch.into_iter().filter_map(|req| {
Expand Down
11 changes: 9 additions & 2 deletions ws-server/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ use jsonrpsee_test_utils::TimeoutFutureExt;
use serde_json::Value as JsonValue;
use std::fmt;
use std::net::SocketAddr;
use tracing_subscriber::{EnvFilter, FmtSubscriber};

fn init_logger() {
let _ = FmtSubscriber::builder().with_env_filter(EnvFilter::from_default_env()).try_init();
}

/// Applications can/should provide their own error.
#[derive(Debug)]
Expand Down Expand Up @@ -154,6 +159,8 @@ async fn server_with_context() -> SocketAddr {

#[tokio::test]
async fn can_set_the_max_request_body_size() {
init_logger();

let addr = "127.0.0.1:0";
// Rejects all requests larger than 10 bytes
let server = WsServerBuilder::default().max_request_body_size(10).build(addr).await.unwrap();
Expand Down Expand Up @@ -223,6 +230,7 @@ async fn single_method_calls_works() {

#[tokio::test]
async fn async_method_calls_works() {
init_logger();
let addr = server().await;
let mut client = WebSocketTestClient::new(addr).await.unwrap();

Expand Down Expand Up @@ -340,7 +348,6 @@ async fn single_method_call_with_params_works() {

#[tokio::test]
async fn single_method_call_with_faulty_params_returns_err() {
let _ = env_logger::try_init();
let addr = server().await;
let mut client = WebSocketTestClient::new(addr).with_default_timeout().await.unwrap().unwrap();
let expected = r#"{"jsonrpc":"2.0","error":{"code":-32602,"message":"invalid type: string \"should be a number\", expected u64 at line 1 column 21"},"id":1}"#;
Expand Down Expand Up @@ -537,7 +544,7 @@ async fn can_register_modules() {

#[tokio::test]
async fn stop_works() {
let _ = env_logger::try_init();
init_logger();
let (_addr, stop_handle) = server_with_handles().with_default_timeout().await.unwrap();
stop_handle.clone().stop().unwrap().with_default_timeout().await.unwrap();

Expand Down