Skip to content

Commit 2bd66e4

Browse files
authored
feat: add support for tower's load-shed layer (#2189)
Refs: #1616
1 parent 689a86d commit 2bd66e4

File tree

4 files changed

+104
-2
lines changed

4 files changed

+104
-2
lines changed
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
use integration_tests::pb::{test_client, test_server, Input, Output};
2+
use std::net::SocketAddr;
3+
use tokio::net::TcpListener;
4+
use tonic::{transport::Server, Code, Request, Response, Status};
5+
6+
#[tokio::test]
7+
async fn service_resource_exhausted() {
8+
let addr = run_service_in_background(0).await;
9+
10+
let mut client = test_client::TestClient::connect(format!("http://{}", addr))
11+
.await
12+
.unwrap();
13+
14+
let req = Request::new(Input {});
15+
let res = client.unary_call(req).await;
16+
17+
let err = res.unwrap_err();
18+
assert_eq!(err.code(), Code::ResourceExhausted);
19+
}
20+
21+
#[tokio::test]
22+
async fn service_resource_not_exhausted() {
23+
let addr = run_service_in_background(1).await;
24+
25+
let mut client = test_client::TestClient::connect(format!("http://{}", addr))
26+
.await
27+
.unwrap();
28+
29+
let req = Request::new(Input {});
30+
let res = client.unary_call(req).await;
31+
32+
assert!(res.is_ok());
33+
}
34+
35+
async fn run_service_in_background(concurrency_limit: usize) -> SocketAddr {
36+
struct Svc;
37+
38+
#[tonic::async_trait]
39+
impl test_server::Test for Svc {
40+
async fn unary_call(&self, _req: Request<Input>) -> Result<Response<Output>, Status> {
41+
Ok(Response::new(Output {}))
42+
}
43+
}
44+
45+
let svc = test_server::TestServer::new(Svc {});
46+
47+
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
48+
let addr = listener.local_addr().unwrap();
49+
50+
tokio::spawn(async move {
51+
Server::builder()
52+
.concurrency_limit_per_connection(concurrency_limit)
53+
.load_shed(true)
54+
.add_service(svc)
55+
.serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener))
56+
.await
57+
.unwrap();
58+
});
59+
60+
addr
61+
}

tonic/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,12 @@ server = [
3939
"dep:socket2",
4040
"dep:tokio", "tokio?/macros", "tokio?/net", "tokio?/time",
4141
"tokio-stream/net",
42-
"dep:tower", "tower?/util", "tower?/limit",
42+
"dep:tower", "tower?/util", "tower?/limit", "tower?/load-shed",
4343
]
4444
channel = [
4545
"dep:hyper", "hyper?/client",
4646
"dep:hyper-util", "hyper-util?/client-legacy",
47-
"dep:tower", "tower?/balance", "tower?/buffer", "tower?/discover", "tower?/limit", "tower?/util",
47+
"dep:tower", "tower?/balance", "tower?/buffer", "tower?/discover", "tower?/limit", "tower?/load-shed", "tower?/util",
4848
"dep:tokio", "tokio?/time",
4949
"dep:hyper-timeout",
5050
]

tonic/src/status.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,18 @@ impl Status {
348348
Err(err) => err,
349349
};
350350

351+
// If the load shed middleware is enabled, respond to
352+
// service overloaded with an appropriate grpc status.
353+
#[cfg(feature = "server")]
354+
let err = match err.downcast::<tower::load_shed::error::Overloaded>() {
355+
Ok(_) => {
356+
return Ok(Status::resource_exhausted(
357+
"Too many active requests for the connection",
358+
));
359+
}
360+
Err(err) => err,
361+
};
362+
351363
if let Some(mut status) = find_status_in_source_chain(&*err) {
352364
status.source = Some(err.into());
353365
return Ok(status);

tonic/src/transport/server/mod.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ use tower::{
6666
layer::util::{Identity, Stack},
6767
layer::Layer,
6868
limit::concurrency::ConcurrencyLimitLayer,
69+
load_shed::LoadShedLayer,
6970
util::BoxCloneService,
7071
Service, ServiceBuilder, ServiceExt,
7172
};
@@ -87,6 +88,7 @@ const DEFAULT_HTTP2_KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(20);
8788
pub struct Server<L = Identity> {
8889
trace_interceptor: Option<TraceInterceptor>,
8990
concurrency_limit: Option<usize>,
91+
load_shed: bool,
9092
timeout: Option<Duration>,
9193
#[cfg(feature = "_tls-any")]
9294
tls: Option<TlsAcceptor>,
@@ -111,6 +113,7 @@ impl Default for Server<Identity> {
111113
Self {
112114
trace_interceptor: None,
113115
concurrency_limit: None,
116+
load_shed: false,
114117
timeout: None,
115118
#[cfg(feature = "_tls-any")]
116119
tls: None,
@@ -179,6 +182,27 @@ impl<L> Server<L> {
179182
}
180183
}
181184

185+
/// Enable or disable load shedding. The default is disabled.
186+
///
187+
/// When load shedding is enabled, if the service responds with not ready
188+
/// the request will immediately be rejected with a
189+
/// [`resource_exhausted`](https://docs.rs/tonic/latest/tonic/struct.Status.html#method.resource_exhausted) error.
190+
/// The default is to buffer requests. This is especially useful in combination with
191+
/// setting a concurrency limit per connection.
192+
///
193+
/// # Example
194+
///
195+
/// ```
196+
/// # use tonic::transport::Server;
197+
/// # use tower_service::Service;
198+
/// # let builder = Server::builder();
199+
/// builder.load_shed(true);
200+
/// ```
201+
#[must_use]
202+
pub fn load_shed(self, load_shed: bool) -> Self {
203+
Server { load_shed, ..self }
204+
}
205+
182206
/// Set a timeout on for all request handlers.
183207
///
184208
/// # Example
@@ -514,6 +538,7 @@ impl<L> Server<L> {
514538
service_builder: self.service_builder.layer(new_layer),
515539
trace_interceptor: self.trace_interceptor,
516540
concurrency_limit: self.concurrency_limit,
541+
load_shed: self.load_shed,
517542
timeout: self.timeout,
518543
#[cfg(feature = "_tls-any")]
519544
tls: self.tls,
@@ -643,6 +668,7 @@ impl<L> Server<L> {
643668
{
644669
let trace_interceptor = self.trace_interceptor.clone();
645670
let concurrency_limit = self.concurrency_limit;
671+
let load_shed = self.load_shed;
646672
let init_connection_window_size = self.init_connection_window_size;
647673
let init_stream_window_size = self.init_stream_window_size;
648674
let max_concurrent_streams = self.max_concurrent_streams;
@@ -667,6 +693,7 @@ impl<L> Server<L> {
667693
let mut svc = MakeSvc {
668694
inner: svc,
669695
concurrency_limit,
696+
load_shed,
670697
timeout,
671698
trace_interceptor,
672699
_io: PhantomData,
@@ -1047,6 +1074,7 @@ impl<S> fmt::Debug for Svc<S> {
10471074
#[derive(Clone)]
10481075
struct MakeSvc<S, IO> {
10491076
concurrency_limit: Option<usize>,
1077+
load_shed: bool,
10501078
timeout: Option<Duration>,
10511079
inner: S,
10521080
trace_interceptor: Option<TraceInterceptor>,
@@ -1080,6 +1108,7 @@ where
10801108

10811109
let svc = ServiceBuilder::new()
10821110
.layer(RecoverErrorLayer::new())
1111+
.option_layer(self.load_shed.then_some(LoadShedLayer::new()))
10831112
.option_layer(concurrency_limit.map(ConcurrencyLimitLayer::new))
10841113
.layer_fn(|s| GrpcTimeout::new(s, timeout))
10851114
.service(svc);

0 commit comments

Comments
 (0)