diff --git a/examples/http.rs b/examples/http.rs index 80bdab93f9..30d12393e7 100644 --- a/examples/http.rs +++ b/examples/http.rs @@ -34,9 +34,10 @@ use std::net::SocketAddr; #[tokio::main] async fn main() -> anyhow::Result<()> { - // init tracing `FmtSubscriber`. - let subscriber = tracing_subscriber::FmtSubscriber::new(); - tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed"); + tracing_subscriber::FmtSubscriber::builder() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init() + .expect("setting default subscriber failed"); let (server_addr, _handle) = run_server().await?; let url = format!("http://{}", server_addr); diff --git a/examples/proc_macro.rs b/examples/proc_macro.rs index c5449b9c24..11825834b3 100644 --- a/examples/proc_macro.rs +++ b/examples/proc_macro.rs @@ -72,9 +72,10 @@ impl RpcServer for RpcServerImpl { #[tokio::main] async fn main() -> anyhow::Result<()> { - // init tracing `FmtSubscriber`. - let subscriber = tracing_subscriber::FmtSubscriber::builder().finish(); - tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed"); + tracing_subscriber::FmtSubscriber::builder() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init() + .expect("setting default subscriber failed"); let (server_addr, _handle) = run_server().await?; let url = format!("ws://{}", server_addr); diff --git a/examples/ws.rs b/examples/ws.rs index 6768c52b83..641a5cff5d 100644 --- a/examples/ws.rs +++ b/examples/ws.rs @@ -33,9 +33,10 @@ use std::net::SocketAddr; #[tokio::main] async fn main() -> anyhow::Result<()> { - // init tracing `FmtSubscriber`. - let subscriber = tracing_subscriber::FmtSubscriber::builder().finish(); - tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed"); + tracing_subscriber::FmtSubscriber::builder() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init() + .expect("setting default subscriber failed"); let addr = run_server().await?; let url = format!("ws://{}", addr); diff --git a/examples/ws_sub_with_params.rs b/examples/ws_sub_with_params.rs index d6720204ed..3c3c61c3d1 100644 --- a/examples/ws_sub_with_params.rs +++ b/examples/ws_sub_with_params.rs @@ -34,9 +34,10 @@ use std::net::SocketAddr; #[tokio::main] async fn main() -> anyhow::Result<()> { - // init tracing `FmtSubscriber`. - let subscriber = tracing_subscriber::FmtSubscriber::builder().finish(); - tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed"); + tracing_subscriber::FmtSubscriber::builder() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init() + .expect("setting default subscriber failed"); let addr = run_server().await?; let url = format!("ws://{}", addr); diff --git a/examples/ws_subscription.rs b/examples/ws_subscription.rs index 02308a7004..f9521992dc 100644 --- a/examples/ws_subscription.rs +++ b/examples/ws_subscription.rs @@ -36,9 +36,10 @@ const NUM_SUBSCRIPTION_RESPONSES: usize = 5; #[tokio::main] async fn main() -> anyhow::Result<()> { - // init tracing `FmtSubscriber`. - let subscriber = tracing_subscriber::FmtSubscriber::builder().finish(); - tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed"); + tracing_subscriber::FmtSubscriber::builder() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init() + .expect("setting default subscriber failed"); let addr = run_server().await?; let url = format!("ws://{}", addr); diff --git a/test-utils/Cargo.toml b/test-utils/Cargo.toml index 74ceae0d51..bf9ff16001 100644 --- a/test-utils/Cargo.toml +++ b/test-utils/Cargo.toml @@ -15,6 +15,6 @@ hyper = { version = "0.14.10", features = ["full"] } tracing = "0.1" serde = { version = "1", default-features = false, features = ["derive"] } serde_json = "1" -soketto = { version = "0.7", features = ["http"] } +soketto = { version = "0.7.1", features = ["http"] } tokio = { version = "1", features = ["net", "rt-multi-thread", "macros", "time"] } tokio-util = { version = "0.6", features = ["compat"] } diff --git a/types/Cargo.toml b/types/Cargo.toml index cc488e43e9..9b4fffcae6 100644 --- a/types/Cargo.toml +++ b/types/Cargo.toml @@ -19,5 +19,5 @@ tracing = { version = "0.1", default-features = false } serde = { version = "1", default-features = false, features = ["derive"] } serde_json = { version = "1", default-features = false, features = ["alloc", "raw_value", "std"] } thiserror = "1.0" -soketto = "0.7" +soketto = "0.7.1" hyper = "0.14.10" diff --git a/ws-client/Cargo.toml b/ws-client/Cargo.toml index b6943e5a47..b528b140b0 100644 --- a/ws-client/Cargo.toml +++ b/ws-client/Cargo.toml @@ -19,7 +19,7 @@ pin-project = "1" rustls-native-certs = "0.6.0" serde = "1" serde_json = "1" -soketto = "0.7" +soketto = "0.7.1" thiserror = "1" tokio = { version = "1", features = ["net", "time", "rt-multi-thread", "macros"] } tokio-rustls = "0.23" diff --git a/ws-server/Cargo.toml b/ws-server/Cargo.toml index 68652d1437..377eaf1d01 100644 --- a/ws-server/Cargo.toml +++ b/ws-server/Cargo.toml @@ -16,12 +16,12 @@ jsonrpsee-types = { path = "../types", version = "0.4.1" } jsonrpsee-utils = { path = "../utils", version = "0.4.1", features = ["server"] } tracing = "0.1" serde_json = { version = "1", features = ["raw_value"] } -soketto = "0.7" +soketto = "0.7.1" tokio = { version = "1", features = ["net", "rt-multi-thread", "macros"] } tokio-util = { version = "0.6", features = ["compat"] } [dev-dependencies] anyhow = "1" -env_logger = "0.9" jsonrpsee-test-utils = { path = "../test-utils" } jsonrpsee = { path = "../jsonrpsee", features = ["full"] } +tracing-subscriber = "0.2.25" diff --git a/ws-server/src/server.rs b/ws-server/src/server.rs index 3eb1a0cfe4..1356ccad06 100644 --- a/ws-server/src/server.rs +++ b/ws-server/src/server.rs @@ -39,6 +39,7 @@ use futures_channel::mpsc; use futures_util::future::FutureExt; use futures_util::io::{BufReader, BufWriter}; use futures_util::stream::{self, StreamExt}; +use soketto::connection::Error as SokettoError; use soketto::handshake::{server::Response, Server as SokettoServer}; use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; use tokio_util::compat::{Compat, TokioAsyncReadCompatExt}; @@ -195,6 +196,7 @@ async fn handshake(socket: tokio::net::TcpStream, mode: HandshakeResponse<'_>) - Ok(()) } HandshakeResponse::Accept { conn_id, methods, resources, cfg, stop_monitor } => { + tracing::debug!("Accepting new connection: {}", conn_id); let key = { let req = server.receive_request().await?; let host_check = cfg.allowed_hosts.verify("Host", Some(req.headers().host)); @@ -243,7 +245,9 @@ async fn background_task( stop_server: StopMonitor, ) -> Result<(), Error> { // And we can finally transition to a websocket background_task. - let (mut sender, mut receiver) = server.into_builder().finish(); + let mut builder = server.into_builder(); + builder.set_max_message_size(max_request_body_size as usize); + let (mut sender, mut receiver) = builder.finish(); let (tx, mut rx) = mpsc::unbounded::(); let stop_server2 = stop_server.clone(); @@ -252,8 +256,10 @@ async fn background_task( while !stop_server2.shutdown_requested() { match rx.next().await { Some(response) => { - tracing::debug!("send: {}", response); - let _ = sender.send_text(response).await; + // TODO: check length of response https://github.com/paritytech/jsonrpsee/issues/536 + tracing::debug!("send {} bytes", response.len()); + tracing::trace!("send: {}", response); + let _ = sender.send_text_owned(response).await; let _ = sender.flush().await; } None => break, @@ -272,22 +278,38 @@ async fn background_task( while !stop_server.shutdown_requested() { data.clear(); - if let Err(e) = method_executors.select_with(receiver.receive_data(&mut data)).await { - tracing::error!("Could not receive WS data: {:?}; closing connection", e); - tx.close_channel(); - return Err(e.into()); - } + if let Err(err) = method_executors.select_with(receiver.receive_data(&mut data)).await { + match err { + SokettoError::Closed => { + tracing::debug!("Remote peer terminated the connection: {}", conn_id); + tx.close_channel(); + return Ok(()); + } + SokettoError::MessageTooLarge { current, maximum } => { + tracing::warn!( + "WS transport error: message is too big error ({} bytes, max is {})", + current, + maximum + ); + send_error(Id::Null, &tx, ErrorCode::OversizedRequest.into()); + continue; + } + // These errors can not be gracefully handled, so just log them and terminate the connection. + err => { + tracing::error!("WS transport error: {:?} => terminating connection {}", err, conn_id); + tx.close_channel(); + return Err(err.into()); + } + }; + }; - if data.len() > max_request_body_size as usize { - tracing::warn!("Request is too big ({} bytes, max is {})", data.len(), max_request_body_size); - send_error(Id::Null, &tx, ErrorCode::OversizedRequest.into()); - continue; - } + tracing::debug!("recv {} bytes", data.len()); match data.get(0) { Some(b'{') => { if let Ok(req) = serde_json::from_slice::(&data) { - tracing::debug!("recv: {:?}", req); + tracing::debug!("recv method call={}", req.method); + tracing::trace!("recv: req={:?}", req); if let Some(fut) = methods.execute_with_resources(&tx, req, conn_id, &resources) { method_executors.add(fut); } @@ -309,6 +331,8 @@ async fn background_task( // complete batch response back to the client over `tx`. let (tx_batch, mut rx_batch) = mpsc::unbounded(); if let Ok(batch) = serde_json::from_slice::>(&d) { + tracing::debug!("recv batch len={}", batch.len()); + tracing::trace!("recv: batch={:?}", batch); if !batch.is_empty() { let methods_stream = stream::iter(batch.into_iter().filter_map(|req| { diff --git a/ws-server/src/tests.rs b/ws-server/src/tests.rs index e0c1236412..20bd65e67c 100644 --- a/ws-server/src/tests.rs +++ b/ws-server/src/tests.rs @@ -34,9 +34,12 @@ use jsonrpsee_test_utils::helpers::*; use jsonrpsee_test_utils::mocks::{Id, TestContext, WebSocketTestClient, WebSocketTestError}; use jsonrpsee_test_utils::TimeoutFutureExt; use serde_json::Value as JsonValue; -use std::fmt; -use std::net::SocketAddr; -use std::time::Duration; +use std::{fmt, net::SocketAddr, time::Duration}; +use tracing_subscriber::{EnvFilter, FmtSubscriber}; + +fn init_logger() { + let _ = FmtSubscriber::builder().with_env_filter(EnvFilter::from_default_env()).try_init(); +} /// Applications can/should provide their own error. #[derive(Debug)] @@ -156,6 +159,8 @@ async fn server_with_context() -> SocketAddr { #[tokio::test] async fn can_set_the_max_request_body_size() { + init_logger(); + let addr = "127.0.0.1:0"; // Rejects all requests larger than 10 bytes let server = WsServerBuilder::default().max_request_body_size(10).build(addr).await.unwrap(); @@ -225,6 +230,7 @@ async fn single_method_calls_works() { #[tokio::test] async fn async_method_calls_works() { + init_logger(); let addr = server().await; let mut client = WebSocketTestClient::new(addr).await.unwrap(); @@ -342,7 +348,6 @@ async fn single_method_call_with_params_works() { #[tokio::test] async fn single_method_call_with_faulty_params_returns_err() { - let _ = env_logger::try_init(); let addr = server().await; let mut client = WebSocketTestClient::new(addr).with_default_timeout().await.unwrap().unwrap(); let expected = r#"{"jsonrpc":"2.0","error":{"code":-32602,"message":"invalid type: string \"should be a number\", expected u64 at line 1 column 21"},"id":1}"#; @@ -539,7 +544,7 @@ async fn can_register_modules() { #[tokio::test] async fn stop_works() { - let _ = env_logger::try_init(); + init_logger(); let (_addr, server_handle) = server_with_handles().with_default_timeout().await.unwrap(); server_handle.clone().stop().unwrap().with_default_timeout().await.unwrap(); @@ -554,7 +559,7 @@ async fn stop_works() { async fn run_forever() { const TIMEOUT: Duration = Duration::from_millis(200); - let _ = env_logger::try_init(); + init_logger(); let (_addr, server_handle) = server_with_handles().with_default_timeout().await.unwrap(); assert!(matches!(server_handle.with_timeout(TIMEOUT).await, Err(_timeout_err)));