Skip to content

Commit 11029ed

Browse files
authored
Implement initial lazy caching credentials provider (#578)
* Implement initial lazy caching credentials provider * Rename TimeProvider to TimeSource * Move TimeSource to its own module * Eliminate Inner layer and add expiry_mut to Credentials * Move Cache to its own module and fix multithreading issue * Add comments * Make refresh_timeout unimplemented * Combine Provider with LazyCachingCredentialsProvider * CR feedback
1 parent f08a60c commit 11029ed

File tree

7 files changed

+558
-8
lines changed

7 files changed

+558
-8
lines changed
Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
[package]
22
name = "aws-auth"
33
version = "0.1.0"
4-
authors = ["Russell Cohen <rcoh@amazon.com>"]
4+
authors = ["AWS Rust SDK Team <aws-sdk-rust@amazon.com>", "Russell Cohen <rcoh@amazon.com>"]
55
license = "Apache-2.0"
66
edition = "2018"
77

88
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
9-
109
[dependencies]
1110
smithy-http = { path = "../../../rust-runtime/smithy-http" }
11+
tokio = { version = "1", features = ["sync"] }
12+
tracing = "0.1.25"
1213
zeroize = "1.2.0"
1314

1415
[dev-dependencies]
15-
http = "0.2.3"
16-
tokio = { version = "1.0", features = ["rt", "macros"] }
1716
async-trait = "0.1.50"
17+
env_logger = "*"
18+
http = "0.2.3"
19+
test-env-log = { version = "0.2.7", features = ["trace"] }
20+
tokio = { version = "1", features = ["macros", "rt", "rt-multi-thread", "test-util"] }
21+
tracing-subscriber = { version = "0.2.16", features = ["fmt"] }

aws/rust-runtime/aws-auth/src/credentials.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use zeroize::Zeroizing;
1919
#[derive(Clone)]
2020
pub struct Credentials(Arc<Inner>);
2121

22+
#[derive(Clone)]
2223
struct Inner {
2324
access_key_id: Zeroizing<String>,
2425
secret_access_key: Zeroizing<String>,
@@ -89,6 +90,10 @@ impl Credentials {
8990
self.0.expires_after
9091
}
9192

93+
pub fn expiry_mut(&mut self) -> &mut Option<SystemTime> {
94+
&mut Arc::make_mut(&mut self.0).expires_after
95+
}
96+
9297
pub fn session_token(&self) -> Option<&str> {
9398
self.0.session_token.as_deref()
9499
}

aws/rust-runtime/aws-auth/src/provider.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
* SPDX-License-Identifier: Apache-2.0.
44
*/
55

6+
mod cache;
67
pub mod env;
8+
pub mod lazy_caching;
9+
mod time;
710

811
use crate::Credentials;
912
use smithy_http::property_bag::PropertyBag;
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0.
4+
*/
5+
6+
use crate::provider::CredentialsResult;
7+
use crate::Credentials;
8+
use std::future::Future;
9+
use std::sync::Arc;
10+
use std::time::{Duration, SystemTime};
11+
use tokio::sync::{OnceCell, RwLock};
12+
13+
#[derive(Clone)]
14+
pub(super) struct Cache {
15+
/// Amount of time before the actual credential expiration time
16+
/// where credentials are considered expired.
17+
buffer_time: Duration,
18+
value: Arc<RwLock<OnceCell<Credentials>>>,
19+
}
20+
21+
impl Cache {
22+
pub fn new(buffer_time: Duration) -> Cache {
23+
Cache {
24+
buffer_time,
25+
value: Arc::new(RwLock::new(OnceCell::new())),
26+
}
27+
}
28+
29+
#[cfg(test)]
30+
async fn get(&self) -> Option<Credentials> {
31+
self.value.read().await.get().cloned()
32+
}
33+
34+
/// Attempts to refresh the cached credentials with the given async future.
35+
/// If multiple threads attempt to refresh at the same time, one of them will win,
36+
/// and the others will await that thread's result rather than multiple refreshes occurring.
37+
/// The function given to acquire a credentials future, `f`, will not be called
38+
/// if another thread is chosen to load the credentials.
39+
pub async fn get_or_load<F, Fut>(&self, f: F) -> CredentialsResult
40+
where
41+
F: FnOnce() -> Fut,
42+
Fut: Future<Output = CredentialsResult>,
43+
{
44+
let lock = self.value.read().await;
45+
let future = lock.get_or_try_init(f);
46+
future.await.map(|credentials| credentials.clone())
47+
}
48+
49+
/// If the credentials are expired, clears the cache. Otherwise, yields the current credentials value.
50+
pub async fn yield_or_clear_if_expired(&self, now: SystemTime) -> Option<Credentials> {
51+
// Short-circuit if the credential is not expired
52+
if let Some(credentials) = self.value.read().await.get() {
53+
if !expired(credentials, self.buffer_time, now) {
54+
return Some(credentials.clone());
55+
}
56+
}
57+
58+
// Acquire a write lock to clear the cache, but then once the lock is acquired,
59+
// check again that the credential is not already cleared. If it has been cleared,
60+
// then another thread is refreshing the cache by the time the write lock was acquired.
61+
let mut lock = self.value.write().await;
62+
if let Some(credentials) = lock.get() {
63+
// Also check that we're clearing the expired credentials and not credentials
64+
// that have been refreshed by another thread.
65+
if expired(credentials, self.buffer_time, now) {
66+
*lock = OnceCell::new();
67+
}
68+
}
69+
None
70+
}
71+
}
72+
73+
fn expired(credentials: &Credentials, buffer_time: Duration, now: SystemTime) -> bool {
74+
credentials
75+
.expiry()
76+
.map(|expiration| now >= (expiration - buffer_time))
77+
.expect("Cached credentials don't have an expiration time. This is a bug in aws-auth.")
78+
}
79+
80+
#[cfg(test)]
81+
mod tests {
82+
use super::{expired, Cache};
83+
use crate::Credentials;
84+
use std::time::{Duration, SystemTime};
85+
86+
fn credentials(expired_secs: u64) -> Credentials {
87+
Credentials::new("test", "test", None, Some(epoch_secs(expired_secs)), "test")
88+
}
89+
90+
fn epoch_secs(secs: u64) -> SystemTime {
91+
SystemTime::UNIX_EPOCH + Duration::from_secs(secs)
92+
}
93+
94+
#[test]
95+
fn expired_check() {
96+
let creds = credentials(100);
97+
assert!(expired(&creds, Duration::from_secs(10), epoch_secs(1000)));
98+
assert!(expired(&creds, Duration::from_secs(10), epoch_secs(90)));
99+
assert!(!expired(&creds, Duration::from_secs(10), epoch_secs(10)));
100+
}
101+
102+
#[test_env_log::test(tokio::test)]
103+
async fn cache_clears_if_expired_only() {
104+
let cache = Cache::new(Duration::from_secs(10));
105+
assert!(cache
106+
.yield_or_clear_if_expired(epoch_secs(100))
107+
.await
108+
.is_none());
109+
110+
cache
111+
.get_or_load(|| async { Ok(credentials(100)) })
112+
.await
113+
.unwrap();
114+
assert_eq!(Some(epoch_secs(100)), cache.get().await.unwrap().expiry());
115+
116+
// It should not clear the credentials if they're not expired
117+
assert_eq!(
118+
Some(epoch_secs(100)),
119+
cache
120+
.yield_or_clear_if_expired(epoch_secs(10))
121+
.await
122+
.unwrap()
123+
.expiry()
124+
);
125+
126+
// It should clear the credentials if they're expired
127+
assert!(cache
128+
.yield_or_clear_if_expired(epoch_secs(500))
129+
.await
130+
.is_none());
131+
assert!(cache.get().await.is_none());
132+
}
133+
}

0 commit comments

Comments
 (0)