-
Notifications
You must be signed in to change notification settings - Fork 0
/
lib.rs
97 lines (81 loc) · 2.53 KB
/
lib.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
use std::{
num::NonZeroU32,
sync::Arc,
task::{Context, Poll},
};
use alloy_json_rpc::{RequestPacket, ResponsePacket};
use alloy_transport::{TransportError, TransportFut};
use governor::{
clock::{QuantaClock, QuantaInstant},
middleware::NoOpMiddleware,
state::{InMemoryState, NotKeyed},
Jitter, Quota, RateLimiter,
};
use thiserror::Error;
use tower::{Layer, Service};
pub type Throttle =
RateLimiter<NotKeyed, InMemoryState, QuantaClock, NoOpMiddleware<QuantaInstant>>;
pub struct ThrottleLayer {
throttle: Arc<Throttle>,
jitter: Option<Jitter>,
}
#[derive(Debug, Error)]
pub enum ThrottleError {
#[error("Requests per second must be a non-zero positive integer")]
InvalidRequestsPerSecond,
}
impl ThrottleLayer {
pub fn new(requests_per_second: u32, jitter: Option<Jitter>) -> Result<Self, ThrottleError> {
let quota = NonZeroU32::new(requests_per_second)
.ok_or(ThrottleError::InvalidRequestsPerSecond)
.map(Quota::per_second)?;
let throttle = Arc::new(RateLimiter::direct(quota));
Ok(ThrottleLayer { throttle, jitter })
}
}
impl<S> Layer<S> for ThrottleLayer {
type Service = ThrottleService<S>;
fn layer(&self, inner: S) -> Self::Service {
ThrottleService {
inner,
throttle: self.throttle.clone(),
jitter: self.jitter,
}
}
}
/// A Tower Service used by the ThrottleLayer that is responsible for throttling rpc requests.
#[derive(Debug, Clone)]
pub struct ThrottleService<S> {
/// The inner service
inner: S,
throttle: Arc<Throttle>,
jitter: Option<Jitter>,
}
impl<S> Service<RequestPacket> for ThrottleService<S>
where
S: Service<RequestPacket, Response = ResponsePacket, Error = TransportError>
+ Send
+ 'static
+ Clone,
S::Future: Send + 'static,
{
type Response = ResponsePacket;
type Error = TransportError;
type Future = TransportFut<'static>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: RequestPacket) -> Self::Future {
let throttle = self.throttle.clone();
let jitter = self.jitter;
let mut inner = self.inner.clone();
Box::pin(async move {
if let Some(jitter) = jitter {
throttle.until_ready_with_jitter(jitter).await;
} else {
throttle.until_ready().await;
}
inner.call(request).await
})
}
}