Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for IRSA authentication for S3 #694

Merged
merged 2 commits into from
Jul 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
378 changes: 96 additions & 282 deletions Cargo.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ deltalake = { version = "0.17.3" }
cornucopia = { version = "0.9.0" }
cornucopia_async = {version = "0.6.0"}
deadpool-postgres = "0.12"

[profile.release]
debug = 1

Expand Down
4 changes: 3 additions & 1 deletion crates/arroyo-server-common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ pub fn init_logging_with_filter(_name: &str, filter: EnvFilter) -> WorkerGuard {
eprintln!("Failed to initialize log tracer {:?}", e);
}

let filter = filter.add_directive("refinery_core=warn".parse().unwrap());
let filter = filter
.add_directive("refinery_core=warn".parse().unwrap())
.add_directive("aws_config::profile::credentials=warn".parse().unwrap());

let (nonblocking, guard) = tracing_appender::non_blocking(std::io::stderr());

Expand Down
27 changes: 26 additions & 1 deletion crates/arroyo-state/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use anyhow::Result;
use anyhow::{Context, Result};
use arrow_array::RecordBatch;
use arroyo_rpc::grpc::rpc::{
CheckpointMetadata, ExpiringKeyedTimeTableConfig, GlobalKeyedTableConfig,
Expand All @@ -9,12 +9,15 @@ use async_trait::async_trait;
use bincode::config::Configuration;
use bincode::{Decode, Encode};

use arroyo_rpc::config::config;
use arroyo_rpc::df::ArroyoSchema;
use arroyo_storage::StorageProvider;
use prost::Message;
use std::collections::hash_map::DefaultHasher;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::ops::RangeInclusive;
use std::sync::Arc;
use std::time::{Duration, SystemTime};

pub mod checkpoint_state;
Expand Down Expand Up @@ -160,3 +163,25 @@ pub fn hash_key<K: Hash>(key: &K) -> u64 {
key.hash(&mut hasher);
hasher.finish()
}

static STORAGE_PROVIDER: tokio::sync::OnceCell<Arc<StorageProvider>> =
tokio::sync::OnceCell::const_new();

pub(crate) async fn get_storage_provider() -> Result<&'static Arc<StorageProvider>> {
// TODO: this should be encoded in the config so that the controller doesn't need
// to be synchronized with the workers

STORAGE_PROVIDER
.get_or_try_init(|| async {
let storage_url = &config().checkpoint_url;

StorageProvider::for_url(storage_url)
.await
.context(format!(
"failed to construct checkpoint backend for URL {}",
storage_url
))
.map(Arc::new)
})
.await
}
20 changes: 4 additions & 16 deletions crates/arroyo-state/src/parquet.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
use crate::tables::expiring_time_key_map::ExpiringTimeKeyTable;
use crate::tables::global_keyed_map::GlobalKeyedTable;
use crate::tables::{CompactionConfig, ErasedTable};
use crate::BackingStore;
use anyhow::{bail, Context, Result};
use crate::{get_storage_provider, BackingStore};
use anyhow::{bail, Result};
use arroyo_rpc::grpc::rpc::{
CheckpointMetadata, OperatorCheckpointMetadata, TableCheckpointMetadata,
};
use arroyo_storage::StorageProvider;
use futures::stream::FuturesUnordered;
use futures::StreamExt;

Expand All @@ -23,17 +22,6 @@ use tracing::{debug, info};
pub const FULL_KEY_RANGE: RangeInclusive<u64> = 0..=u64::MAX;
pub const GENERATIONS_TO_COMPACT: u32 = 1; // only compact generation 0 files

async fn get_storage_provider() -> anyhow::Result<StorageProvider> {
// TODO: this should be encoded in the config so that the controller doesn't need
// to be synchronized with the workers
let storage_url = &config().checkpoint_url;

StorageProvider::for_url(storage_url).await.context(format!(
"failed to construct checkpoint backend for URL {}",
storage_url
))
}

pub struct ParquetBackend;

fn base_path(job_id: &str, epoch: u32) -> String {
Expand Down Expand Up @@ -178,11 +166,11 @@ impl ParquetBackend {
Self::load_operator_metadata(&job_id, &operator_id, epoch)
.await?
.expect("expect operator metadata to still be present");
let storage_provider = Arc::new(get_storage_provider().await?);
let storage_provider = get_storage_provider().await?;
let compaction_config = CompactionConfig {
storage_provider,
compact_generations: vec![0].into_iter().collect(),
min_compaction_epochs: min_files_to_compact,
storage_provider: Arc::clone(storage_provider),
};
let operator_metadata = operator_checkpoint_metadata.operator_metadata.unwrap();

Expand Down
18 changes: 2 additions & 16 deletions crates/arroyo-state/src/tables/table_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use tokio::sync::{
use arroyo_rpc::config::config;
use tracing::{debug, error, info, warn};

use crate::{tables::global_keyed_map::GlobalKeyedTable, StateMessage};
use crate::{get_storage_provider, tables::global_keyed_map::GlobalKeyedTable, StateMessage};
use crate::{CheckpointMessage, TableData};

use super::expiring_time_key_map::{
Expand Down Expand Up @@ -225,20 +225,6 @@ impl BackendWriter {
}
}

async fn get_storage_provider() -> anyhow::Result<StorageProviderRef> {
// TODO: this should be encoded in the config so that the controller doesn't need
// to be synchronized with the workers

Ok(Arc::new(
StorageProvider::for_url(&config().checkpoint_url)
.await
.context(format!(
"failed to construct checkpoint backend for URL {}",
config().checkpoint_url
))?,
))
}

impl TableManager {
pub async fn new(
task_info: TaskInfoRef,
Expand Down Expand Up @@ -320,7 +306,7 @@ impl TableManager {
tables,
writer,
task_info,
storage,
storage: Arc::clone(storage),
caches: HashMap::new(),
})
}
Expand Down
5 changes: 2 additions & 3 deletions crates/arroyo-storage/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@ arroyo-types = { path = "../arroyo-types" }
arroyo-rpc = { path = "../arroyo-rpc" }
bytes = "1.4.0"
tracing = "0.1"
# used only for getting local AWS credentials; can be removed once we have a
# better way to do this
rusoto_core = "0.48.0"

aws-credential-types = "1.2.0"
aws-config = { version = "1.5.4" }
rand = "0.8"
object_store = {workspace = true, features = ["aws", "gcp"]}
regex = "1.9.5"
Expand Down
62 changes: 34 additions & 28 deletions crates/arroyo-storage/src/aws.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
use std::sync::Arc;

use object_store::{aws::AwsCredential, CredentialProvider};
use rusoto_core::credential::{
AutoRefreshingProvider, ChainProvider, ProfileProvider, ProvideAwsCredentials,
};

use crate::StorageError;
use aws_config::BehaviorVersion;
use aws_credential_types::provider::ProvideCredentials;
use object_store::{aws::AwsCredential, CredentialProvider};
use std::sync::Arc;

pub struct ArroyoCredentialProvider {
provider: AutoRefreshingProvider<ChainProvider>,
provider: aws_credential_types::provider::SharedCredentialsProvider,
}

impl std::fmt::Debug for ArroyoCredentialProvider {
Expand All @@ -18,38 +15,47 @@ impl std::fmt::Debug for ArroyoCredentialProvider {
}

impl ArroyoCredentialProvider {
pub fn try_new() -> Result<Self, StorageError> {
let inner: AutoRefreshingProvider<ChainProvider> =
AutoRefreshingProvider::new(ChainProvider::new())
.map_err(|e| StorageError::CredentialsError(e.to_string()))?;

Ok(Self { provider: inner })
pub async fn try_new() -> Result<Self, StorageError> {
let config = aws_config::defaults(BehaviorVersion::latest()).load().await;

let credentials = config
.credentials_provider()
.ok_or_else(|| {
StorageError::CredentialsError(
"Unable to load S3 credentials from environment".to_string(),
)
})?
.clone();

Ok(Self {
provider: credentials,
})
}

pub async fn default_region() -> Option<String> {
ProfileProvider::region().ok()?
aws_config::defaults(BehaviorVersion::latest())
.load()
.await
.region()
.map(|r| r.to_string())
}
}

#[async_trait::async_trait]
impl CredentialProvider for ArroyoCredentialProvider {
#[doc = " The type of credential returned by this provider"]
type Credential = AwsCredential;

/// Return a credential
async fn get_credential(&self) -> object_store::Result<Arc<Self::Credential>> {
let credentials =
self.provider
.credentials()
.await
.map_err(|err| object_store::Error::Generic {
store: "s3",
source: Box::new(err),
})?;
let creds = self.provider.provide_credentials().await.map_err(|e| {
object_store::Error::Generic {
store: "S3",
source: Box::new(e),
}
})?;
Ok(Arc::new(AwsCredential {
key_id: credentials.aws_access_key_id().to_string(),
secret_key: credentials.aws_secret_access_key().to_string(),
token: credentials.token().clone(),
key_id: creds.access_key_id().to_string(),
secret_key: creds.secret_access_key().to_string(),
token: creds.session_token().map(ToString::to_string),
}))
}
}
67 changes: 4 additions & 63 deletions crates/arroyo-storage/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,9 @@ use object_store::multipart::PartId;
use object_store::path::Path;
use object_store::{aws::AmazonS3Builder, local::LocalFileSystem, ObjectStore};
use object_store::{CredentialProvider, MultipartId};
use once_cell::sync::Lazy;
use regex::{Captures, Regex};
use std::time::{Duration, Instant};
use thiserror::Error;
use tokio::sync::RwLock;
use tracing::{debug, error, trace};
use tracing::{debug, error};

mod aws;

Expand Down Expand Up @@ -296,23 +293,11 @@ fn last<I: Sized, const COUNT: usize>(opts: [Option<I>; COUNT]) -> Option<I> {
}

pub async fn get_current_credentials() -> Result<Arc<AwsCredential>, StorageError> {
let provider = ArroyoCredentialProvider::try_new()?;
let provider = ArroyoCredentialProvider::try_new().await?;
let credentials = provider.get_credential().await?;
Ok(credentials)
}

static OBJECT_STORE_CACHE: Lazy<RwLock<HashMap<String, CacheEntry<Arc<dyn ObjectStore>>>>> =
Lazy::new(Default::default);

struct CacheEntry<T> {
value: T,
inserted_at: Instant,
}

// The bearer token should last for 3600 seconds,
// but regenerating it every 5 minutes to avoid token expiry
const GCS_CACHE_TTL: Duration = Duration::from_secs(5 * 60);

impl StorageProvider {
pub async fn for_url(url: &str) -> Result<Self, StorageError> {
Self::for_url_with_options(url, HashMap::new()).await
Expand Down Expand Up @@ -360,11 +345,6 @@ impl StorageProvider {
Ok(key.clone())
}

pub async fn url_exists(url: &str) -> Result<bool, StorageError> {
let provider = Self::for_url(url).await?;
provider.exists("").await
}

async fn construct_s3(
mut config: S3Config,
options: HashMap<String, String>,
Expand All @@ -386,7 +366,7 @@ impl StorageProvider {

if !aws_key_manually_set {
let credentials: Arc<ArroyoCredentialProvider> =
Arc::new(ArroyoCredentialProvider::try_new()?);
Arc::new(ArroyoCredentialProvider::try_new().await?);
builder = builder.with_credentials(credentials);
}

Expand Down Expand Up @@ -444,45 +424,6 @@ impl StorageProvider {
})
}

async fn get_or_create_object_store(
builder: GoogleCloudStorageBuilder,
bucket: &str,
) -> Result<Arc<dyn ObjectStore>, StorageError> {
let mut cache = OBJECT_STORE_CACHE.write().await;

if let Some(entry) = cache.get(bucket) {
if entry.inserted_at.elapsed() < GCS_CACHE_TTL {
trace!(
"Cache hit - using cached object store for bucket {}",
bucket
);
return Ok(entry.value.clone());
} else {
debug!(
"Cache expired - constructing new object store for bucket {}",
bucket
);
}
} else {
debug!(
"Cache miss - constructing new object store for bucket {}",
bucket
);
}

let new_store = Arc::new(builder.build().map_err(Into::<StorageError>::into)?);

cache.insert(
bucket.to_string(),
CacheEntry {
value: new_store.clone(),
inserted_at: Instant::now(),
},
);

Ok(new_store)
}

async fn construct_gcs(config: GCSConfig) -> Result<Self, StorageError> {
let mut builder = GoogleCloudStorageBuilder::from_env().with_bucket_name(&config.bucket);

Expand All @@ -498,7 +439,7 @@ impl StorageProvider {

let object_store_base_url = format!("https://{}.storage.googleapis.com", config.bucket);

let object_store = Self::get_or_create_object_store(builder, &config.bucket).await?;
let object_store = Arc::new(builder.build()?);

Ok(Self {
config: BackendConfig::GCS(config),
Expand Down
4 changes: 0 additions & 4 deletions crates/arroyo-worker/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,7 @@ parquet = { workspace = true, features = ["async"]}
arrow-array = { workspace = true}
arrow-json = { workspace = true }

aws-sdk-kinesis = { version = "0.21", default-features = false, features = ["rt-tokio", "native-tls"] }
aws-config = { version = "0.51", default-features = false, features = ["rt-tokio", "native-tls"] }
uuid = {version = "1.4.1", features = ["v4"]}
rusoto_core = "0.48.0"
rusoto_s3 = "0.48.0"

tonic = { workspace = true }
prost = "0.12"
Expand Down
Loading