Skip to content

Commit

Permalink
fix: cache credential resolution with the AWS credential provider
Browse files Browse the repository at this point in the history
`object_store` invokes `get_credential` on _every_ invocation of a
get/list/put/etc. The provider invocation for environment based
credentials is practically zero-cost, so this has no/low overhead.

In the case of the AssumeRoleProvider or any provider which has _some_
cost, such as an invocation of the AWS STS APIs, this can result in
rate-limiting or service quota exhaustion.

In order to prevent this, the credentials are attempted to be cached
only so long as they have no expired, which is defined in the
`aws_credential_types::Credential` struct

Signed-off-by: R. Tyler Croy <rtyler@brokenco.de>
Sponsored-by: Scribd Inc
  • Loading branch information
rtyler committed Nov 12, 2024
1 parent 7a3b3ec commit 64d7eca
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 4 deletions.
2 changes: 1 addition & 1 deletion crates/aws/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "deltalake-aws"
version = "0.4.1"
version = "0.4.2"
authors.workspace = true
keywords.workspace = true
readme.workspace = true
Expand Down
101 changes: 98 additions & 3 deletions crates/aws/src/credentials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
use std::collections::HashMap;
use std::str::FromStr;
use std::sync::Arc;
use std::time::{Duration, SystemTime};

use aws_config::default_provider::credentials::DefaultCredentialsChain;
use aws_config::meta::credentials::CredentialsProviderChain;
Expand All @@ -19,6 +20,7 @@ use deltalake_core::storage::object_store::{
};
use deltalake_core::storage::StorageOptions;
use deltalake_core::DeltaResult;
use tokio::sync::Mutex;
use tracing::log::*;

use crate::constants;
Expand All @@ -27,12 +29,21 @@ use crate::constants;
/// into a necessary [AwsCredential] type for configuring [object_store::aws::AmazonS3]
#[derive(Clone, Debug)]
pub(crate) struct AWSForObjectStore {
/// TODO: replace this with something with a credential cache instead of the sdkConfig
sdk_config: SdkConfig,
cache: Arc<Mutex<Option<Credentials>>>,
}

impl AWSForObjectStore {
pub(crate) fn new(sdk_config: SdkConfig) -> Self {
Self { sdk_config }
let cache = Arc::new(Mutex::new(None));
Self { sdk_config, cache }
}

/// Return true if a credential has been cached
async fn has_cached_credentials(&self) -> bool {
let guard = self.cache.lock().await;
(*guard).is_some()
}
}

Expand All @@ -43,10 +54,34 @@ impl CredentialProvider for AWSForObjectStore {
/// Provide the necessary configured credentials from the AWS SDK for use by
/// [object_store::aws::AmazonS3]
async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
debug!("AWSForObjectStore is unlocking..");
let mut guard = self.cache.lock().await;

if let Some(cached) = guard.as_ref() {
debug!("Located cached credentials");
let now = SystemTime::now();

// Credentials such as assume role credentials will have an expiry on them, whereas
// environmental provided credentials will *not*. In the latter case, it's still
// useful avoid running through the provider chain again, so in both cases we should
// still treat credentials as useful
if cached.expiry().unwrap_or(now) >= now {
debug!("Cached credentials are still valid, returning");
return Ok(Arc::new(Self::Credential {
key_id: cached.access_key_id().into(),
secret_key: cached.secret_access_key().into(),
token: cached.session_token().map(|o| o.to_string()),
}));
} else {
debug!("Cached credentials appear to no longer be valid, re-resolving");
}
}

let provider = self
.sdk_config
.credentials_provider()
.ok_or(ObjectStoreError::NotImplemented)?;

let credentials =
provider
.provide_credentials()
Expand All @@ -60,11 +95,15 @@ impl CredentialProvider for AWSForObjectStore {
credentials.access_key_id()
);

Ok(Arc::new(Self::Credential {
let result = Ok(Arc::new(Self::Credential {
key_id: credentials.access_key_id().into(),
secret_key: credentials.secret_access_key().into(),
token: credentials.session_token().map(|o| o.to_string()),
}))
}));

// Update the mutex before exiting with the new Credentials from the AWS provider
*guard = Some(credentials);
return result;
}
}

Expand Down Expand Up @@ -324,4 +363,60 @@ mod tests {
panic!("Could not retrieve credentials from the SdkConfig: {config:?}");
}
}

#[tokio::test]
async fn test_object_store_credential_provider() -> DeltaResult<()> {
let options = StorageOptions(hashmap! {
constants::AWS_ACCESS_KEY_ID.to_string() => "test_id".to_string(),
constants::AWS_SECRET_ACCESS_KEY.to_string() => "test_secret".to_string(),
});
let sdk_config = resolve_credentials(options)
.await
.expect("Failed to resolve credentijals for the test");
let provider = AWSForObjectStore::new(sdk_config);
let _credential = provider
.get_credential()
.await
.expect("Failed to produce a credential");
Ok(())
}

/// The [CredentialProvider] is called _repeatedly_ by the [object_store] create, in essence on
/// every get/put/list/etc operation, the `get_credential` function will be invoked.
///
/// In some cases, such as when assuming roles, this can result in an excessive amount of STS
/// API calls in the scenarios where the delta-rs process is performing a large number of S3
/// operations.
#[tokio::test]
async fn test_object_store_credential_provider_consistency() -> DeltaResult<()> {
let options = StorageOptions(hashmap! {
constants::AWS_ACCESS_KEY_ID.to_string() => "test_id".to_string(),
constants::AWS_SECRET_ACCESS_KEY.to_string() => "test_secret".to_string(),
});
let sdk_config = resolve_credentials(options)
.await
.expect("Failed to resolve credentijals for the test");
let provider = AWSForObjectStore::new(sdk_config);
let credential_a = provider
.get_credential()
.await
.expect("Failed to produce a credential");

assert!(
provider.has_cached_credentials().await,
"The provider should have cached the credential on the first call!"
);

let credential_b = provider
.get_credential()
.await
.expect("Failed to produce a credential");

assert_ne!(
Arc::as_ptr(&credential_a),
Arc::as_ptr(&credential_b),
"Repeated calls to get_credential() produced different results!"
);
Ok(())
}
}

0 comments on commit 64d7eca

Please sign in to comment.