From e570c90e9b4ee3a9aea005d93028ab1caf380c8f Mon Sep 17 00:00:00 2001 From: Lucio Franco Date: Wed, 29 Mar 2023 17:22:15 -0400 Subject: [PATCH] feat(core): Default encoding/decoding limits This PR adds new defaults for both client and server max encoding/decoding message size limits. By default, the max message decoding size is `4MB` and the max message encoding size is `usize::MAX`. This is follow up work from https://github.com/hyperium/tonic/pull/1274 BREAKING: Default max message encoding/decoding limits --- tests/integration_tests/Cargo.toml | 3 +- tests/integration_tests/proto/test.proto | 11 + tests/integration_tests/src/lib.rs | 6 + .../tests/max_message_size.rs | 278 ++++++++++++++++++ tonic-build/src/client.rs | 4 + tonic-build/src/server.rs | 4 + tonic/src/codec/decode.rs | 6 +- tonic/src/codec/encode.rs | 4 +- tonic/src/codec/mod.rs | 3 +- tonic/src/lib.rs | 8 + 10 files changed, 321 insertions(+), 6 deletions(-) create mode 100644 tests/integration_tests/tests/max_message_size.rs diff --git a/tests/integration_tests/Cargo.toml b/tests/integration_tests/Cargo.toml index 3bf261322..957d936fb 100644 --- a/tests/integration_tests/Cargo.toml +++ b/tests/integration_tests/Cargo.toml @@ -14,6 +14,7 @@ futures-util = "0.3" prost = "0.11" tokio = {version = "1.0", features = ["macros", "rt-multi-thread", "net"]} tonic = {path = "../../tonic"} +tracing-subscriber = {version = "0.3", features = ["env-filter"]} [dev-dependencies] async-stream = "0.3" @@ -25,7 +26,7 @@ tokio-stream = {version = "0.1.5", features = ["net"]} tower = {version = "0.4", features = []} tower-http = { version = "0.4", features = ["set-header", "trace"] } tower-service = "0.3" -tracing-subscriber = {version = "0.3", features = ["env-filter"]} +tracing = "0.1" [build-dependencies] tonic-build = {path = "../../tonic-build"} diff --git a/tests/integration_tests/proto/test.proto b/tests/integration_tests/proto/test.proto index 98452e962..53df288dd 100644 --- a/tests/integration_tests/proto/test.proto +++ b/tests/integration_tests/proto/test.proto @@ -8,3 +8,14 @@ service Test { message Input {} message Output {} + +service Test1 { + rpc UnaryCall(Input1) returns (Output1); +} + +message Input1 { + bytes buf = 1; +} +message Output1 { + bytes buf = 1; +} diff --git a/tests/integration_tests/src/lib.rs b/tests/integration_tests/src/lib.rs index 57691ed6b..7c6e23728 100644 --- a/tests/integration_tests/src/lib.rs +++ b/tests/integration_tests/src/lib.rs @@ -53,3 +53,9 @@ pub mod mock { } } } + +pub fn trace_init() { + let _ = tracing_subscriber::FmtSubscriber::builder() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); +} diff --git a/tests/integration_tests/tests/max_message_size.rs b/tests/integration_tests/tests/max_message_size.rs new file mode 100644 index 000000000..14dbec3da --- /dev/null +++ b/tests/integration_tests/tests/max_message_size.rs @@ -0,0 +1,278 @@ +use integration_tests::{ + pb::{test1_client, test1_server, Input1, Output1}, + trace_init, +}; +use tonic::{ + transport::{Endpoint, Server}, + Code, Request, Response, Status, +}; + +#[test] +fn max_message_recv_size() { + trace_init(); + + // Server recv + assert_server_recv_max_success(128); + // 5 is the size of the gRPC header + assert_server_recv_max_success((4 * 1024 * 1024) - 5); + // 4mb is the max recv size + assert_server_recv_max_failure(4 * 1024 * 1024); + assert_server_recv_max_failure(4 * 1024 * 1024 + 1); + assert_server_recv_max_failure(8 * 1024 * 1024); + + // Client recv + assert_client_recv_max_success(128); + // 5 is the size of the gRPC header + assert_client_recv_max_success((4 * 1024 * 1024) - 5); + // 4mb is the max recv size + assert_client_recv_max_failure(4 * 1024 * 1024); + assert_client_recv_max_failure(4 * 1024 * 1024 + 1); + assert_client_recv_max_failure(8 * 1024 * 1024); + + // Custom limit settings + assert_test_case(TestCase { + // 5 is the size of the gRPC header + server_blob_size: 1024 - 5, + client_recv_max: Some(1024), + ..Default::default() + }); + assert_test_case(TestCase { + server_blob_size: 1024, + client_recv_max: Some(1024), + expected_code: Some(Code::OutOfRange), + ..Default::default() + }); + + assert_test_case(TestCase { + // 5 is the size of the gRPC header + client_blob_size: 1024 - 5, + server_recv_max: Some(1024), + ..Default::default() + }); + assert_test_case(TestCase { + client_blob_size: 1024, + server_recv_max: Some(1024), + expected_code: Some(Code::OutOfRange), + ..Default::default() + }); +} + +#[test] +fn max_message_send_size() { + trace_init(); + + // Check client send limit works + assert_test_case(TestCase { + client_blob_size: 4 * 1024 * 1024, + server_recv_max: Some(usize::MAX), + ..Default::default() + }); + assert_test_case(TestCase { + // 5 is the size of the gRPC header + client_blob_size: 1024 - 5, + server_recv_max: Some(usize::MAX), + client_send_max: Some(1024), + ..Default::default() + }); + assert_test_case(TestCase { + // 5 is the size of the gRPC header + client_blob_size: 4 * 1024 * 1024, + server_recv_max: Some(usize::MAX), + // Set client send limit to 1024 + client_send_max: Some(1024), + // TODO: This should return OutOfRange + // https://github.com/hyperium/tonic/issues/1334 + expected_code: Some(Code::Internal), + ..Default::default() + }); + + // Check server send limit works + assert_test_case(TestCase { + server_blob_size: 4 * 1024 * 1024, + client_recv_max: Some(usize::MAX), + ..Default::default() + }); + assert_test_case(TestCase { + // 5 is the gRPC header size + server_blob_size: 1024 - 5, + client_recv_max: Some(usize::MAX), + // Set server send limit to 1024 + server_send_max: Some(1024), + ..Default::default() + }); + assert_test_case(TestCase { + server_blob_size: 4 * 1024 * 1024, + client_recv_max: Some(usize::MAX), + // Set server send limit to 1024 + server_send_max: Some(1024), + expected_code: Some(Code::OutOfRange), + ..Default::default() + }); +} + +// Track caller doesn't work on async fn so we extract the async part +// into a sync version and assert the response there using track track_caller +// so that when this does panic it tells us which line in the test failed not +// where we placed the panic call. + +#[track_caller] +fn assert_server_recv_max_success(size: usize) { + let case = TestCase { + client_blob_size: size, + server_blob_size: 0, + ..Default::default() + }; + + assert_test_case(case); +} + +#[track_caller] +fn assert_server_recv_max_failure(size: usize) { + let case = TestCase { + client_blob_size: size, + server_blob_size: 0, + expected_code: Some(Code::OutOfRange), + ..Default::default() + }; + + assert_test_case(case); +} + +#[track_caller] +fn assert_client_recv_max_success(size: usize) { + let case = TestCase { + client_blob_size: 0, + server_blob_size: size, + ..Default::default() + }; + + assert_test_case(case); +} + +#[track_caller] +fn assert_client_recv_max_failure(size: usize) { + let case = TestCase { + client_blob_size: 0, + server_blob_size: size, + expected_code: Some(Code::OutOfRange), + ..Default::default() + }; + + assert_test_case(case); +} + +#[track_caller] +fn assert_test_case(case: TestCase) { + let res = max_message_run(&case); + + match (case.expected_code, res) { + (Some(_), Ok(())) => panic!("Expected failure, but got success"), + (Some(code), Err(status)) => { + if status.code() != code { + panic!( + "Expected failure, got failure but wrong code, got: {:?}", + status + ) + } + } + + (None, Err(status)) => panic!("Expected success, but got failure, got: {:?}", status), + + _ => (), + } +} + +#[derive(Default)] +struct TestCase { + client_blob_size: usize, + server_blob_size: usize, + client_recv_max: Option, + server_recv_max: Option, + client_send_max: Option, + server_send_max: Option, + + expected_code: Option, +} + +#[tokio::main] +async fn max_message_run(case: &TestCase) -> Result<(), Status> { + let client_blob = vec![0; case.client_blob_size]; + let server_blob = vec![0; case.server_blob_size]; + + let (client, server) = tokio::io::duplex(1024); + + struct Svc(Vec); + + #[tonic::async_trait] + impl test1_server::Test1 for Svc { + async fn unary_call(&self, _req: Request) -> Result, Status> { + Ok(Response::new(Output1 { + buf: self.0.clone(), + })) + } + } + + let svc = test1_server::Test1Server::new(Svc(server_blob)); + + let svc = if let Some(size) = case.server_recv_max { + svc.max_decoding_message_size(size) + } else { + svc + }; + + let svc = if let Some(size) = case.server_send_max { + svc.max_encoding_message_size(size) + } else { + svc + }; + + tokio::spawn(async move { + Server::builder() + .add_service(svc) + .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>(server)])) + .await + .unwrap(); + }); + + // Move client to an option so we can _move_ the inner value + // on the first attempt to connect. All other attempts will fail. + let mut client = Some(client); + let channel = Endpoint::try_from("http://[::]:50051") + .unwrap() + .connect_with_connector(tower::service_fn(move |_| { + let client = client.take(); + + async move { + if let Some(client) = client { + Ok(client) + } else { + Err(std::io::Error::new( + std::io::ErrorKind::Other, + "Client already taken", + )) + } + } + })) + .await + .unwrap(); + + let client = test1_client::Test1Client::new(channel); + + let client = if let Some(size) = case.client_recv_max { + client.max_decoding_message_size(size) + } else { + client + }; + + let mut client = if let Some(size) = case.client_send_max { + client.max_encoding_message_size(size) + } else { + client + }; + + let req = Request::new(Input1 { + buf: client_blob.clone(), + }); + + client.unary_call(req).await.map(|_| ()) +} diff --git a/tonic-build/src/client.rs b/tonic-build/src/client.rs index 417474916..17854a1fe 100644 --- a/tonic-build/src/client.rs +++ b/tonic-build/src/client.rs @@ -136,6 +136,8 @@ pub(crate) fn generate_internal( } /// Limits the maximum size of a decoded message. + /// + /// Default: `4MB` #[must_use] pub fn max_decoding_message_size(mut self, limit: usize) -> Self { self.inner = self.inner.max_decoding_message_size(limit); @@ -143,6 +145,8 @@ pub(crate) fn generate_internal( } /// Limits the maximum size of an encoded message. + /// + /// Default: `usize::MAX` #[must_use] pub fn max_encoding_message_size(mut self, limit: usize) -> Self { self.inner = self.inner.max_encoding_message_size(limit); diff --git a/tonic-build/src/server.rs b/tonic-build/src/server.rs index c521ff7c1..c86ca68f8 100644 --- a/tonic-build/src/server.rs +++ b/tonic-build/src/server.rs @@ -84,6 +84,8 @@ pub(crate) fn generate_internal( let configure_max_message_size_methods = quote! { /// Limits the maximum size of a decoded message. + /// + /// Default: `4MB` #[must_use] pub fn max_decoding_message_size(mut self, limit: usize) -> Self { self.max_decoding_message_size = Some(limit); @@ -91,6 +93,8 @@ pub(crate) fn generate_internal( } /// Limits the maximum size of an encoded message. + /// + /// Default: `usize::MAX` #[must_use] pub fn max_encoding_message_size(mut self, limit: usize) -> Self { self.max_encoding_message_size = Some(limit); diff --git a/tonic/src/codec/decode.rs b/tonic/src/codec/decode.rs index e4d9799e5..b6d24f530 100644 --- a/tonic/src/codec/decode.rs +++ b/tonic/src/codec/decode.rs @@ -1,5 +1,5 @@ use super::compression::{decompress, CompressionEncoding}; -use super::{DecodeBuf, Decoder, DEFAULT_MAX_MESSAGE_SIZE, HEADER_SIZE}; +use super::{DecodeBuf, Decoder, DEFAULT_MAX_RECV_MESSAGE_SIZE, HEADER_SIZE}; use crate::{body::BoxBody, metadata::MetadataMap, Code, Status}; use bytes::{Buf, BufMut, BytesMut}; use futures_core::Stream; @@ -174,7 +174,9 @@ impl StreamingInner { }; let len = self.buf.get_u32() as usize; - let limit = self.max_message_size.unwrap_or(DEFAULT_MAX_MESSAGE_SIZE); + let limit = self + .max_message_size + .unwrap_or(DEFAULT_MAX_RECV_MESSAGE_SIZE); if len > limit { return Err(Status::new( Code::OutOfRange, diff --git a/tonic/src/codec/encode.rs b/tonic/src/codec/encode.rs index 5efc40ef3..48080182b 100644 --- a/tonic/src/codec/encode.rs +++ b/tonic/src/codec/encode.rs @@ -1,5 +1,5 @@ use super::compression::{compress, CompressionEncoding, SingleMessageCompressionOverride}; -use super::{EncodeBuf, Encoder, DEFAULT_MAX_MESSAGE_SIZE, HEADER_SIZE}; +use super::{EncodeBuf, Encoder, DEFAULT_MAX_SEND_MESSAGE_SIZE, HEADER_SIZE}; use crate::{Code, Status}; use bytes::{BufMut, Bytes, BytesMut}; use futures_core::{Stream, TryStream}; @@ -141,7 +141,7 @@ fn finish_encoding( buf: &mut BytesMut, ) -> Result { let len = buf.len() - HEADER_SIZE; - let limit = max_message_size.unwrap_or(DEFAULT_MAX_MESSAGE_SIZE); + let limit = max_message_size.unwrap_or(DEFAULT_MAX_SEND_MESSAGE_SIZE); if len > limit { return Err(Status::new( Code::OutOfRange, diff --git a/tonic/src/codec/mod.rs b/tonic/src/codec/mod.rs index 30ca36a2c..306621329 100644 --- a/tonic/src/codec/mod.rs +++ b/tonic/src/codec/mod.rs @@ -30,7 +30,8 @@ const HEADER_SIZE: usize = std::mem::size_of::(); // The default maximum uncompressed size in bytes for a message. Defaults to 4MB. -const DEFAULT_MAX_MESSAGE_SIZE: usize = 4 * 1024 * 1024; +const DEFAULT_MAX_RECV_MESSAGE_SIZE: usize = 4 * 1024 * 1024; +const DEFAULT_MAX_SEND_MESSAGE_SIZE: usize = usize::MAX; /// Trait that knows how to encode and decode gRPC messages. pub trait Codec { diff --git a/tonic/src/lib.rs b/tonic/src/lib.rs index 4b00ff335..ab91d60b2 100644 --- a/tonic/src/lib.rs +++ b/tonic/src/lib.rs @@ -53,6 +53,14 @@ //! to build even more feature rich clients and servers. This module also provides the ability to //! enable TLS using [`rustls`], via the `tls` feature flag. //! +//! # Code generated client/server configuration +//! +//! ## Max Message Size +//! +//! Currently, both servers and clients can be configured to set the max message encoding and +//! decoding size. This will ensure that an incoming gRPC message will not exahust the systems +//! memory. By default, the decoding message limit is `4MB` and the encoding limit is `usize::MAX`. +//! //! [gRPC]: https://grpc.io //! [`tonic`]: https://github.com/hyperium/tonic //! [`tokio`]: https://docs.rs/tokio