diff --git a/crates/aws/src/credentials.rs b/crates/aws/src/credentials.rs new file mode 100644 index 0000000000..4f3d425f3c --- /dev/null +++ b/crates/aws/src/credentials.rs @@ -0,0 +1,67 @@ +use std::time::Duration; + +use aws_config::{ + ecs::EcsCredentialsProvider, environment::EnvironmentVariableCredentialsProvider, + imds::credentials::ImdsCredentialsProvider, meta::credentials::CredentialsProviderChain, + profile::ProfileFileCredentialsProvider, provider_config::ProviderConfig, + web_identity_token::WebIdentityTokenCredentialsProvider, +}; +use aws_credential_types::provider::{self, ProvideCredentials}; +use tracing::Instrument; + +#[derive(Debug)] +pub struct ConfiguredCredentialChain { + provider_chain: CredentialsProviderChain, +} + +impl ConfiguredCredentialChain { + pub fn new(disable_imds: bool, imds_timeout: u64, config: Option) -> Self { + let conf = config.unwrap_or_default(); + + let env_provider = EnvironmentVariableCredentialsProvider::default(); + let profile_provider = ProfileFileCredentialsProvider::builder() + .configure(&conf) + .build(); + let web_identity_token_provider = WebIdentityTokenCredentialsProvider::builder() + .configure(&conf) + .build(); + let ecs_provider = EcsCredentialsProvider::builder().configure(&conf).build(); + + let mut provider_chain = CredentialsProviderChain::first_try("Environment", env_provider) + .or_else("Profile", profile_provider) + .or_else("WebIdentityToken", web_identity_token_provider) + .or_else("EcsContainer", ecs_provider); + if !disable_imds { + let imds_provider = ImdsCredentialsProvider::builder() + .configure(&conf) + .imds_client( + aws_config::imds::Client::builder() + .connect_timeout(Duration::from_millis(imds_timeout)) + .read_timeout(Duration::from_millis(imds_timeout)) + .build(), + ) + .build(); + provider_chain = provider_chain.or_else("Ec2InstanceMetadata", imds_provider); + } + + Self { provider_chain } + } + + async fn credentials(&self) -> provider::Result { + self.provider_chain + .provide_credentials() + .instrument(tracing::debug_span!("provide_credentials", provider = %"default_chain")) + .await + } +} + +impl ProvideCredentials for ConfiguredCredentialChain { + fn provide_credentials<'a>( + &'a self, + ) -> aws_credential_types::provider::future::ProvideCredentials<'a> + where + Self: 'a, + { + aws_credential_types::provider::future::ProvideCredentials::new(self.credentials()) + } +} diff --git a/crates/aws/src/lib.rs b/crates/aws/src/lib.rs index 026f0e0df9..82eff07c1b 100644 --- a/crates/aws/src/lib.rs +++ b/crates/aws/src/lib.rs @@ -1,5 +1,6 @@ //! Lock client implementation based on DynamoDb. +mod credentials; pub mod errors; pub mod logstore; #[cfg(feature = "native-tls")] diff --git a/crates/aws/src/storage.rs b/crates/aws/src/storage.rs index ffd7bb6996..19c18ff3ea 100644 --- a/crates/aws/src/storage.rs +++ b/crates/aws/src/storage.rs @@ -164,19 +164,30 @@ impl S3StorageOptions { let allow_unsafe_rename = str_option(options, s3_constants::AWS_S3_ALLOW_UNSAFE_RENAME) .map(|val| str_is_truthy(&val)) .unwrap_or(false); - + let disable_imds = str_option(options, s3_constants::AWS_EC2_METADATA_DISABLED) + .map(|val| str_is_truthy(&val)) + .unwrap_or(false); + let imds_timeout = + Self::u64_or_default(options, s3_constants::AWS_EC2_METADATA_TIMEOUT, 100); + let credentials_provider = + crate::credentials::ConfiguredCredentialChain::new(disable_imds, imds_timeout, None); #[cfg(feature = "native-tls")] let sdk_config = execute_sdk_future( - aws_config::ConfigLoader::default() + aws_config::from_env() .http_client(native::use_native_tls_client( str_option(options, s3_constants::AWS_ALLOW_HTTP) .map(|val| str_is_truthy(&val)) .unwrap_or(false), )) + .credentials_provider(credentials_provider) .load(), )?; #[cfg(feature = "rustls")] - let sdk_config = execute_sdk_future(aws_config::load_from_env())?; + let sdk_config = execute_sdk_future( + aws_config::from_env() + .credentials_provider(credentials_provider) + .load(), + )?; let sdk_config = if let Some(endpoint_url) = str_option(options, s3_constants::AWS_ENDPOINT_URL) { @@ -433,6 +444,14 @@ pub mod s3_constants { /// Only safe if there is one writer to a given table. pub const AWS_S3_ALLOW_UNSAFE_RENAME: &str = "AWS_S3_ALLOW_UNSAFE_RENAME"; + /// If set to "true", disables the imds client + /// Defaults to "false" + pub const AWS_EC2_METADATA_DISABLED: &str = "AWS_EC2_METADATA_DISABLED"; + + /// The timeout in milliseconds for the EC2 metadata endpoint + /// Defaults to 100 + pub const AWS_EC2_METADATA_TIMEOUT: &str = "AWS_EC2_METADATA_TIMEOUT"; + /// The list of option keys owned by the S3 module. /// Option keys not contained in this list will be added to the `extra_opts` /// field of [crate::storage::s3::S3StorageOptions]. @@ -452,6 +471,8 @@ pub mod s3_constants { AWS_S3_POOL_IDLE_TIMEOUT_SECONDS, AWS_STS_POOL_IDLE_TIMEOUT_SECONDS, AWS_S3_GET_INTERNAL_SERVER_ERROR_RETRIES, + AWS_EC2_METADATA_DISABLED, + AWS_EC2_METADATA_TIMEOUT, ]; } @@ -462,8 +483,11 @@ pub(crate) fn str_option(map: &HashMap, key: &str) -> Option(future: F) -> F::Output + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + let _env_scope = Self::new(); + future.await + } } impl Drop for ScopedEnv { @@ -744,4 +777,48 @@ mod tests { } }); } + + #[tokio::test] + #[serial] + async fn storage_options_toggle_imds() { + ScopedEnv::run_async(async { + clear_env_of_aws_keys(); + let disabled_time = storage_options_configure_imds(Some("true")).await; + let enabled_time = storage_options_configure_imds(Some("false")).await; + let default_time = storage_options_configure_imds(None).await; + println!( + "enabled_time: {}, disabled_time: {}, default_time: {}", + enabled_time.as_micros(), + disabled_time.as_micros(), + default_time.as_micros(), + ); + assert!(disabled_time < enabled_time); + assert!(disabled_time < default_time); + }) + .await; + } + + async fn storage_options_configure_imds(value: Option<&str>) -> Duration { + let _options = match value { + Some(value) => S3StorageOptions::from_map(&hashmap! { + s3_constants::AWS_REGION.to_string() => "eu-west-1".to_string(), + s3_constants::AWS_EC2_METADATA_DISABLED.to_string() => value.to_string(), + }) + .unwrap(), + None => S3StorageOptions::from_map(&hashmap! { + s3_constants::AWS_REGION.to_string() => "eu-west-1".to_string(), + }) + .unwrap(), + }; + + assert_eq!( + "eu-west-1", + std::env::var(s3_constants::AWS_REGION).unwrap() + ); + + let provider = _options.sdk_config.credentials_provider().unwrap(); + let now = SystemTime::now(); + _ = provider.provide_credentials().await; + now.elapsed().unwrap() + } } diff --git a/crates/core/src/kernel/snapshot/log_segment.rs b/crates/core/src/kernel/snapshot/log_segment.rs index 0b82231ee8..e8d727b7da 100644 --- a/crates/core/src/kernel/snapshot/log_segment.rs +++ b/crates/core/src/kernel/snapshot/log_segment.rs @@ -453,10 +453,16 @@ async fn list_log_files_with_checkpoint( }) .collect_vec(); - // TODO raise a proper error - assert_eq!(checkpoint_files.len(), cp.parts.unwrap_or(1) as usize); - - Ok((commit_files, checkpoint_files)) + if checkpoint_files.len() != cp.parts.unwrap_or(1) as usize { + let msg = format!( + "Number of checkpoint files '{}' is not equal to number of checkpoint metadata parts '{:?}'", + checkpoint_files.len(), + cp.parts + ); + Err(DeltaTableError::MetadataError(msg)) + } else { + Ok((commit_files, checkpoint_files)) + } } /// List relevant log files.