Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions crates/catalog/rest/src/catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,15 @@ impl RestCatalog {
}
}

/// Set a custom storage credentials loader.
///
/// This is intended to be called after catalog construction, so the loader
/// can hold a reference to the catalog (e.g., `Arc<RestCatalog>`) and call
/// catalog-specific methods like `load_table_with_credentials`.
pub fn set_storage_credentials_loader(&mut self, loader: Arc<dyn StorageCredentialsLoader>) {
self.user_config.storage_credentials_loader = Some(loader);
}

/// Add an extension to the file IO builder.
pub fn with_file_io_extension<T: Any + Send + Sync>(mut self, ext: T) -> Self {
self.file_io_extensions.add(ext);
Expand Down Expand Up @@ -430,6 +439,7 @@ impl RestCatalog {
metadata_location: Option<&str>,
extra_config: Option<HashMap<String, String>>,
storage_credential: Option<StorageCredential>,
table_ident: Option<&TableIdent>,
) -> Result<FileIO> {
let mut props = self.context().await?.config.props.clone();
if let Some(config) = extra_config {
Expand Down Expand Up @@ -458,6 +468,9 @@ impl RestCatalog {
file_io_builder =
file_io_builder.with_extension(MetadataLocation(loc.to_string()));
}
if let Some(ident) = table_ident {
file_io_builder = file_io_builder.with_extension(ident.clone());
}
file_io_builder = file_io_builder.with_extension(loader.clone());
}

Expand Down Expand Up @@ -570,7 +583,10 @@ impl RestCatalog {
&self.user_config.storage_credentials_loader
{
let credential = storage_credentials_loader
.load_credentials(response.metadata_location.as_deref().unwrap_or(""))
.load_credentials(
table_ident,
response.metadata_location.as_deref().unwrap_or(""),
)
.await?;
config.extend(credential.config.clone());
Some(credential)
Expand All @@ -583,6 +599,7 @@ impl RestCatalog {
response.metadata_location.as_deref(),
Some(config),
final_credential,
Some(table_ident),
)
.await?;

Expand Down Expand Up @@ -891,7 +908,7 @@ impl Catalog for RestCatalog {

// TODO: Support vended credentials here.
let file_io = self
.load_file_io(Some(metadata_location), Some(config), None)
.load_file_io(Some(metadata_location), Some(config), None, None)
.await?;

let table_builder = Table::builder()
Expand Down Expand Up @@ -1033,7 +1050,7 @@ impl Catalog for RestCatalog {

// TODO: Support vended credentials here.
let file_io = self
.load_file_io(Some(metadata_location), None, None)
.load_file_io(Some(metadata_location), None, None, None)
.await?;

Table::builder()
Expand Down Expand Up @@ -1100,7 +1117,7 @@ impl Catalog for RestCatalog {

// TODO: Support vended credentials here.
let file_io = self
.load_file_io(Some(&response.metadata_location), None, None)
.load_file_io(Some(&response.metadata_location), None, None, None)
.await?;

Table::builder()
Expand Down Expand Up @@ -3034,7 +3051,11 @@ mod tests {

#[async_trait::async_trait]
impl StorageCredentialsLoader for DummyCredentialLoader {
async fn load_credentials(&self, _location: &str) -> Result<StorageCredential> {
async fn load_credentials(
&self,
_table_ident: &TableIdent,
_location: &str,
) -> Result<StorageCredential> {
self.was_called.store(true, Ordering::SeqCst);
let mut config = HashMap::new();
config.insert("custom.key".to_string(), "custom.value".to_string());
Expand Down
17 changes: 16 additions & 1 deletion crates/iceberg/src/io/opendal/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ use super::{
FileIOBuilder, FileMetadata, FileRead, FileWrite, InputFile, MetadataLocation, OutputFile,
Storage, StorageConfig, StorageCredential, StorageCredentialsLoader, StorageFactory,
};
use crate::catalog::TableIdent;
use crate::{Error, ErrorKind, Result};

#[cfg(feature = "storage-azdls")]
Expand Down Expand Up @@ -233,12 +234,22 @@ impl OpenDalStorage {
.get::<MetadataLocation>()
.map(|l| l.0.clone())
.unwrap_or_default();
let table_ident = extensions
.get::<TableIdent>()
.map(|arc| (*arc).clone())
.unwrap_or_else(|| {
TableIdent::new(
crate::NamespaceIdent::new("unknown".to_string()),
"unknown".to_string(),
)
});
let backend = RefreshableOpenDalStorageBuilder::new()
.scheme(scheme_str)
.base_props(props)
.credentials_loader(Arc::clone(&loader))
.initial_credentials(initial_creds)
.location(location)
.table_ident(table_ident)
.extensions(extensions)
.build()?;
return Ok(Self::Refreshable {
Expand Down Expand Up @@ -534,7 +545,11 @@ mod tests {

#[async_trait::async_trait]
impl StorageCredentialsLoader for TestCredentialLoader {
async fn load_credentials(&self, _location: &str) -> crate::Result<StorageCredential> {
async fn load_credentials(
&self,
_table_ident: &TableIdent,
_location: &str,
) -> crate::Result<StorageCredential> {
Ok(StorageCredential {
prefix: "s3://test/".to_string(),
config: HashMap::new(),
Expand Down
16 changes: 15 additions & 1 deletion crates/iceberg/src/io/refreshable_accessor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ mod tests {
use std::sync::atomic::{AtomicUsize, Ordering};

use super::*;
use crate::NamespaceIdent;
use crate::catalog::TableIdent;
use crate::io::refreshable_storage::RefreshableOpenDalStorageBuilder;
use crate::io::{StorageCredential, StorageCredentialsLoader};

Expand Down Expand Up @@ -259,7 +261,11 @@ mod tests {

#[async_trait::async_trait]
impl StorageCredentialsLoader for SequenceLoader {
async fn load_credentials(&self, _location: &str) -> crate::Result<StorageCredential> {
async fn load_credentials(
&self,
_table_ident: &TableIdent,
_location: &str,
) -> crate::Result<StorageCredential> {
self.call_count.fetch_add(1, Ordering::SeqCst);
let mut responses = self.responses.lock().unwrap();
Ok(responses.pop_front().unwrap_or_else(dummy_credential))
Expand Down Expand Up @@ -382,6 +388,10 @@ mod tests {
.scheme("memory".to_string())
.base_props(HashMap::new())
.credentials_loader(Arc::clone(&loader))
.table_ident(TableIdent::new(
NamespaceIdent::new("test_ns".to_string()),
"test_table".to_string(),
))
.build()
.expect("Failed to build storage");

Expand Down Expand Up @@ -477,6 +487,10 @@ mod tests {
.scheme("memory".to_string())
.base_props(HashMap::new())
.credentials_loader(Arc::clone(&loader) as _)
.table_ident(TableIdent::new(
NamespaceIdent::new("test_ns".to_string()),
"test_table".to_string(),
))
.build()
.expect("Failed to build storage");

Expand Down
40 changes: 36 additions & 4 deletions crates/iceberg/src/io/refreshable_storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use tokio::sync::Mutex as AsyncMutex;

use super::opendal::OpenDalStorage;
use super::refreshable_accessor::RefreshableAccessor;
use crate::catalog::TableIdent;
use crate::io::file_io::Extensions;
use crate::io::{StorageCredential, StorageCredentialsLoader};
use crate::{Error, ErrorKind, Result};
Expand Down Expand Up @@ -52,6 +53,9 @@ pub struct RefreshableOpenDalStorage {
/// Metadata location passed to `load_credentials`
location: String,

/// Table identifier passed to `load_credentials`
table_ident: TableIdent,

/// Cached AccessorInfo (created lazily from first operator)
cached_info: Mutex<Option<Arc<AccessorInfo>>>,

Expand Down Expand Up @@ -88,6 +92,7 @@ impl RefreshableOpenDalStorage {
credentials_loader: Arc<dyn StorageCredentialsLoader>,
initial_credentials: Option<StorageCredential>,
location: String,
table_ident: TableIdent,
extensions: Extensions,
) -> Result<Self> {
// Build initial inner_storage from base_props + initial_credentials
Expand All @@ -104,6 +109,7 @@ impl RefreshableOpenDalStorage {
credentials_loader,
extensions,
location,
table_ident,
cached_info: Mutex::new(None),
credential_version: AtomicU64::new(0),
refresh_lock: AsyncMutex::new(()),
Expand Down Expand Up @@ -197,7 +203,7 @@ impl RefreshableOpenDalStorage {
// We are the one who should call the loader
let new_creds = self
.credentials_loader
.load_credentials(&self.location)
.load_credentials(&self.table_ident, &self.location)
.await?;
self.do_refresh(new_creds)?;
Ok(self.credential_version.load(Ordering::Acquire))
Expand All @@ -212,6 +218,7 @@ pub struct RefreshableOpenDalStorageBuilder {
credentials_loader: Option<Arc<dyn StorageCredentialsLoader>>,
initial_credentials: Option<StorageCredential>,
location: String,
table_ident: Option<TableIdent>,
extensions: Extensions,
}

Expand Down Expand Up @@ -251,6 +258,12 @@ impl RefreshableOpenDalStorageBuilder {
self
}

/// Set the table identifier passed to `load_credentials`
pub fn table_ident(mut self, table_ident: TableIdent) -> Self {
self.table_ident = Some(table_ident);
self
}

/// Set the extensions
pub fn extensions(mut self, extensions: Extensions) -> Self {
self.extensions = extensions;
Expand All @@ -268,6 +281,8 @@ impl RefreshableOpenDalStorageBuilder {
})?,
self.initial_credentials,
self.location,
self.table_ident
.ok_or_else(|| Error::new(ErrorKind::DataInvalid, "table_ident is required"))?,
self.extensions,
)?))
}
Expand All @@ -278,6 +293,7 @@ mod tests {
use std::sync::atomic::{AtomicUsize, Ordering};

use super::*;
use crate::NamespaceIdent;
use crate::io::StorageCredential;

// --- Test helpers ---
Expand All @@ -288,7 +304,11 @@ mod tests {

#[async_trait::async_trait]
impl StorageCredentialsLoader for SimpleLoader {
async fn load_credentials(&self, _location: &str) -> Result<StorageCredential> {
async fn load_credentials(
&self,
_table_ident: &TableIdent,
_location: &str,
) -> Result<StorageCredential> {
Ok(StorageCredential {
prefix: "memory:/refreshed/".to_string(),
config: HashMap::from([("refreshed_key".to_string(), "refreshed_val".to_string())]),
Expand Down Expand Up @@ -322,7 +342,11 @@ mod tests {

#[async_trait::async_trait]
impl StorageCredentialsLoader for TrackingRefreshLoader {
async fn load_credentials(&self, _location: &str) -> Result<StorageCredential> {
async fn load_credentials(
&self,
_table_ident: &TableIdent,
_location: &str,
) -> Result<StorageCredential> {
let n = self.call_count.fetch_add(1, Ordering::SeqCst) + 1;
Ok(StorageCredential {
prefix: format!("memory:/refresh-{n}/"),
Expand All @@ -331,13 +355,21 @@ mod tests {
}
}

fn test_table_ident() -> TableIdent {
TableIdent::new(
NamespaceIdent::new("test_ns".to_string()),
"test_table".to_string(),
)
}

fn build_memory_refreshable(
loader: Arc<dyn StorageCredentialsLoader>,
) -> Arc<RefreshableOpenDalStorage> {
RefreshableOpenDalStorageBuilder::new()
.scheme("memory".to_string())
.base_props(HashMap::new())
.credentials_loader(loader)
.table_ident(test_table_ident())
.build()
.expect("Failed to build RefreshableOpenDalStorage for memory")
}
Expand All @@ -346,7 +378,7 @@ mod tests {
async fn refresh(storage: &RefreshableOpenDalStorage) -> Result<()> {
let new_creds = storage
.credentials_loader
.load_credentials(&storage.location)
.load_credentials(&storage.table_ident, &storage.location)
.await?;
storage.do_refresh(new_creds)
}
Expand Down
14 changes: 12 additions & 2 deletions crates/iceberg/src/io/storage_credential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use std::collections::HashMap;
use std::fmt::Debug;

use crate::Result;
use crate::catalog::TableIdent;

/// Storage credentials for accessing cloud storage.
///
Expand Down Expand Up @@ -54,7 +55,11 @@ pub struct MetadataLocation(pub String);
///
/// #[async_trait::async_trait]
/// impl StorageCredentialsLoader for MyCredentialLoader {
/// async fn load_credentials(&self, location: &str) -> iceberg::Result<StorageCredential> {
/// async fn load_credentials(
/// &self,
/// _table_ident: &iceberg::TableIdent,
/// location: &str,
/// ) -> iceberg::Result<StorageCredential> {
/// // Fetch fresh credentials from your credential service
/// let mut config = HashMap::new();
/// config.insert("access_key_id".to_string(), "fresh-key".to_string());
Expand Down Expand Up @@ -85,6 +90,11 @@ pub trait StorageCredentialsLoader: Send + Sync + Debug {
/// Load storage credentials using custom user-defined logic.
///
/// # Arguments
/// * `table_ident` - The table identifier for which credentials are being loaded
/// * `location` - The full path being accessed (e.g., "s3://bucket/path/file.parquet")
async fn load_credentials(&self, location: &str) -> Result<StorageCredential>;
async fn load_credentials(
&self,
table_ident: &TableIdent,
location: &str,
) -> Result<StorageCredential>;
}
Loading