Skip to content

Commit

Permalink
Object Store (AWS): Support region configured via named profile (#4161)
Browse files Browse the repository at this point in the history
* feat(aws_profile): use profile region as fallback

* moved ProfileProvider to aws::profile module
* added aws::region::RegionProvider
* lazy-init profile credential provider
* support overriding profile region
* tests

* fix(aws_profile): clippy & RAT errors

* fix(aws_profile): make RegionProvider async

* test(aws_profile): use fake config for testing

* refactor(aws_profile): remove unnecessary module

aws::profile::region -> aws::profile

* refactor(aws_profile): tests w/ profile files

* fix(object_store): rat + clippy warnings

* Don't spawn thread

---------

Co-authored-by: Raphael Taylor-Davies <r.taylordavies@googlemail.com>
  • Loading branch information
mr-brobot and tustvold authored May 16, 2023
1 parent 108b7a8 commit 4714b21
Show file tree
Hide file tree
Showing 3 changed files with 204 additions and 64 deletions.
62 changes: 0 additions & 62 deletions object_store/src/aws/credential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -515,68 +515,6 @@ async fn web_identity(
})
}

#[cfg(feature = "aws_profile")]
mod profile {
use super::*;
use aws_config::profile::ProfileFileCredentialsProvider;
use aws_config::provider_config::ProviderConfig;
use aws_credential_types::provider::ProvideCredentials;
use aws_types::region::Region;
use std::time::SystemTime;

#[derive(Debug)]
pub struct ProfileProvider {
cache: TokenCache<Arc<AwsCredential>>,
credentials: ProfileFileCredentialsProvider,
}

impl ProfileProvider {
pub fn new(name: String, region: String) -> Self {
let config = ProviderConfig::default().with_region(Some(Region::new(region)));

Self {
cache: Default::default(),
credentials: ProfileFileCredentialsProvider::builder()
.configure(&config)
.profile_name(name)
.build(),
}
}
}

impl CredentialProvider for ProfileProvider {
fn get_credential(&self) -> BoxFuture<'_, Result<Arc<AwsCredential>>> {
Box::pin(self.cache.get_or_insert_with(move || async move {
let c =
self.credentials
.provide_credentials()
.await
.map_err(|source| crate::Error::Generic {
store: STORE,
source: Box::new(source),
})?;
let t_now = SystemTime::now();
let expiry = c
.expiry()
.and_then(|e| e.duration_since(t_now).ok())
.map(|ttl| Instant::now() + ttl);

Ok(TemporaryToken {
token: Arc::new(AwsCredential {
key_id: c.access_key_id().to_string(),
secret_key: c.secret_access_key().to_string(),
token: c.session_token().map(ToString::to_string),
}),
expiry,
})
}))
}
}
}

#[cfg(feature = "aws_profile")]
pub use profile::ProfileProvider;

#[cfg(test)]
mod tests {
use super::*;
Expand Down
78 changes: 76 additions & 2 deletions object_store/src/aws/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ mod checksum;
mod client;
mod credential;

#[cfg(feature = "aws_profile")]
mod profile;

// http://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html
//
// Do not URI-encode any of the unreserved characters that RFC 3986 defines:
Expand Down Expand Up @@ -985,8 +988,14 @@ impl AmazonS3Builder {
self.parse_url(&url)?;
}

let region = match (self.region.clone(), self.profile.clone()) {
(Some(region), _) => Some(region),
(None, Some(profile)) => profile_region(profile),
(None, None) => None,
};

let bucket = self.bucket_name.context(MissingBucketNameSnafu)?;
let region = self.region.context(MissingRegionSnafu)?;
let region = region.context(MissingRegionSnafu)?;
let checksum = self.checksum_algorithm.map(|x| x.get()).transpose()?;

let credentials = match (self.access_key_id, self.secret_access_key, self.token) {
Expand Down Expand Up @@ -1094,12 +1103,30 @@ impl AmazonS3Builder {
}
}

#[cfg(feature = "aws_profile")]
fn profile_region(profile: String) -> Option<String> {
use tokio::runtime::Handle;

let handle = Handle::current();
let provider = profile::ProfileProvider::new(profile, None);

handle.block_on(provider.get_region())
}

#[cfg(feature = "aws_profile")]
fn profile_credentials(
profile: String,
region: String,
) -> Result<Box<dyn CredentialProvider>> {
Ok(Box::new(credential::ProfileProvider::new(profile, region)))
Ok(Box::new(profile::ProfileProvider::new(
profile,
Some(region),
)))
}

#[cfg(not(feature = "aws_profile"))]
fn profile_region(_profile: String) -> Option<String> {
None
}

#[cfg(not(feature = "aws_profile"))]
Expand Down Expand Up @@ -1594,3 +1621,50 @@ mod s3_resolve_bucket_region_tests {
assert!(result.is_err());
}
}

#[cfg(all(test, feature = "aws_profile"))]
mod profile_tests {
use super::*;
use std::env;

use super::profile::{TEST_PROFILE_NAME, TEST_PROFILE_REGION};

#[tokio::test]
async fn s3_test_region_from_profile() {
let s3_url = "s3://bucket/prefix".to_owned();

let s3 = AmazonS3Builder::new()
.with_url(s3_url)
.with_profile(TEST_PROFILE_NAME)
.build()
.unwrap();

let region = &s3.client.config().region;

assert_eq!(region, TEST_PROFILE_REGION);
}

#[test]
fn s3_test_region_override() {
let s3_url = "s3://bucket/prefix".to_owned();

let aws_profile =
env::var("AWS_PROFILE").unwrap_or_else(|_| TEST_PROFILE_NAME.into());

let aws_region =
env::var("AWS_REGION").unwrap_or_else(|_| "object_store:fake_region".into());

env::set_var("AWS_PROFILE", aws_profile);

let s3 = AmazonS3Builder::from_env()
.with_url(s3_url)
.with_region(aws_region.clone())
.build()
.unwrap();

let actual = &s3.client.config().region;
let expected = &aws_region;

assert_eq!(actual, expected);
}
}
128 changes: 128 additions & 0 deletions object_store/src/aws/profile.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#![cfg(feature = "aws_profile")]

use aws_config::meta::region::ProvideRegion;
use aws_config::profile::profile_file::ProfileFiles;
use aws_config::profile::ProfileFileCredentialsProvider;
use aws_config::profile::ProfileFileRegionProvider;
use aws_config::provider_config::ProviderConfig;
use aws_credential_types::provider::ProvideCredentials;
use aws_types::region::Region;
use futures::future::BoxFuture;
use std::sync::Arc;
use std::time::Instant;
use std::time::SystemTime;

use crate::aws::credential::CredentialProvider;
use crate::aws::AwsCredential;
use crate::client::token::{TemporaryToken, TokenCache};
use crate::Result;

#[cfg(test)]
pub static TEST_PROFILE_NAME: &str = "object_store:fake_profile";

#[cfg(test)]
pub static TEST_PROFILE_REGION: &str = "object_store:fake_region_from_profile";

#[derive(Debug)]
pub struct ProfileProvider {
name: String,
region: Option<String>,
cache: TokenCache<Arc<AwsCredential>>,
}

impl ProfileProvider {
pub fn new(name: String, region: Option<String>) -> Self {
Self {
name,
region,
cache: Default::default(),
}
}

#[cfg(test)]
fn profile_files(&self) -> ProfileFiles {
use aws_config::profile::profile_file::ProfileFileKind;

let config = format!(
"[profile {}]\nregion = {}",
TEST_PROFILE_NAME, TEST_PROFILE_REGION
);

ProfileFiles::builder()
.with_contents(ProfileFileKind::Config, config)
.build()
}

#[cfg(not(test))]
fn profile_files(&self) -> ProfileFiles {
ProfileFiles::default()
}

pub async fn get_region(&self) -> Option<String> {
if let Some(region) = self.region.clone() {
return Some(region);
}

let provider = ProfileFileRegionProvider::builder()
.profile_files(self.profile_files())
.profile_name(&self.name)
.build();

let region = provider.region().await;

region.map(|r| r.as_ref().to_owned())
}
}

impl CredentialProvider for ProfileProvider {
fn get_credential(&self) -> BoxFuture<'_, Result<Arc<AwsCredential>>> {
Box::pin(self.cache.get_or_insert_with(move || async move {
let region = self.region.clone().map(Region::new);

let config = ProviderConfig::default().with_region(region);

let credentials = ProfileFileCredentialsProvider::builder()
.configure(&config)
.profile_name(&self.name)
.build();

let c = credentials.provide_credentials().await.map_err(|source| {
crate::Error::Generic {
store: "S3",
source: Box::new(source),
}
})?;
let t_now = SystemTime::now();
let expiry = c
.expiry()
.and_then(|e| e.duration_since(t_now).ok())
.map(|ttl| Instant::now() + ttl);

Ok(TemporaryToken {
token: Arc::new(AwsCredential {
key_id: c.access_key_id().to_string(),
secret_key: c.secret_access_key().to_string(),
token: c.session_token().map(ToString::to_string),
}),
expiry,
})
}))
}
}

0 comments on commit 4714b21

Please sign in to comment.