Skip to content

Commit

Permalink
Merge branch 'main' into fix_json_parse
Browse files Browse the repository at this point in the history
  • Loading branch information
rtyler authored Apr 12, 2024
2 parents 6f4bccd + 64b3e54 commit 9fad206
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 7 deletions.
67 changes: 67 additions & 0 deletions crates/aws/src/credentials.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
use std::time::Duration;

use aws_config::{
ecs::EcsCredentialsProvider, environment::EnvironmentVariableCredentialsProvider,
imds::credentials::ImdsCredentialsProvider, meta::credentials::CredentialsProviderChain,
profile::ProfileFileCredentialsProvider, provider_config::ProviderConfig,
web_identity_token::WebIdentityTokenCredentialsProvider,
};
use aws_credential_types::provider::{self, ProvideCredentials};
use tracing::Instrument;

#[derive(Debug)]
pub struct ConfiguredCredentialChain {
provider_chain: CredentialsProviderChain,
}

impl ConfiguredCredentialChain {
pub fn new(disable_imds: bool, imds_timeout: u64, config: Option<ProviderConfig>) -> Self {
let conf = config.unwrap_or_default();

let env_provider = EnvironmentVariableCredentialsProvider::default();
let profile_provider = ProfileFileCredentialsProvider::builder()
.configure(&conf)
.build();
let web_identity_token_provider = WebIdentityTokenCredentialsProvider::builder()
.configure(&conf)
.build();
let ecs_provider = EcsCredentialsProvider::builder().configure(&conf).build();

let mut provider_chain = CredentialsProviderChain::first_try("Environment", env_provider)
.or_else("Profile", profile_provider)
.or_else("WebIdentityToken", web_identity_token_provider)
.or_else("EcsContainer", ecs_provider);
if !disable_imds {
let imds_provider = ImdsCredentialsProvider::builder()
.configure(&conf)
.imds_client(
aws_config::imds::Client::builder()
.connect_timeout(Duration::from_millis(imds_timeout))
.read_timeout(Duration::from_millis(imds_timeout))
.build(),
)
.build();
provider_chain = provider_chain.or_else("Ec2InstanceMetadata", imds_provider);
}

Self { provider_chain }
}

async fn credentials(&self) -> provider::Result {
self.provider_chain
.provide_credentials()
.instrument(tracing::debug_span!("provide_credentials", provider = %"default_chain"))
.await
}
}

impl ProvideCredentials for ConfiguredCredentialChain {
fn provide_credentials<'a>(
&'a self,
) -> aws_credential_types::provider::future::ProvideCredentials<'a>
where
Self: 'a,
{
aws_credential_types::provider::future::ProvideCredentials::new(self.credentials())
}
}
1 change: 1 addition & 0 deletions crates/aws/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Lock client implementation based on DynamoDb.

mod credentials;
pub mod errors;
pub mod logstore;
#[cfg(feature = "native-tls")]
Expand Down
83 changes: 80 additions & 3 deletions crates/aws/src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,19 +164,30 @@ impl S3StorageOptions {
let allow_unsafe_rename = str_option(options, s3_constants::AWS_S3_ALLOW_UNSAFE_RENAME)
.map(|val| str_is_truthy(&val))
.unwrap_or(false);

let disable_imds = str_option(options, s3_constants::AWS_EC2_METADATA_DISABLED)
.map(|val| str_is_truthy(&val))
.unwrap_or(false);
let imds_timeout =
Self::u64_or_default(options, s3_constants::AWS_EC2_METADATA_TIMEOUT, 100);
let credentials_provider =
crate::credentials::ConfiguredCredentialChain::new(disable_imds, imds_timeout, None);
#[cfg(feature = "native-tls")]
let sdk_config = execute_sdk_future(
aws_config::ConfigLoader::default()
aws_config::from_env()
.http_client(native::use_native_tls_client(
str_option(options, s3_constants::AWS_ALLOW_HTTP)
.map(|val| str_is_truthy(&val))
.unwrap_or(false),
))
.credentials_provider(credentials_provider)
.load(),
)?;
#[cfg(feature = "rustls")]
let sdk_config = execute_sdk_future(aws_config::load_from_env())?;
let sdk_config = execute_sdk_future(
aws_config::from_env()
.credentials_provider(credentials_provider)
.load(),
)?;

let sdk_config =
if let Some(endpoint_url) = str_option(options, s3_constants::AWS_ENDPOINT_URL) {
Expand Down Expand Up @@ -433,6 +444,14 @@ pub mod s3_constants {
/// Only safe if there is one writer to a given table.
pub const AWS_S3_ALLOW_UNSAFE_RENAME: &str = "AWS_S3_ALLOW_UNSAFE_RENAME";

/// If set to "true", disables the imds client
/// Defaults to "false"
pub const AWS_EC2_METADATA_DISABLED: &str = "AWS_EC2_METADATA_DISABLED";

/// The timeout in milliseconds for the EC2 metadata endpoint
/// Defaults to 100
pub const AWS_EC2_METADATA_TIMEOUT: &str = "AWS_EC2_METADATA_TIMEOUT";

/// The list of option keys owned by the S3 module.
/// Option keys not contained in this list will be added to the `extra_opts`
/// field of [crate::storage::s3::S3StorageOptions].
Expand All @@ -452,6 +471,8 @@ pub mod s3_constants {
AWS_S3_POOL_IDLE_TIMEOUT_SECONDS,
AWS_STS_POOL_IDLE_TIMEOUT_SECONDS,
AWS_S3_GET_INTERNAL_SERVER_ERROR_RETRIES,
AWS_EC2_METADATA_DISABLED,
AWS_EC2_METADATA_TIMEOUT,
];
}

Expand All @@ -462,8 +483,11 @@ pub(crate) fn str_option(map: &HashMap<String, String>, key: &str) -> Option<Str

#[cfg(test)]
mod tests {
use std::time::SystemTime;

use super::*;

use aws_sdk_sts::config::ProvideCredentials;
use maplit::hashmap;
use serial_test::serial;

Expand All @@ -481,6 +505,15 @@ mod tests {
let _env_scope = Self::new();
f()
}

pub async fn run_async<F>(future: F) -> F::Output
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let _env_scope = Self::new();
future.await
}
}

impl Drop for ScopedEnv {
Expand Down Expand Up @@ -744,4 +777,48 @@ mod tests {
}
});
}

#[tokio::test]
#[serial]
async fn storage_options_toggle_imds() {
ScopedEnv::run_async(async {
clear_env_of_aws_keys();
let disabled_time = storage_options_configure_imds(Some("true")).await;
let enabled_time = storage_options_configure_imds(Some("false")).await;
let default_time = storage_options_configure_imds(None).await;
println!(
"enabled_time: {}, disabled_time: {}, default_time: {}",
enabled_time.as_micros(),
disabled_time.as_micros(),
default_time.as_micros(),
);
assert!(disabled_time < enabled_time);
assert!(disabled_time < default_time);
})
.await;
}

async fn storage_options_configure_imds(value: Option<&str>) -> Duration {
let _options = match value {
Some(value) => S3StorageOptions::from_map(&hashmap! {
s3_constants::AWS_REGION.to_string() => "eu-west-1".to_string(),
s3_constants::AWS_EC2_METADATA_DISABLED.to_string() => value.to_string(),
})
.unwrap(),
None => S3StorageOptions::from_map(&hashmap! {
s3_constants::AWS_REGION.to_string() => "eu-west-1".to_string(),
})
.unwrap(),
};

assert_eq!(
"eu-west-1",
std::env::var(s3_constants::AWS_REGION).unwrap()
);

let provider = _options.sdk_config.credentials_provider().unwrap();
let now = SystemTime::now();
_ = provider.provide_credentials().await;
now.elapsed().unwrap()
}
}
14 changes: 10 additions & 4 deletions crates/core/src/kernel/snapshot/log_segment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -453,10 +453,16 @@ async fn list_log_files_with_checkpoint(
})
.collect_vec();

// TODO raise a proper error
assert_eq!(checkpoint_files.len(), cp.parts.unwrap_or(1) as usize);

Ok((commit_files, checkpoint_files))
if checkpoint_files.len() != cp.parts.unwrap_or(1) as usize {
let msg = format!(
"Number of checkpoint files '{}' is not equal to number of checkpoint metadata parts '{:?}'",
checkpoint_files.len(),
cp.parts
);
Err(DeltaTableError::MetadataError(msg))
} else {
Ok((commit_files, checkpoint_files))
}
}

/// List relevant log files.
Expand Down

0 comments on commit 9fad206

Please sign in to comment.