Skip to content

Commit c5564e7

Browse files
authored
Merge pull request #20 from G8XSU/retry
Add Retry utility with RetryPolicy definition
2 parents f4c561d + e7fa784 commit c5564e7

File tree

4 files changed

+323
-0
lines changed

4 files changed

+323
-0
lines changed

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ build = "build.rs"
1414
[dependencies]
1515
prost = "0.11.6"
1616
reqwest = { version = "0.11.13", features = ["rustls-tls"] }
17+
tokio = { version = "1", default-features = false, features = ["time"] }
18+
rand = "0.8.5"
1719

1820
[target.'cfg(genproto)'.build-dependencies]
1921
prost-build = { version = "0.11.3" }

src/util/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,6 @@
22
///
33
/// [`StorableBuilder`]: storable_builder::StorableBuilder
44
pub mod storable_builder;
5+
6+
/// Contains retry utilities.
7+
pub mod retry;

src/util/retry.rs

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
use rand::Rng;
2+
use std::error::Error;
3+
use std::future::Future;
4+
use std::marker::PhantomData;
5+
use std::time::Duration;
6+
7+
/// A function that performs and retries the given operation according to a retry policy.
8+
///
9+
/// **Caution**: A retry policy without the number of attempts capped by [`MaxAttemptsRetryPolicy`]
10+
/// decorator will result in infinite retries.
11+
///
12+
/// **Example**
13+
/// ```rust
14+
/// # use std::time::Duration;
15+
/// # use vss_client::error::VssError;
16+
/// # use vss_client::util::retry::{ExponentialBackoffRetryPolicy, retry, RetryPolicy};
17+
/// #
18+
/// # async fn operation() -> Result<i32, VssError> {
19+
/// # tokio::time::sleep(Duration::from_millis(10)).await;
20+
/// # Ok(42)
21+
/// # }
22+
/// #
23+
/// let retry_policy = ExponentialBackoffRetryPolicy::new(Duration::from_millis(100))
24+
/// .with_max_attempts(5)
25+
/// .with_max_total_delay(Duration::from_secs(2))
26+
/// .with_max_jitter(Duration::from_millis(30))
27+
/// .skip_retry_on_error(|e| matches!(e, VssError::InvalidRequestError(..)));
28+
///
29+
/// let result = retry(operation, &retry_policy);
30+
///```
31+
pub async fn retry<R, F, Fut, T, E>(mut operation: F, retry_policy: &R) -> Result<T, E>
32+
where
33+
R: RetryPolicy<E = E>,
34+
F: FnMut() -> Fut,
35+
Fut: Future<Output = Result<T, E>>,
36+
E: Error,
37+
{
38+
let mut attempts_made = 0;
39+
let mut accumulated_delay = Duration::ZERO;
40+
loop {
41+
match operation().await {
42+
Ok(result) => return Ok(result),
43+
Err(err) => {
44+
attempts_made += 1;
45+
if let Some(delay) =
46+
retry_policy.next_delay(&RetryContext { attempts_made, accumulated_delay, error: &err })
47+
{
48+
tokio::time::sleep(delay).await;
49+
accumulated_delay += delay;
50+
} else {
51+
return Err(err);
52+
}
53+
}
54+
}
55+
}
56+
}
57+
58+
/// Provides the logic for how and when to perform retries.
59+
pub trait RetryPolicy: Sized {
60+
/// The error type returned by the `operation` in `retry`.
61+
type E: Error;
62+
63+
/// Returns the duration to wait before trying the next attempt.
64+
/// `context` represents the context of a retry operation.
65+
///
66+
/// If `None` is returned then no further retry attempt is made.
67+
fn next_delay(&self, context: &RetryContext<Self::E>) -> Option<Duration>;
68+
69+
/// Returns a new `RetryPolicy` that respects the given maximum attempts.
70+
fn with_max_attempts(self, max_attempts: u32) -> MaxAttemptsRetryPolicy<Self> {
71+
MaxAttemptsRetryPolicy { inner_policy: self, max_attempts }
72+
}
73+
74+
/// Returns a new `RetryPolicy` that respects the given total delay.
75+
fn with_max_total_delay(self, max_total_delay: Duration) -> MaxTotalDelayRetryPolicy<Self> {
76+
MaxTotalDelayRetryPolicy { inner_policy: self, max_total_delay }
77+
}
78+
79+
/// Returns a new `RetryPolicy` that adds jitter(random delay) to underlying policy.
80+
fn with_max_jitter(self, max_jitter: Duration) -> JitteredRetryPolicy<Self> {
81+
JitteredRetryPolicy { inner_policy: self, max_jitter }
82+
}
83+
84+
/// Skips retrying on errors that evaluate to `true` after applying `function`.
85+
fn skip_retry_on_error<F>(self, function: F) -> FilteredRetryPolicy<Self, F>
86+
where
87+
F: 'static + Fn(&Self::E) -> bool,
88+
{
89+
FilteredRetryPolicy { inner_policy: self, function }
90+
}
91+
}
92+
93+
/// Represents the context of a retry operation.
94+
///
95+
/// The context holds key information about the retry operation
96+
/// such as how many attempts have been made until now, the accumulated
97+
/// delay between retries, and the error that triggered the retry.
98+
pub struct RetryContext<'a, E: Error> {
99+
/// The number attempts made until now, before attempting the next retry.
100+
attempts_made: u32,
101+
102+
/// The amount of artificial delay we have already waited in between previous
103+
/// attempts. Does not include the time taken to execute the operation.
104+
accumulated_delay: Duration,
105+
106+
/// The error encountered in the previous attempt.
107+
error: &'a E,
108+
}
109+
110+
/// The exponential backoff strategy is a retry approach that doubles the delay between retries.
111+
/// A combined exponential backoff and jitter strategy is recommended that is ["Exponential Backoff and Jitter"](https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/).
112+
/// This is helpful to avoid [Thundering Herd Problem](https://en.wikipedia.org/wiki/Thundering_herd_problem).
113+
pub struct ExponentialBackoffRetryPolicy<E> {
114+
/// The base delay duration for the backoff algorithm. First retry is `base_delay` after first attempt.
115+
base_delay: Duration,
116+
phantom: PhantomData<E>,
117+
}
118+
119+
impl<E: Error> ExponentialBackoffRetryPolicy<E> {
120+
/// Constructs a new instance using `base_delay`.
121+
///
122+
/// `base_delay` is the base delay duration for the backoff algorithm. First retry is `base_delay`
123+
/// after first attempt.
124+
pub fn new(base_delay: Duration) -> ExponentialBackoffRetryPolicy<E> {
125+
Self { base_delay, phantom: PhantomData }
126+
}
127+
}
128+
129+
impl<E: Error> RetryPolicy for ExponentialBackoffRetryPolicy<E> {
130+
type E = E;
131+
fn next_delay(&self, context: &RetryContext<Self::E>) -> Option<Duration> {
132+
let backoff_factor = 2_u32.pow(context.attempts_made) - 1;
133+
let delay = self.base_delay * backoff_factor;
134+
Some(delay)
135+
}
136+
}
137+
138+
/// Decorates the given `RetryPolicy` to respect the given maximum attempts.
139+
pub struct MaxAttemptsRetryPolicy<T: RetryPolicy> {
140+
/// The underlying retry policy to use.
141+
inner_policy: T,
142+
/// The maximum number of attempts to retry.
143+
max_attempts: u32,
144+
}
145+
146+
impl<T: RetryPolicy> RetryPolicy for MaxAttemptsRetryPolicy<T> {
147+
type E = T::E;
148+
fn next_delay(&self, context: &RetryContext<Self::E>) -> Option<Duration> {
149+
if self.max_attempts == context.attempts_made {
150+
None
151+
} else {
152+
self.inner_policy.next_delay(context)
153+
}
154+
}
155+
}
156+
157+
/// Decorates the given `RetryPolicy` to respect the given maximum total delay.
158+
pub struct MaxTotalDelayRetryPolicy<T: RetryPolicy> {
159+
/// The underlying retry policy to use.
160+
inner_policy: T,
161+
/// The maximum accumulated delay that will be allowed over all attempts.
162+
max_total_delay: Duration,
163+
}
164+
165+
impl<T: RetryPolicy> RetryPolicy for MaxTotalDelayRetryPolicy<T> {
166+
type E = T::E;
167+
fn next_delay(&self, context: &RetryContext<Self::E>) -> Option<Duration> {
168+
let next_delay = self.inner_policy.next_delay(context);
169+
if let Some(next_delay) = next_delay {
170+
if self.max_total_delay < context.accumulated_delay + next_delay {
171+
return None;
172+
}
173+
}
174+
next_delay
175+
}
176+
}
177+
178+
/// Decorates the given `RetryPolicy` and adds jitter (random delay) to it. This can make retries
179+
/// more spread out and less likely to all fail at once.
180+
pub struct JitteredRetryPolicy<T: RetryPolicy> {
181+
/// The underlying retry policy to use.
182+
inner_policy: T,
183+
/// The maximum amount of random jitter to apply to the delay.
184+
max_jitter: Duration,
185+
}
186+
187+
impl<T: RetryPolicy> RetryPolicy for JitteredRetryPolicy<T> {
188+
type E = T::E;
189+
fn next_delay(&self, context: &RetryContext<Self::E>) -> Option<Duration> {
190+
if let Some(base_delay) = self.inner_policy.next_delay(context) {
191+
let mut rng = rand::thread_rng();
192+
let jitter = Duration::from_micros(rng.gen_range(0..self.max_jitter.as_micros() as u64));
193+
Some(base_delay + jitter)
194+
} else {
195+
None
196+
}
197+
}
198+
}
199+
200+
/// Decorates the given `RetryPolicy` by not retrying on errors that match the given function.
201+
pub struct FilteredRetryPolicy<T: RetryPolicy, F> {
202+
inner_policy: T,
203+
function: F,
204+
}
205+
206+
impl<T, F, E> RetryPolicy for FilteredRetryPolicy<T, F>
207+
where
208+
T: RetryPolicy<E = E>,
209+
F: Fn(&E) -> bool,
210+
E: Error,
211+
{
212+
type E = T::E;
213+
fn next_delay(&self, context: &RetryContext<E>) -> Option<Duration> {
214+
if (self.function)(&context.error) {
215+
None
216+
} else {
217+
self.inner_policy.next_delay(context)
218+
}
219+
}
220+
}

tests/retry_tests.rs

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
#[cfg(test)]
2+
mod retry_tests {
3+
use std::io;
4+
use std::sync::atomic::{AtomicU32, Ordering};
5+
use std::sync::Arc;
6+
use std::time::Duration;
7+
8+
use vss_client::error::VssError;
9+
use vss_client::util::retry::{retry, ExponentialBackoffRetryPolicy, RetryPolicy};
10+
11+
#[tokio::test]
12+
async fn test_async_retry() {
13+
let base_delay = Duration::from_millis(10);
14+
let max_attempts = 3;
15+
let max_total_delay = Duration::from_secs(60);
16+
let max_jitter = Duration::from_millis(5);
17+
18+
let exponential_backoff_jitter_policy = ExponentialBackoffRetryPolicy::new(base_delay)
19+
.skip_retry_on_error(|e| matches!(e, VssError::InvalidRequestError(..)))
20+
.with_max_attempts(max_attempts)
21+
.with_max_total_delay(max_total_delay)
22+
.with_max_jitter(max_jitter);
23+
24+
let mut call_count = Arc::new(AtomicU32::new(0));
25+
let count = call_count.clone();
26+
let async_function = move || {
27+
let count = count.clone();
28+
async move {
29+
let attempts_made = count.fetch_add(1, Ordering::SeqCst);
30+
if attempts_made < max_attempts - 1 {
31+
return Err(VssError::InternalServerError("Failure".to_string()));
32+
}
33+
tokio::time::sleep(Duration::from_millis(100)).await;
34+
Ok(42)
35+
}
36+
};
37+
38+
let result = retry(async_function, &exponential_backoff_jitter_policy).await;
39+
assert_eq!(result.ok(), Some(42));
40+
assert_eq!(call_count.load(Ordering::SeqCst), max_attempts);
41+
42+
call_count = Arc::new(AtomicU32::new(0));
43+
let count = call_count.clone();
44+
let failing_async_function = move || {
45+
let count = count.clone();
46+
async move {
47+
count.fetch_add(1, Ordering::SeqCst);
48+
tokio::time::sleep(Duration::from_millis(100)).await;
49+
Err::<(), VssError>(VssError::InternalServerError("Failed".to_string()))
50+
}
51+
};
52+
53+
let failed_result = retry(failing_async_function, &exponential_backoff_jitter_policy).await;
54+
assert!(failed_result.is_err());
55+
assert_eq!(call_count.load(Ordering::SeqCst), 3);
56+
}
57+
58+
#[tokio::test]
59+
async fn test_retry_on_all_errors() {
60+
let retry_policy = ExponentialBackoffRetryPolicy::new(Duration::from_millis(10)).with_max_attempts(3);
61+
62+
let call_count = Arc::new(AtomicU32::new(0));
63+
let count = call_count.clone();
64+
let failing_async_function = move || {
65+
let count = count.clone();
66+
async move {
67+
count.fetch_add(1, Ordering::SeqCst);
68+
tokio::time::sleep(Duration::from_millis(100)).await;
69+
Err::<(), io::Error>(io::Error::new(io::ErrorKind::InvalidData, "Failure"))
70+
}
71+
};
72+
73+
let failed_result = retry(failing_async_function, &retry_policy).await;
74+
assert!(failed_result.is_err());
75+
assert_eq!(call_count.load(Ordering::SeqCst), 3);
76+
}
77+
78+
#[tokio::test]
79+
async fn test_retry_capped_by_max_total_delay() {
80+
let retry_policy = ExponentialBackoffRetryPolicy::new(Duration::from_millis(100))
81+
.with_max_total_delay(Duration::from_millis(350));
82+
83+
let call_count = Arc::new(AtomicU32::new(0));
84+
let count = call_count.clone();
85+
let failing_async_function = move || {
86+
let count = count.clone();
87+
async move {
88+
count.fetch_add(1, Ordering::SeqCst);
89+
tokio::time::sleep(Duration::from_millis(100)).await;
90+
Err::<(), VssError>(VssError::InternalServerError("Failed".to_string()))
91+
}
92+
};
93+
94+
let failed_result = retry(failing_async_function, &retry_policy).await;
95+
assert!(failed_result.is_err());
96+
assert_eq!(call_count.load(Ordering::SeqCst), 2);
97+
}
98+
}

0 commit comments

Comments
 (0)