diff --git a/Cargo.lock b/Cargo.lock index 7bc3445c1e4..86fafc9aa28 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5222,6 +5222,7 @@ dependencies = [ "azure_identity", "azure_storage", "azure_storage_blobs", + "base64 0.22.1", "bytes", "chrono", "futures", @@ -5238,6 +5239,7 @@ dependencies = [ "rstest", "serde", "serde_json", + "sha2", "snafu", "tempfile", "time", @@ -5250,9 +5252,9 @@ dependencies = [ [[package]] name = "lance-namespace-reqwest-client" -version = "0.4.0" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4dfe76b82f4167fa1c19d5d8825f8fb7d3831e83fa6e0485b3dd59ef0f7b1685" +checksum = "a2acdba67f84190067532fce07b51a435dd390d7cdc1129a05003e5cb3274cf0" dependencies = [ "reqwest", "serde", diff --git a/Cargo.toml b/Cargo.toml index 826d1430d35..9cc3af5be85 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -63,7 +63,7 @@ lance-io = { version = "=2.0.0-beta.4", path = "./rust/lance-io", default-featur lance-linalg = { version = "=2.0.0-beta.4", path = "./rust/lance-linalg" } lance-namespace = { version = "=2.0.0-beta.4", path = "./rust/lance-namespace" } lance-namespace-impls = { version = "=2.0.0-beta.4", path = "./rust/lance-namespace-impls" } -lance-namespace-reqwest-client = "=0.4.0" +lance-namespace-reqwest-client = { version = "=0.4.5" } lance-table = { version = "=2.0.0-beta.4", path = "./rust/lance-table" } lance-test-macros = { version = "=2.0.0-beta.4", path = "./rust/lance-test-macros" } lance-testing = { version = "=2.0.0-beta.4", path = "./rust/lance-testing" } diff --git a/java/lance-jni/Cargo.lock b/java/lance-jni/Cargo.lock index f20a87b446b..865b9d564c4 100644 --- a/java/lance-jni/Cargo.lock +++ b/java/lance-jni/Cargo.lock @@ -4287,6 +4287,7 @@ dependencies = [ "azure_identity", "azure_storage", "azure_storage_blobs", + "base64 0.22.1", "bytes", "chrono", "futures", @@ -4302,6 +4303,7 @@ dependencies = [ "reqwest", "serde", "serde_json", + "sha2", "snafu", "time", "tokio", @@ -4312,9 +4314,9 @@ dependencies = [ [[package]] name = "lance-namespace-reqwest-client" -version = "0.4.0" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4dfe76b82f4167fa1c19d5d8825f8fb7d3831e83fa6e0485b3dd59ef0f7b1685" +checksum = "a2acdba67f84190067532fce07b51a435dd390d7cdc1129a05003e5cb3274cf0" dependencies = [ "reqwest", "serde", diff --git a/java/pom.xml b/java/pom.xml index cf212b22f20..fa64e1a5f99 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -108,12 +108,12 @@ org.lance lance-namespace-core - 0.4.0 + 0.4.5 org.lance lance-namespace-apache-client - 0.4.0 + 0.4.5 com.fasterxml.jackson.core diff --git a/python/Cargo.lock b/python/Cargo.lock index cf63ab702b5..886115f9303 100644 --- a/python/Cargo.lock +++ b/python/Cargo.lock @@ -4625,6 +4625,7 @@ dependencies = [ "azure_identity", "azure_storage", "azure_storage_blobs", + "base64 0.22.1", "bytes", "chrono", "futures", @@ -4640,6 +4641,7 @@ dependencies = [ "reqwest", "serde", "serde_json", + "sha2", "snafu", "time", "tokio", @@ -4650,9 +4652,9 @@ dependencies = [ [[package]] name = "lance-namespace-reqwest-client" -version = "0.4.0" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4dfe76b82f4167fa1c19d5d8825f8fb7d3831e83fa6e0485b3dd59ef0f7b1685" +checksum = "a2acdba67f84190067532fce07b51a435dd390d7cdc1129a05003e5cb3274cf0" dependencies = [ "reqwest", "serde", diff --git a/python/pyproject.toml b/python/pyproject.toml index 0a6dd542222..815e28c6666 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "pylance" dynamic = ["version"] -dependencies = ["pyarrow>=14", "numpy>=1.22", "lance-namespace>=0.4.0"] +dependencies = ["pyarrow>=14", "numpy>=1.22", "lance-namespace>=0.4.5"] description = "python wrapper for Lance columnar format" authors = [{ name = "Lance Devs", email = "dev@lance.org" }] license = { file = "LICENSE" } diff --git a/rust/lance-io/src/object_store/storage_options.rs b/rust/lance-io/src/object_store/storage_options.rs index f809df8d1d3..22854e8fd53 100644 --- a/rust/lance-io/src/object_store/storage_options.rs +++ b/rust/lance-io/src/object_store/storage_options.rs @@ -113,8 +113,7 @@ impl StorageOptionsProvider for LanceNamespaceStorageOptionsProvider { async fn fetch_storage_options(&self) -> Result>> { let request = DescribeTableRequest { id: Some(self.table_id.clone()), - version: None, - with_table_uri: None, + ..Default::default() }; let response = self diff --git a/rust/lance-namespace-impls/Cargo.toml b/rust/lance-namespace-impls/Cargo.toml index cb0ff52d1e0..85ee4a6989f 100644 --- a/rust/lance-namespace-impls/Cargo.toml +++ b/rust/lance-namespace-impls/Cargo.toml @@ -22,9 +22,9 @@ dir-azure = ["lance-io/azure", "lance/azure"] dir-oss = ["lance-io/oss", "lance/oss"] dir-huggingface = ["lance-io/huggingface", "lance/huggingface"] # Credential vending features -credential-vendor-aws = ["dep:aws-sdk-sts", "dep:aws-config", "dep:aws-credential-types"] -credential-vendor-gcp = ["dep:google-cloud-auth", "dep:reqwest", "dep:serde"] -credential-vendor-azure = ["dep:azure_core", "dep:azure_identity", "dep:azure_storage", "dep:azure_storage_blobs", "dep:time"] +credential-vendor-aws = ["dep:aws-sdk-sts", "dep:aws-config", "dep:aws-credential-types", "dep:sha2", "dep:base64"] +credential-vendor-gcp = ["dep:google-cloud-auth", "dep:reqwest", "dep:serde", "dep:sha2", "dep:base64"] +credential-vendor-azure = ["dep:azure_core", "dep:azure_identity", "dep:azure_storage", "dep:azure_storage_blobs", "dep:time", "dep:sha2", "dep:base64", "dep:reqwest"] [dependencies] lance-namespace.workspace = true @@ -66,10 +66,12 @@ log.workspace = true rand.workspace = true chrono.workspace = true -# AWS credential vending dependencies (optional, enabled by "dir-aws" feature) +# AWS credential vending dependencies (optional, enabled by "credential-vendor-aws" feature) aws-sdk-sts = { version = "1.38.0", optional = true } aws-config = { workspace = true, optional = true } aws-credential-types = { workspace = true, optional = true } +sha2 = { version = "0.10", optional = true } +base64 = { version = "0.22", optional = true } # GCP credential vending dependencies (optional, enabled by "dir-gcp" feature) google-cloud-auth = { version = "0.18", optional = true } diff --git a/rust/lance-namespace-impls/src/credentials.rs b/rust/lance-namespace-impls/src/credentials.rs index 6be4f1e38a4..f9f7ecc7950 100644 --- a/rust/lance-namespace-impls/src/credentials.rs +++ b/rust/lance-namespace-impls/src/credentials.rs @@ -68,12 +68,22 @@ pub mod azure; #[cfg(feature = "credential-vendor-gcp")] pub mod gcp; +/// Credential caching module. +/// Available when any credential vendor feature is enabled. +#[cfg(any( + feature = "credential-vendor-aws", + feature = "credential-vendor-azure", + feature = "credential-vendor-gcp" +))] +pub mod cache; + use std::collections::HashMap; use std::str::FromStr; use async_trait::async_trait; use lance_core::Result; use lance_io::object_store::uri_to_url; +use lance_namespace::models::Identity; /// Default credential duration: 1 hour (3600000 milliseconds) pub const DEFAULT_CREDENTIAL_DURATION_MILLIS: u64 = 3600 * 1000; @@ -188,6 +198,18 @@ pub const ENABLED: &str = "enabled"; /// Common property key for permission level (short form). pub const PERMISSION: &str = "permission"; +/// Common property key to enable credential caching (short form). +/// Default: true. Set to "false" to disable caching. +pub const CACHE_ENABLED: &str = "cache_enabled"; + +/// Common property key for API key salt (short form). +/// Used to hash API keys before comparison: SHA256(api_key + ":" + salt) +pub const API_KEY_SALT: &str = "api_key_salt"; + +/// Property key prefix for API key hash to permission mappings (short form). +/// Format: `api_key_hash. = ""` +pub const API_KEY_HASH_PREFIX: &str = "api_key_hash."; + /// AWS-specific property keys (short form, without prefix) #[cfg(feature = "credential-vendor-aws")] pub mod aws_props { @@ -204,6 +226,14 @@ pub mod aws_props { #[cfg(feature = "credential-vendor-gcp")] pub mod gcp_props { pub const SERVICE_ACCOUNT: &str = "gcp_service_account"; + + /// Workload Identity Provider resource name for OIDC token exchange. + /// Format: //iam.googleapis.com/projects/{project}/locations/global/workloadIdentityPools/{pool}/providers/{provider} + pub const WORKLOAD_IDENTITY_PROVIDER: &str = "gcp_workload_identity_provider"; + + /// Service account to impersonate after Workload Identity Federation (optional). + /// If not set, uses the federated identity directly. + pub const IMPERSONATION_SERVICE_ACCOUNT: &str = "gcp_impersonation_service_account"; } /// Azure-specific property keys (short form, without prefix) @@ -215,6 +245,10 @@ pub mod azure_props { /// Azure credential duration in milliseconds. /// Default: 3600000 (1 hour). Azure SAS tokens can be valid up to 7 days. pub const DURATION_MILLIS: &str = "azure_duration_millis"; + + /// Client ID of the Azure AD App Registration for Workload Identity Federation. + /// Required when using auth_token identity for OIDC token exchange. + pub const FEDERATED_CLIENT_ID: &str = "azure_federated_client_id"; } /// Vended credentials with expiration information. @@ -271,16 +305,30 @@ pub trait CredentialVendor: Send + Sync + std::fmt::Debug { /// Vend credentials for accessing the specified table location. /// /// The permission level (read/write/admin) is determined by the vendor's - /// configuration, not per-request. + /// configuration, not per-request. When identity is provided, the vendor + /// may use different authentication flows: + /// + /// - `auth_token`: Use AssumeRoleWithWebIdentity (AWS validates the token) + /// - `api_key`: Validate against configured API key hashes and use AssumeRole + /// - `None`: Use static configuration with AssumeRole /// /// # Arguments /// /// * `table_location` - The table URI to vend credentials for + /// * `identity` - Optional identity from the request (api_key OR auth_token, mutually exclusive) /// /// # Returns /// /// Returns vended credentials with expiration information. - async fn vend_credentials(&self, table_location: &str) -> Result; + /// + /// # Errors + /// + /// Returns error if identity validation fails (no fallback to static config). + async fn vend_credentials( + &self, + table_location: &str, + identity: Option<&Identity>, + ) -> Result; /// Returns the cloud provider name (e.g., "aws", "gcp", "azure"). fn provider_name(&self) -> &'static str; @@ -349,21 +397,50 @@ pub async fn create_credential_vendor_for_location( ) -> Result>> { let provider = detect_provider_from_uri(table_location); - match provider { + let vendor: Option> = match provider { #[cfg(feature = "credential-vendor-aws")] - "aws" => create_aws_vendor(properties).await, + "aws" => create_aws_vendor(properties).await?, #[cfg(feature = "credential-vendor-gcp")] - "gcp" => create_gcp_vendor(properties).await, + "gcp" => create_gcp_vendor(properties).await?, #[cfg(feature = "credential-vendor-azure")] - "azure" => create_azure_vendor(properties), + "azure" => create_azure_vendor(properties)?, + + _ => None, + }; - _ => Ok(None), + // Wrap with caching if enabled (default: true) + #[cfg(any( + feature = "credential-vendor-aws", + feature = "credential-vendor-azure", + feature = "credential-vendor-gcp" + ))] + if let Some(v) = vendor { + let cache_enabled = properties + .get(CACHE_ENABLED) + .map(|s| !s.eq_ignore_ascii_case("false")) + .unwrap_or(true); + + if cache_enabled { + return Ok(Some(Box::new(cache::CachingCredentialVendor::new(v)))); + } else { + return Ok(Some(v)); + } } + + #[cfg(not(any( + feature = "credential-vendor-aws", + feature = "credential-vendor-azure", + feature = "credential-vendor-gcp" + )))] + let _ = vendor; + + Ok(None) } /// Parse permission from properties, defaulting to Read +#[allow(dead_code)] fn parse_permission(properties: &HashMap) -> VendedPermission { properties .get(PERMISSION) @@ -372,6 +449,7 @@ fn parse_permission(properties: &HashMap) -> VendedPermission { } /// Parse duration from properties using a vendor-specific key, defaulting to DEFAULT_CREDENTIAL_DURATION_MILLIS +#[allow(dead_code)] fn parse_duration_millis(properties: &HashMap, key: &str) -> u64 { properties .get(key) diff --git a/rust/lance-namespace-impls/src/credentials/aws.rs b/rust/lance-namespace-impls/src/credentials/aws.rs index 96e0e8a2a80..d9b363e37e0 100644 --- a/rust/lance-namespace-impls/src/credentials/aws.rs +++ b/rust/lance-namespace-impls/src/credentials/aws.rs @@ -11,9 +11,12 @@ use std::collections::HashMap; use async_trait::async_trait; use aws_config::BehaviorVersion; use aws_sdk_sts::Client as StsClient; +use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use lance_core::{Error, Result}; use lance_io::object_store::uri_to_url; -use log::{debug, info}; +use lance_namespace::models::Identity; +use log::{debug, info, warn}; +use sha2::{Digest, Sha256}; use super::{ redact_credential, CredentialVendor, VendedCredentials, VendedPermission, @@ -24,6 +27,7 @@ use super::{ #[derive(Debug, Clone)] pub struct AwsCredentialVendorConfig { /// The IAM role ARN to assume. + /// Used for both AssumeRole (static/api_key) and AssumeRoleWithWebIdentity (auth_token). pub role_arn: String, /// Optional external ID for the assume role request. @@ -43,7 +47,18 @@ pub struct AwsCredentialVendorConfig { /// Permission level for vended credentials. /// Default: Read (full read access) + /// Used to generate scoped IAM policy for all credential flows. pub permission: VendedPermission, + + /// Salt for API key hashing. + /// Required when using API key authentication. + /// API keys are hashed as: SHA256(api_key + ":" + salt) + pub api_key_salt: Option, + + /// Map of SHA256(api_key + ":" + salt) -> permission level. + /// When an API key is provided, its hash is looked up in this map. + /// If found, the mapped permission is used instead of the default permission. + pub api_key_hash_permissions: HashMap, } impl AwsCredentialVendorConfig { @@ -56,6 +71,8 @@ impl AwsCredentialVendorConfig { role_session_name: None, region: None, permission: VendedPermission::default(), + api_key_salt: None, + api_key_hash_permissions: HashMap::new(), } } @@ -88,6 +105,32 @@ impl AwsCredentialVendorConfig { self.permission = permission; self } + + /// Set the API key salt for hashing. + pub fn with_api_key_salt(mut self, salt: impl Into) -> Self { + self.api_key_salt = Some(salt.into()); + self + } + + /// Add an API key hash to permission mapping. + pub fn with_api_key_hash_permission( + mut self, + key_hash: impl Into, + permission: VendedPermission, + ) -> Self { + self.api_key_hash_permissions + .insert(key_hash.into(), permission); + self + } + + /// Set the entire API key hash permissions map. + pub fn with_api_key_hash_permissions( + mut self, + permissions: HashMap, + ) -> Self { + self.api_key_hash_permissions = permissions; + self + } } /// AWS credential vendor that uses STS AssumeRole. @@ -206,60 +249,84 @@ impl AwsCredentialVendor { policy.to_string() } -} -#[async_trait] -impl CredentialVendor for AwsCredentialVendor { - async fn vend_credentials(&self, table_location: &str) -> Result { - debug!( - "AWS credential vending: location={}, permission={}", - table_location, self.config.permission - ); + /// Hash an API key using SHA-256 with salt (Polaris pattern). + /// Format: SHA256(api_key + ":" + salt) as hex string. + pub fn hash_api_key(api_key: &str, salt: &str) -> String { + let mut hasher = Sha256::new(); + hasher.update(format!("{}:{}", api_key, salt)); + format!("{:x}", hasher.finalize()) + } - let (bucket, prefix) = Self::parse_s3_uri(table_location)?; - let policy = Self::build_policy(&bucket, &prefix, self.config.permission); + /// Extract a session name from a JWT token (best effort, no validation). + /// Decodes the payload and extracts 'sub' or 'email' claim. + /// Falls back to "lance-web-identity" if parsing fails. + fn derive_session_name_from_token(token: &str) -> String { + // JWT format: header.payload.signature + let parts: Vec<&str> = token.split('.').collect(); + if parts.len() != 3 { + return "lance-web-identity".to_string(); + } - let role_session_name = self - .config - .role_session_name - .clone() - .unwrap_or_else(|| "lance-credential-vending".to_string()); + // Decode the payload (second part) + let payload = match URL_SAFE_NO_PAD.decode(parts[1]) { + Ok(bytes) => bytes, + Err(_) => { + // Try standard base64 as fallback + match base64::engine::general_purpose::STANDARD_NO_PAD.decode(parts[1]) { + Ok(bytes) => bytes, + Err(_) => return "lance-web-identity".to_string(), + } + } + }; - // Cap session name to 64 chars (AWS limit) - let role_session_name = if role_session_name.len() > 64 { - role_session_name[..64].to_string() - } else { - role_session_name + // Parse as JSON and extract 'sub' or 'email' + let json: serde_json::Value = match serde_json::from_slice(&payload) { + Ok(v) => v, + Err(_) => return "lance-web-identity".to_string(), }; - // Convert millis to seconds for AWS API (rounding up to ensure at least the requested duration) - // AWS STS allows 900-43200 seconds (15 min - 12 hours), clamp to valid range - let duration_secs = self.config.duration_millis.div_ceil(1000).clamp(900, 43200) as i32; + let subject = json + .get("sub") + .or_else(|| json.get("email")) + .and_then(|v| v.as_str()) + .unwrap_or("unknown"); - let mut request = self - .sts_client - .assume_role() - .role_arn(&self.config.role_arn) - .role_session_name(&role_session_name) - .policy(&policy) - .duration_seconds(duration_secs); + // Sanitize for role session name (alphanumeric, =, @, -, .) + let sanitized: String = subject + .chars() + .filter(|c| c.is_alphanumeric() || *c == '=' || *c == '@' || *c == '-' || *c == '.') + .collect(); - if let Some(ref external_id) = self.config.external_id { - request = request.external_id(external_id); + let session_name = format!("lance-{}", sanitized); + + // Cap to 64 chars (AWS limit) + if session_name.len() > 64 { + session_name[..64].to_string() + } else { + session_name } + } - let response = request.send().await.map_err(|e| Error::IO { - source: Box::new(std::io::Error::other(format!( - "Failed to assume role '{}': {}", - self.config.role_arn, e - ))), - location: snafu::location!(), - })?; + /// Cap a session name to 64 characters (AWS limit). + fn cap_session_name(name: &str) -> String { + if name.len() > 64 { + name[..64].to_string() + } else { + name.to_string() + } + } - let credentials = response.credentials().ok_or_else(|| Error::IO { - source: Box::new(std::io::Error::other( - "AssumeRole response missing credentials", - )), + /// Extract credentials from an STS Credentials response. + fn extract_credentials( + &self, + credentials: Option<&aws_sdk_sts::types::Credentials>, + bucket: &str, + prefix: &str, + permission: VendedPermission, + ) -> Result { + let credentials = credentials.ok_or_else(|| Error::IO { + source: Box::new(std::io::Error::other("STS response missing credentials")), location: snafu::location!(), })?; @@ -273,7 +340,7 @@ impl CredentialVendor for AwsCredentialVendor { info!( "AWS credentials vended: bucket={}, prefix={}, permission={}, expires_at={}, access_key_id={}", - bucket, prefix, self.config.permission, expires_at_millis, redact_credential(&access_key_id) + bucket, prefix, permission, expires_at_millis, redact_credential(&access_key_id) ); let mut storage_options = HashMap::new(); @@ -293,6 +360,211 @@ impl CredentialVendor for AwsCredentialVendor { Ok(VendedCredentials::new(storage_options, expires_at_millis)) } + /// Vend credentials using AssumeRoleWithWebIdentity (for auth_token). + async fn vend_with_web_identity( + &self, + bucket: &str, + prefix: &str, + auth_token: &str, + policy: &str, + ) -> Result { + let session_name = Self::derive_session_name_from_token(auth_token); + let duration_secs = self.config.duration_millis.div_ceil(1000).clamp(900, 43200) as i32; + + debug!( + "AWS AssumeRoleWithWebIdentity: role={}, session={}, permission={}", + self.config.role_arn, session_name, self.config.permission + ); + + let response = self + .sts_client + .assume_role_with_web_identity() + .role_arn(&self.config.role_arn) + .web_identity_token(auth_token) + .role_session_name(&session_name) + .policy(policy) + .duration_seconds(duration_secs) + .send() + .await + .map_err(|e| Error::IO { + source: Box::new(std::io::Error::other(format!( + "AssumeRoleWithWebIdentity failed for role '{}': {}", + self.config.role_arn, e + ))), + location: snafu::location!(), + })?; + + self.extract_credentials( + response.credentials(), + bucket, + prefix, + self.config.permission, + ) + } + + /// Vend credentials using AssumeRole with API key validation. + async fn vend_with_api_key( + &self, + bucket: &str, + prefix: &str, + api_key: &str, + ) -> Result { + let salt = self + .config + .api_key_salt + .as_ref() + .ok_or_else(|| Error::InvalidInput { + source: "api_key_salt must be configured to use API key authentication".into(), + location: snafu::location!(), + })?; + + let key_hash = Self::hash_api_key(api_key, salt); + + // Look up permission from hash mapping + let permission = self + .config + .api_key_hash_permissions + .get(&key_hash) + .copied() + .ok_or_else(|| { + warn!( + "Invalid API key: hash {} not found in permissions map", + &key_hash[..8] + ); + Error::InvalidInput { + source: "Invalid API key".into(), + location: snafu::location!(), + } + })?; + + let policy = Self::build_policy(bucket, prefix, permission); + let session_name = Self::cap_session_name(&format!("lance-api-{}", &key_hash[..16])); + let duration_secs = self.config.duration_millis.div_ceil(1000).clamp(900, 43200) as i32; + + debug!( + "AWS AssumeRole with API key: role={}, session={}, permission={}", + self.config.role_arn, session_name, permission + ); + + let request = self + .sts_client + .assume_role() + .role_arn(&self.config.role_arn) + .role_session_name(&session_name) + .policy(&policy) + .duration_seconds(duration_secs) + .external_id(&key_hash); // Use hash as external_id + + let response = request.send().await.map_err(|e| Error::IO { + source: Box::new(std::io::Error::other(format!( + "AssumeRole with API key failed for role '{}': {}", + self.config.role_arn, e + ))), + location: snafu::location!(), + })?; + + self.extract_credentials(response.credentials(), bucket, prefix, permission) + } + + /// Vend credentials using AssumeRole with static configuration. + async fn vend_with_static_config( + &self, + bucket: &str, + prefix: &str, + policy: &str, + ) -> Result { + let role_session_name = self + .config + .role_session_name + .clone() + .unwrap_or_else(|| "lance-credential-vending".to_string()); + let role_session_name = Self::cap_session_name(&role_session_name); + + let duration_secs = self.config.duration_millis.div_ceil(1000).clamp(900, 43200) as i32; + + debug!( + "AWS AssumeRole (static): role={}, session={}, permission={}", + self.config.role_arn, role_session_name, self.config.permission + ); + + let mut request = self + .sts_client + .assume_role() + .role_arn(&self.config.role_arn) + .role_session_name(&role_session_name) + .policy(policy) + .duration_seconds(duration_secs); + + if let Some(ref external_id) = self.config.external_id { + request = request.external_id(external_id); + } + + let response = request.send().await.map_err(|e| Error::IO { + source: Box::new(std::io::Error::other(format!( + "AssumeRole failed for role '{}': {}", + self.config.role_arn, e + ))), + location: snafu::location!(), + })?; + + self.extract_credentials( + response.credentials(), + bucket, + prefix, + self.config.permission, + ) + } +} + +#[async_trait] +impl CredentialVendor for AwsCredentialVendor { + async fn vend_credentials( + &self, + table_location: &str, + identity: Option<&Identity>, + ) -> Result { + debug!( + "AWS credential vending: location={}, permission={}, has_identity={}", + table_location, + self.config.permission, + identity.is_some() + ); + + let (bucket, prefix) = Self::parse_s3_uri(table_location)?; + + match identity { + Some(id) if id.auth_token.is_some() => { + // Use AssumeRoleWithWebIdentity with configured permission + let policy = Self::build_policy(&bucket, &prefix, self.config.permission); + self.vend_with_web_identity( + &bucket, + &prefix, + id.auth_token.as_ref().unwrap(), + &policy, + ) + .await + } + Some(id) if id.api_key.is_some() => { + // Use AssumeRole with API key validation and mapped permission + self.vend_with_api_key(&bucket, &prefix, id.api_key.as_ref().unwrap()) + .await + } + Some(_) => { + // Identity provided but neither api_key nor auth_token set + Err(Error::InvalidInput { + source: "Identity provided but neither api_key nor auth_token is set".into(), + location: snafu::location!(), + }) + } + None => { + // Use AssumeRole with static configuration + let policy = Self::build_policy(&bucket, &prefix, self.config.permission); + self.vend_with_static_config(&bucket, &prefix, &policy) + .await + } + } + } + fn provider_name(&self) -> &'static str { "aws" } @@ -543,7 +815,7 @@ mod tests { .expect("should create read vendor"); let read_creds = read_vendor - .vend_credentials(&table_location) + .vend_credentials(&table_location, None) .await .expect("should vend read credentials"); @@ -582,7 +854,7 @@ mod tests { .expect("should create admin vendor"); let admin_creds = admin_vendor - .vend_credentials(&table_location) + .vend_credentials(&table_location, None) .await .expect("should vend admin credentials"); @@ -627,8 +899,7 @@ mod tests { // Create a child namespace let create_ns_req = CreateNamespaceRequest { id: Some(vec!["test_ns".to_string()]), - properties: None, - mode: None, + ..Default::default() }; namespace .create_namespace(create_ns_req) @@ -640,6 +911,7 @@ mod tests { let create_table_req = CreateTableRequest { id: Some(vec!["test_ns".to_string(), "test_table".to_string()]), mode: Some("Create".to_string()), + ..Default::default() }; let create_response = namespace .create_table(create_table_req, table_data) @@ -704,8 +976,7 @@ mod tests { // List tables to verify the table was created let list_req = ListTablesRequest { id: Some(vec!["test_ns".to_string()]), - page_token: None, - limit: None, + ..Default::default() }; let list_response = namespace .list_tables(list_req) @@ -719,6 +990,7 @@ mod tests { // Clean up: drop the table let drop_req = DropTableRequest { id: Some(vec!["test_ns".to_string(), "test_table".to_string()]), + ..Default::default() }; namespace .drop_table(drop_req) @@ -755,12 +1027,12 @@ mod tests { // Vend credentials multiple times to verify consistent behavior let creds1 = vendor - .vend_credentials(&table_location) + .vend_credentials(&table_location, None) .await .expect("should vend credentials first time"); let creds2 = vendor - .vend_credentials(&table_location) + .vend_credentials(&table_location, None) .await .expect("should vend credentials second time"); @@ -802,13 +1074,13 @@ mod tests { // Vend credentials for table1 let creds1 = vendor - .vend_credentials(&table1_location) + .vend_credentials(&table1_location, None) .await .expect("should vend credentials for table1"); // Vend credentials for table2 let creds2 = vendor - .vend_credentials(&table2_location) + .vend_credentials(&table2_location, None) .await .expect("should vend credentials for table2"); @@ -861,8 +1133,7 @@ mod tests { // Verify namespace works let create_ns_req = CreateNamespaceRequest { id: Some(vec!["props_test".to_string()]), - properties: None, - mode: None, + ..Default::default() }; namespace .create_namespace(create_ns_req) diff --git a/rust/lance-namespace-impls/src/credentials/azure.rs b/rust/lance-namespace-impls/src/credentials/azure.rs index 1d4e4ded081..75a711b7448 100644 --- a/rust/lance-namespace-impls/src/credentials/azure.rs +++ b/rust/lance-namespace-impls/src/credentials/azure.rs @@ -13,10 +13,14 @@ use async_trait::async_trait; use azure_core::auth::TokenCredential; use azure_identity::DefaultAzureCredential; use azure_storage::prelude::*; +use azure_storage::shared_access_signature::service_sas::{BlobSharedAccessSignature, SasKey}; use azure_storage_blobs::prelude::*; +use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use lance_core::{Error, Result}; use lance_io::object_store::uri_to_url; +use lance_namespace::models::Identity; use log::{debug, info, warn}; +use sha2::{Digest, Sha256}; use super::{ redact_credential, CredentialVendor, VendedCredentials, VendedPermission, @@ -38,7 +42,22 @@ pub struct AzureCredentialVendorConfig { /// Permission level for vended credentials. /// Default: Read (full read access) + /// Used to generate SAS permissions for all credential flows. pub permission: VendedPermission, + + /// Client ID of the Azure AD App Registration for Workload Identity Federation. + /// Required when using auth_token identity for OIDC token exchange. + pub federated_client_id: Option, + + /// Salt for API key hashing. + /// Required when using API key authentication. + /// API keys are hashed as: SHA256(api_key + ":" + salt) + pub api_key_salt: Option, + + /// Map of SHA256(api_key + ":" + salt) -> permission level. + /// When an API key is provided, its hash is looked up in this map. + /// If found, the mapped permission is used instead of the default permission. + pub api_key_hash_permissions: HashMap, } impl Default for AzureCredentialVendorConfig { @@ -48,6 +67,9 @@ impl Default for AzureCredentialVendorConfig { account_name: None, duration_millis: DEFAULT_CREDENTIAL_DURATION_MILLIS, permission: VendedPermission::default(), + federated_client_id: None, + api_key_salt: None, + api_key_hash_permissions: HashMap::new(), } } } @@ -81,18 +103,105 @@ impl AzureCredentialVendorConfig { self.permission = permission; self } + + /// Set the federated client ID for Workload Identity Federation. + pub fn with_federated_client_id(mut self, client_id: impl Into) -> Self { + self.federated_client_id = Some(client_id.into()); + self + } + + /// Set the API key salt for hashing. + pub fn with_api_key_salt(mut self, salt: impl Into) -> Self { + self.api_key_salt = Some(salt.into()); + self + } + + /// Add an API key hash to permission mapping. + pub fn with_api_key_hash_permission( + mut self, + key_hash: impl Into, + permission: VendedPermission, + ) -> Self { + self.api_key_hash_permissions + .insert(key_hash.into(), permission); + self + } + + /// Set the entire API key hash permissions map. + pub fn with_api_key_hash_permissions( + mut self, + permissions: HashMap, + ) -> Self { + self.api_key_hash_permissions = permissions; + self + } } /// Azure credential vendor that generates SAS tokens. #[derive(Debug)] pub struct AzureCredentialVendor { config: AzureCredentialVendorConfig, + http_client: reqwest::Client, } impl AzureCredentialVendor { /// Create a new Azure credential vendor with the specified configuration. pub fn new(config: AzureCredentialVendorConfig) -> Self { - Self { config } + Self { + config, + http_client: reqwest::Client::new(), + } + } + + /// Hash an API key using SHA-256 with salt (Polaris pattern). + /// Format: SHA256(api_key + ":" + salt) as hex string. + pub fn hash_api_key(api_key: &str, salt: &str) -> String { + let mut hasher = Sha256::new(); + hasher.update(format!("{}:{}", api_key, salt)); + format!("{:x}", hasher.finalize()) + } + + /// Extract a session name from a JWT token (best effort, no validation). + /// Decodes the payload and extracts 'sub' or 'email' claim. + /// Falls back to "lance-azure-identity" if parsing fails. + fn derive_session_name_from_token(token: &str) -> String { + // JWT format: header.payload.signature + let parts: Vec<&str> = token.split('.').collect(); + if parts.len() != 3 { + return "lance-azure-identity".to_string(); + } + + // Decode the payload (second part) + let payload = match URL_SAFE_NO_PAD.decode(parts[1]) { + Ok(bytes) => bytes, + Err(_) => { + // Try standard base64 as fallback + match base64::engine::general_purpose::STANDARD_NO_PAD.decode(parts[1]) { + Ok(bytes) => bytes, + Err(_) => return "lance-azure-identity".to_string(), + } + } + }; + + // Parse as JSON and extract 'sub' or 'email' + let json: serde_json::Value = match serde_json::from_slice(&payload) { + Ok(v) => v, + Err(_) => return "lance-azure-identity".to_string(), + }; + + let subject = json + .get("sub") + .or_else(|| json.get("email")) + .and_then(|v| v.as_str()) + .unwrap_or("unknown"); + + // Sanitize: keep only alphanumeric, @, -, . + let sanitized: String = subject + .chars() + .filter(|c| c.is_alphanumeric() || *c == '@' || *c == '-' || *c == '.') + .collect(); + + format!("lance-{}", sanitized) } /// Build SAS permissions based on the VendedPermission level. @@ -196,61 +305,596 @@ impl AzureCredentialVendor { Ok((token, expires_at_millis)) } -} -#[async_trait] -impl CredentialVendor for AzureCredentialVendor { - async fn vend_credentials(&self, table_location: &str) -> Result { - debug!( - "Azure credential vending: location={}, permission={}", - table_location, self.config.permission - ); + /// Generate a SAS token with a specific permission level. + async fn generate_sas_token_with_permission( + &self, + account: &str, + container: &str, + permission: VendedPermission, + ) -> Result<(String, u64)> { + let credential = + DefaultAzureCredential::create(azure_identity::TokenCredentialOptions::default()) + .map_err(|e| Error::IO { + source: Box::new(std::io::Error::other(format!( + "Failed to create Azure credentials: {}", + e + ))), + location: snafu::location!(), + })?; - let url = uri_to_url(table_location)?; + let credential: Arc = Arc::new(credential); + let blob_service_client = BlobServiceClient::new(account, credential.clone()); - let container = url.host_str().ok_or_else(|| Error::InvalidInput { - source: format!("Azure URI '{}' missing container", table_location).into(), + let now = time::OffsetDateTime::now_utc(); + let duration_millis = self.config.duration_millis as i64; + let end_time = now + time::Duration::milliseconds(duration_millis); + + let max_key_end = now + time::Duration::days(7) - time::Duration::seconds(60); + let key_end_time = if end_time > max_key_end { + max_key_end + } else { + end_time + }; + + let user_delegation_key = blob_service_client + .get_user_deligation_key(now, key_end_time) + .await + .map_err(|e| Error::IO { + source: Box::new(std::io::Error::other(format!( + "Failed to get user delegation key for account '{}': {}", + account, e + ))), + location: snafu::location!(), + })?; + + let permissions = Self::build_sas_permissions(permission); + let container_client = blob_service_client.container_client(container); + + let sas_token = container_client + .user_delegation_shared_access_signature( + permissions, + &user_delegation_key.user_deligation_key, + ) + .await + .map_err(|e| Error::IO { + source: Box::new(std::io::Error::other(format!( + "Failed to generate SAS token for container '{}': {}", + container, e + ))), + location: snafu::location!(), + })?; + + let expires_at_millis = + (end_time.unix_timestamp() * 1000 + end_time.millisecond() as i64) as u64; + + let token = sas_token.token().map_err(|e| Error::IO { + source: Box::new(std::io::Error::other(format!( + "Failed to get SAS token: {}", + e + ))), location: snafu::location!(), })?; - // Check if path extends beyond container level - let path = url.path().trim_start_matches('/'); - if !path.is_empty() { - warn!( - "Azure SAS tokens are scoped to container level only. \ - Credentials for '{}' will have access to entire container '{}', not just path '{}'", - table_location, container, path - ); - } + Ok((token, expires_at_millis)) + } - let account = + /// Generate a directory-scoped SAS token. + /// + /// Unlike container-level SAS tokens, this restricts access to a specific directory + /// path within the container. This is more secure for multi-tenant scenarios. + /// + /// # Arguments + /// * `account` - Storage account name + /// * `container` - Container name + /// * `path` - Directory path within the container (e.g., "tenant-a/tables/my-table") + /// * `permission` - Permission level for the SAS token + async fn generate_directory_sas_token( + &self, + account: &str, + container: &str, + path: &str, + permission: VendedPermission, + ) -> Result<(String, u64)> { + let credential = + DefaultAzureCredential::create(azure_identity::TokenCredentialOptions::default()) + .map_err(|e| Error::IO { + source: Box::new(std::io::Error::other(format!( + "Failed to create Azure credentials: {}", + e + ))), + location: snafu::location!(), + })?; + + let credential: Arc = Arc::new(credential); + let blob_service_client = BlobServiceClient::new(account, credential.clone()); + + let now = time::OffsetDateTime::now_utc(); + let duration_millis = self.config.duration_millis as i64; + let end_time = now + time::Duration::milliseconds(duration_millis); + + let max_key_end = now + time::Duration::days(7) - time::Duration::seconds(60); + let key_end_time = if end_time > max_key_end { + max_key_end + } else { + end_time + }; + + let user_delegation_key = blob_service_client + .get_user_deligation_key(now, key_end_time) + .await + .map_err(|e| Error::IO { + source: Box::new(std::io::Error::other(format!( + "Failed to get user delegation key for account '{}': {}", + account, e + ))), + location: snafu::location!(), + })?; + + // Normalize path: remove leading/trailing slashes + let normalized_path = path.trim_matches('/'); + let depth = if normalized_path.is_empty() { + 0 + } else { + normalized_path.split('/').count() + }; + + // Build canonical resource path for directory-level SAS + let canonical_resource = format!("/blob/{}/{}/{}", account, container, normalized_path); + + // Convert user delegation key to SasKey + let sas_key = SasKey::UserDelegationKey(user_delegation_key.user_deligation_key); + + let permissions = Self::build_sas_permissions(permission); + + // Create directory-scoped SAS signature + let sas = BlobSharedAccessSignature::new( + sas_key, + canonical_resource, + permissions, + end_time, + BlobSignedResource::Directory, + ) + .signed_directory_depth(depth as u8); + + let token = sas.token().map_err(|e| Error::IO { + source: Box::new(std::io::Error::other(format!( + "Failed to generate directory SAS token: {}", + e + ))), + location: snafu::location!(), + })?; + + let expires_at_millis = + (end_time.unix_timestamp() * 1000 + end_time.millisecond() as i64) as u64; + + info!( + "Azure directory-scoped SAS generated: account={}, container={}, path={}, depth={}, permission={}", + account, container, normalized_path, depth, permission + ); + + Ok((token, expires_at_millis)) + } + + /// Exchange an OIDC token for Azure AD access token using Workload Identity Federation. + /// + /// This requires: + /// 1. An Azure AD App Registration with Federated Credentials configured + /// 2. The OIDC token's issuer and subject to match the Federated Credential configuration + async fn exchange_oidc_for_azure_token(&self, oidc_token: &str) -> Result { + let tenant_id = self + .config + .tenant_id + .as_ref() + .ok_or_else(|| Error::InvalidInput { + source: "azure_tenant_id must be configured for OIDC token exchange".into(), + location: snafu::location!(), + })?; + + let client_id = self.config - .account_name + .federated_client_id .as_ref() .ok_or_else(|| Error::InvalidInput { - source: "Azure credential vending requires 'credential_vendor.azure_account_name' to be set in configuration".into(), + source: "azure_federated_client_id must be configured for OIDC token exchange" + .into(), + location: snafu::location!(), + })?; + + let token_url = format!( + "https://login.microsoftonline.com/{}/oauth2/v2.0/token", + tenant_id + ); + + let params = [ + ("grant_type", "client_credentials"), + ( + "client_assertion_type", + "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", + ), + ("client_assertion", oidc_token), + ("client_id", client_id), + ("scope", "https://storage.azure.com/.default"), + ]; + + let response = self + .http_client + .post(&token_url) + .form(¶ms) + .send() + .await + .map_err(|e| Error::IO { + source: Box::new(std::io::Error::other(format!( + "Failed to exchange OIDC token for Azure AD token: {}", + e + ))), + location: snafu::location!(), + })?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + return Err(Error::IO { + source: Box::new(std::io::Error::other(format!( + "Azure AD token exchange failed with status {}: {}", + status, body + ))), + location: snafu::location!(), + }); + } + + let token_response: serde_json::Value = response.json().await.map_err(|e| Error::IO { + source: Box::new(std::io::Error::other(format!( + "Failed to parse Azure AD token response: {}", + e + ))), + location: snafu::location!(), + })?; + + token_response + .get("access_token") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + .ok_or_else(|| Error::IO { + source: Box::new(std::io::Error::other( + "Azure AD token response missing access_token", + )), + location: snafu::location!(), + }) + } + + /// Generate a SAS token using a federated Azure AD token. + /// + /// Uses directory-scoped SAS when path is provided, container-level otherwise. + async fn generate_sas_with_azure_token( + &self, + azure_token: &str, + account: &str, + container: &str, + path: &str, + permission: VendedPermission, + ) -> Result<(String, u64)> { + // Create a custom TokenCredential that uses our Azure AD token + let credential = FederatedTokenCredential::new(azure_token.to_string()); + let credential: Arc = Arc::new(credential); + + let blob_service_client = BlobServiceClient::new(account, credential.clone()); + + let now = time::OffsetDateTime::now_utc(); + let duration_millis = self.config.duration_millis as i64; + let end_time = now + time::Duration::milliseconds(duration_millis); + + let max_key_end = now + time::Duration::days(7) - time::Duration::seconds(60); + let key_end_time = if end_time > max_key_end { + max_key_end + } else { + end_time + }; + + let user_delegation_key = blob_service_client + .get_user_deligation_key(now, key_end_time) + .await + .map_err(|e| Error::IO { + source: Box::new(std::io::Error::other(format!( + "Failed to get user delegation key with federated token: {}", + e + ))), + location: snafu::location!(), + })?; + + let permissions = Self::build_sas_permissions(permission); + + let expires_at_millis = + (end_time.unix_timestamp() * 1000 + end_time.millisecond() as i64) as u64; + + // Use directory-scoped SAS when path is provided + let normalized_path = path.trim_matches('/'); + let token = if normalized_path.is_empty() { + // Container-level SAS + let container_client = blob_service_client.container_client(container); + let sas_token = container_client + .user_delegation_shared_access_signature( + permissions, + &user_delegation_key.user_deligation_key, + ) + .await + .map_err(|e| Error::IO { + source: Box::new(std::io::Error::other(format!( + "Failed to generate SAS token with federated token: {}", + e + ))), location: snafu::location!(), })?; - let (sas_token, expires_at_millis) = self.generate_sas_token(account, container).await?; + sas_token.token().map_err(|e| Error::IO { + source: Box::new(std::io::Error::other(format!( + "Failed to get SAS token: {}", + e + ))), + location: snafu::location!(), + })? + } else { + // Directory-scoped SAS + let depth = normalized_path.split('/').count(); + let canonical_resource = format!("/blob/{}/{}/{}", account, container, normalized_path); + let sas_key = SasKey::UserDelegationKey(user_delegation_key.user_deligation_key); + + let sas = BlobSharedAccessSignature::new( + sas_key, + canonical_resource, + permissions, + end_time, + BlobSignedResource::Directory, + ) + .signed_directory_depth(depth as u8); + + sas.token().map_err(|e| Error::IO { + source: Box::new(std::io::Error::other(format!( + "Failed to generate directory SAS token with federated token: {}", + e + ))), + location: snafu::location!(), + })? + }; + + Ok((token, expires_at_millis)) + } + + /// Vend credentials using Workload Identity Federation (for auth_token). + async fn vend_with_web_identity( + &self, + account: &str, + container: &str, + path: &str, + auth_token: &str, + ) -> Result { + let session_name = Self::derive_session_name_from_token(auth_token); + debug!( + "Azure vend_with_web_identity: account={}, container={}, path={}, session={}", + account, container, path, session_name + ); + + // Exchange OIDC token for Azure AD token + let azure_token = self.exchange_oidc_for_azure_token(auth_token).await?; + + // Generate SAS token using the Azure AD token + // Use directory-scoped SAS when path is provided + let (sas_token, expires_at_millis) = self + .generate_sas_with_azure_token( + &azure_token, + account, + container, + path, + self.config.permission, + ) + .await?; let mut storage_options = HashMap::new(); - // Use the standard key that object_store/lance-io expects storage_options.insert("azure_storage_sas_token".to_string(), sas_token.clone()); - storage_options.insert("azure_storage_account_name".to_string(), account.clone()); + storage_options.insert( + "azure_storage_account_name".to_string(), + account.to_string(), + ); storage_options.insert( "expires_at_millis".to_string(), expires_at_millis.to_string(), ); info!( - "Azure credentials vended: account={}, container={}, permission={}, expires_at={}, sas_token={}", - account, container, self.config.permission, expires_at_millis, redact_credential(&sas_token) + "Azure credentials vended (web identity): account={}, container={}, path={}, permission={}, expires_at={}, sas_token={}", + account, container, path, self.config.permission, expires_at_millis, redact_credential(&sas_token) ); Ok(VendedCredentials::new(storage_options, expires_at_millis)) } + /// Vend credentials using API key validation. + async fn vend_with_api_key( + &self, + account: &str, + container: &str, + path: &str, + api_key: &str, + ) -> Result { + let salt = self + .config + .api_key_salt + .as_ref() + .ok_or_else(|| Error::InvalidInput { + source: "api_key_salt must be configured to use API key authentication".into(), + location: snafu::location!(), + })?; + + let key_hash = Self::hash_api_key(api_key, salt); + + // Look up permission from hash mapping + let permission = self + .config + .api_key_hash_permissions + .get(&key_hash) + .copied() + .ok_or_else(|| { + warn!( + "Invalid API key: hash {} not found in permissions map", + &key_hash[..8] + ); + Error::InvalidInput { + source: "Invalid API key".into(), + location: snafu::location!(), + } + })?; + + debug!( + "Azure vend_with_api_key: account={}, container={}, path={}, permission={}", + account, container, path, permission + ); + + // Use directory-scoped SAS when path is provided, container-level otherwise + let (sas_token, expires_at_millis) = if path.is_empty() { + self.generate_sas_token_with_permission(account, container, permission) + .await? + } else { + self.generate_directory_sas_token(account, container, path, permission) + .await? + }; + + let mut storage_options = HashMap::new(); + storage_options.insert("azure_storage_sas_token".to_string(), sas_token.clone()); + storage_options.insert( + "azure_storage_account_name".to_string(), + account.to_string(), + ); + storage_options.insert( + "expires_at_millis".to_string(), + expires_at_millis.to_string(), + ); + + info!( + "Azure credentials vended (api_key): account={}, container={}, path={}, permission={}, expires_at={}, sas_token={}", + account, container, path, permission, expires_at_millis, redact_credential(&sas_token) + ); + + Ok(VendedCredentials::new(storage_options, expires_at_millis)) + } +} + +/// A custom TokenCredential that wraps a pre-obtained Azure AD access token. +#[derive(Debug)] +struct FederatedTokenCredential { + token: String, +} + +impl FederatedTokenCredential { + fn new(token: String) -> Self { + Self { token } + } +} + +#[async_trait] +impl TokenCredential for FederatedTokenCredential { + async fn get_token( + &self, + _scopes: &[&str], + ) -> std::result::Result { + // Return the pre-obtained token with a 1-hour expiry (conservative estimate) + let expires_on = time::OffsetDateTime::now_utc() + time::Duration::hours(1); + Ok(azure_core::auth::AccessToken::new( + azure_core::auth::Secret::new(self.token.clone()), + expires_on, + )) + } + + async fn clear_cache(&self) -> std::result::Result<(), azure_core::Error> { + Ok(()) + } +} + +#[async_trait] +impl CredentialVendor for AzureCredentialVendor { + async fn vend_credentials( + &self, + table_location: &str, + identity: Option<&Identity>, + ) -> Result { + debug!( + "Azure credential vending: location={}, permission={}, identity={:?}", + table_location, + self.config.permission, + identity.map(|i| format!( + "api_key={}, auth_token={}", + i.api_key.is_some(), + i.auth_token.is_some() + )) + ); + + let url = uri_to_url(table_location)?; + + let container = url.host_str().ok_or_else(|| Error::InvalidInput { + source: format!("Azure URI '{}' missing container", table_location).into(), + location: snafu::location!(), + })?; + + // Extract path for directory-scoped SAS + let path = url.path().trim_start_matches('/'); + + let account = + self.config + .account_name + .as_ref() + .ok_or_else(|| Error::InvalidInput { + source: "Azure credential vending requires 'credential_vendor.azure_account_name' to be set in configuration".into(), + location: snafu::location!(), + })?; + + // Dispatch based on identity + match identity { + Some(id) if id.auth_token.is_some() => { + let auth_token = id.auth_token.as_ref().unwrap(); + self.vend_with_web_identity(account, container, path, auth_token) + .await + } + Some(id) if id.api_key.is_some() => { + let api_key = id.api_key.as_ref().unwrap(); + self.vend_with_api_key(account, container, path, api_key) + .await + } + Some(_) => Err(Error::InvalidInput { + source: "Identity provided but neither auth_token nor api_key is set".into(), + location: snafu::location!(), + }), + None => { + // Static credential vending using DefaultAzureCredential + // Use directory-scoped SAS when path is provided, container-level otherwise + let (sas_token, expires_at_millis) = if path.is_empty() { + self.generate_sas_token(account, container).await? + } else { + self.generate_directory_sas_token( + account, + container, + path, + self.config.permission, + ) + .await? + }; + + let mut storage_options = HashMap::new(); + storage_options.insert("azure_storage_sas_token".to_string(), sas_token.clone()); + storage_options.insert("azure_storage_account_name".to_string(), account.clone()); + storage_options.insert( + "expires_at_millis".to_string(), + expires_at_millis.to_string(), + ); + + info!( + "Azure credentials vended (static): account={}, container={}, path={}, permission={}, expires_at={}, sas_token={}", + account, container, path, self.config.permission, expires_at_millis, redact_credential(&sas_token) + ); + + Ok(VendedCredentials::new(storage_options, expires_at_millis)) + } + } + } + fn provider_name(&self) -> &'static str { "azure" } diff --git a/rust/lance-namespace-impls/src/credentials/cache.rs b/rust/lance-namespace-impls/src/credentials/cache.rs new file mode 100644 index 00000000000..6e7c6c4dcf7 --- /dev/null +++ b/rust/lance-namespace-impls/src/credentials/cache.rs @@ -0,0 +1,438 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Credential caching for cloud storage access. +//! +//! This module provides a caching wrapper for credential vendors that reduces +//! the number of credential vending requests (e.g., STS calls) by caching +//! credentials until they are close to expiration. +//! +//! ## Caching Strategy +//! +//! - **Cache Key**: Table location + identity hash (api_key hash or auth_token hash) +//! - **TTL**: Half of the credential's remaining lifetime, capped at 30 minutes +//! - **Eviction**: Credentials are evicted when TTL expires or when explicitly cleared +//! +//! ## Example +//! +//! ```ignore +//! use lance_namespace_impls::credentials::cache::CachingCredentialVendor; +//! +//! let vendor = AwsCredentialVendor::new(config).await?; +//! let cached_vendor = CachingCredentialVendor::new(Box::new(vendor)); +//! +//! // First call hits the underlying vendor +//! let creds1 = cached_vendor.vend_credentials("s3://bucket/table", None).await?; +//! +//! // Subsequent calls within TTL return cached credentials +//! let creds2 = cached_vendor.vend_credentials("s3://bucket/table", None).await?; +//! ``` + +use std::collections::HashMap; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use async_trait::async_trait; +use lance_core::Result; +use lance_namespace::models::Identity; +use log::debug; +use tokio::sync::RwLock; + +use super::{CredentialVendor, VendedCredentials, VendedPermission}; + +/// Maximum cache TTL: 30 minutes. +/// Even if credentials are valid for longer, we refresh more frequently +/// to handle clock skew and ensure freshness. +const MAX_CACHE_TTL_SECS: u64 = 30 * 60; + +/// Minimum cache TTL: 1 minute. +/// If credentials expire sooner than this, we don't cache them. +const MIN_CACHE_TTL_SECS: u64 = 60; + +/// A cached credential entry with expiration tracking. +#[derive(Clone)] +struct CacheEntry { + credentials: VendedCredentials, + /// When this cache entry should be considered stale + cached_until: Instant, +} + +impl std::fmt::Debug for CacheEntry { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CacheEntry") + .field("credentials", &"[redacted]") + .field("cached_until", &self.cached_until) + .finish() + } +} + +impl CacheEntry { + fn is_stale(&self) -> bool { + Instant::now() >= self.cached_until + } +} + +/// A caching wrapper for credential vendors. +/// +/// This wrapper caches vended credentials to reduce the number of underlying +/// credential vending operations (e.g., STS calls). Credentials are cached +/// until half their lifetime has passed, capped at 30 minutes. +#[derive(Debug)] +pub struct CachingCredentialVendor { + inner: Box, + cache: Arc>>, +} + +impl CachingCredentialVendor { + /// Create a new caching credential vendor wrapping the given vendor. + pub fn new(inner: Box) -> Self { + Self { + inner, + cache: Arc::new(RwLock::new(HashMap::new())), + } + } + + /// Build a cache key from the table location and identity. + /// + /// The key is a hash of the location and identity fields to ensure + /// different identities get different cached credentials. + fn build_cache_key(table_location: &str, identity: Option<&Identity>) -> String { + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + + table_location.hash(&mut hasher); + + if let Some(id) = identity { + if let Some(ref api_key) = id.api_key { + ":api_key:".hash(&mut hasher); + api_key.hash(&mut hasher); + } + if let Some(ref auth_token) = id.auth_token { + ":auth_token:".hash(&mut hasher); + // Only hash first 64 chars of token to avoid memory issues with large tokens + let token_prefix = if auth_token.len() > 64 { + &auth_token[..64] + } else { + auth_token.as_str() + }; + token_prefix.hash(&mut hasher); + } + } else { + ":no_identity".hash(&mut hasher); + } + + format!("{:016x}", hasher.finish()) + } + + /// Calculate the cache TTL for the given credentials. + /// + /// Returns the TTL as a Duration, or None if the credentials should not be cached. + fn calculate_cache_ttl(credentials: &VendedCredentials) -> Option { + let now_millis = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("time went backwards") + .as_millis() as u64; + + if credentials.expires_at_millis <= now_millis { + // Already expired + return None; + } + + let remaining_millis = credentials.expires_at_millis - now_millis; + let remaining_secs = remaining_millis / 1000; + + // TTL is half the remaining lifetime + let ttl_secs = remaining_secs / 2; + + // Cap between MIN and MAX + if ttl_secs < MIN_CACHE_TTL_SECS { + None // Don't cache if TTL is too short + } else { + Some(Duration::from_secs(ttl_secs.min(MAX_CACHE_TTL_SECS))) + } + } + + /// Clear all cached credentials. + pub async fn clear_cache(&self) { + let mut cache = self.cache.write().await; + cache.clear(); + debug!("Credential cache cleared"); + } + + /// Get the number of cached entries. + pub async fn cache_size(&self) -> usize { + let cache = self.cache.read().await; + cache.len() + } + + /// Remove stale entries from the cache. + pub async fn evict_stale(&self) -> usize { + let mut cache = self.cache.write().await; + let before = cache.len(); + cache.retain(|_, entry| !entry.is_stale()); + let evicted = before - cache.len(); + if evicted > 0 { + debug!("Evicted {} stale credential cache entries", evicted); + } + evicted + } +} + +#[async_trait] +impl CredentialVendor for CachingCredentialVendor { + async fn vend_credentials( + &self, + table_location: &str, + identity: Option<&Identity>, + ) -> Result { + let cache_key = Self::build_cache_key(table_location, identity); + + // Try to get from cache first + { + let cache = self.cache.read().await; + if let Some(entry) = cache.get(&cache_key) { + if !entry.is_stale() && !entry.credentials.is_expired() { + debug!( + "Credential cache hit for location={}, provider={}", + table_location, + self.inner.provider_name() + ); + return Ok(entry.credentials.clone()); + } + } + } + + // Cache miss or stale - vend new credentials + debug!( + "Credential cache miss for location={}, provider={}", + table_location, + self.inner.provider_name() + ); + + let credentials = self + .inner + .vend_credentials(table_location, identity) + .await?; + + // Cache the new credentials if TTL is sufficient + if let Some(ttl) = Self::calculate_cache_ttl(&credentials) { + let entry = CacheEntry { + credentials: credentials.clone(), + cached_until: Instant::now() + ttl, + }; + + let mut cache = self.cache.write().await; + cache.insert(cache_key, entry); + + debug!( + "Cached credentials for location={}, ttl={}s", + table_location, + ttl.as_secs() + ); + } + + Ok(credentials) + } + + fn provider_name(&self) -> &'static str { + self.inner.provider_name() + } + + fn permission(&self) -> VendedPermission { + self.inner.permission() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::atomic::{AtomicU32, Ordering}; + + /// A mock credential vendor for testing. + #[derive(Debug)] + struct MockVendor { + call_count: AtomicU32, + duration_millis: u64, + } + + impl MockVendor { + fn new(duration_millis: u64) -> Self { + Self { + call_count: AtomicU32::new(0), + duration_millis, + } + } + } + + #[async_trait] + impl CredentialVendor for MockVendor { + async fn vend_credentials( + &self, + _table_location: &str, + _identity: Option<&Identity>, + ) -> Result { + self.call_count.fetch_add(1, Ordering::SeqCst); + + let now_millis = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_millis() as u64; + + let mut storage_options = HashMap::new(); + storage_options.insert("test_key".to_string(), "test_value".to_string()); + + Ok(VendedCredentials::new( + storage_options, + now_millis + self.duration_millis, + )) + } + + fn provider_name(&self) -> &'static str { + "mock" + } + + fn permission(&self) -> VendedPermission { + VendedPermission::Read + } + } + + #[test] + fn test_build_cache_key_no_identity() { + let key1 = CachingCredentialVendor::build_cache_key("s3://bucket/table1", None); + let key2 = CachingCredentialVendor::build_cache_key("s3://bucket/table2", None); + let key3 = CachingCredentialVendor::build_cache_key("s3://bucket/table1", None); + + assert_ne!(key1, key2, "Different locations should have different keys"); + assert_eq!(key1, key3, "Same location should have same key"); + } + + #[test] + fn test_build_cache_key_with_identity() { + let identity_api = Identity { + api_key: Some("my-api-key".to_string()), + auth_token: None, + }; + let identity_token = Identity { + api_key: None, + auth_token: Some("my-token".to_string()), + }; + + let key_no_id = CachingCredentialVendor::build_cache_key("s3://bucket/table", None); + let key_api = + CachingCredentialVendor::build_cache_key("s3://bucket/table", Some(&identity_api)); + let key_token = + CachingCredentialVendor::build_cache_key("s3://bucket/table", Some(&identity_token)); + + assert_ne!(key_no_id, key_api, "Identity should change key"); + assert_ne!(key_no_id, key_token, "Identity should change key"); + assert_ne!( + key_api, key_token, + "Different identity types should have different keys" + ); + } + + #[test] + fn test_calculate_cache_ttl() { + let now_millis = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_millis() as u64; + + // Credentials with 1 hour remaining -> TTL should be 30 minutes (capped) + let creds_1h = VendedCredentials::new(HashMap::new(), now_millis + 3600 * 1000); + let ttl = CachingCredentialVendor::calculate_cache_ttl(&creds_1h); + assert_eq!(ttl, Some(Duration::from_secs(MAX_CACHE_TTL_SECS))); + + // Credentials with 10 minutes remaining -> TTL should be 5 minutes + let creds_10m = VendedCredentials::new(HashMap::new(), now_millis + 10 * 60 * 1000); + let ttl = CachingCredentialVendor::calculate_cache_ttl(&creds_10m); + assert_eq!(ttl, Some(Duration::from_secs(5 * 60))); + + // Credentials with 1 minute remaining -> TTL should be None (too short) + let creds_1m = VendedCredentials::new(HashMap::new(), now_millis + 60 * 1000); + let ttl = CachingCredentialVendor::calculate_cache_ttl(&creds_1m); + assert!(ttl.is_none(), "Should not cache short-lived credentials"); + + // Already expired credentials -> None + let creds_expired = VendedCredentials::new(HashMap::new(), now_millis - 1000); + let ttl = CachingCredentialVendor::calculate_cache_ttl(&creds_expired); + assert!(ttl.is_none(), "Should not cache expired credentials"); + } + + #[tokio::test] + async fn test_caching_reduces_calls() { + // Create a mock vendor with 1 hour credentials + let mock = MockVendor::new(3600 * 1000); + let cached = CachingCredentialVendor::new(Box::new(mock)); + + // First call should hit the underlying vendor + let _ = cached + .vend_credentials("s3://bucket/table", None) + .await + .unwrap(); + assert_eq!(cached.cache_size().await, 1); + + // Get reference to inner mock for call count + // We can't easily get the call count from the boxed trait, so we'll check cache size + + // Second call should use cache (cache size stays at 1) + let _ = cached + .vend_credentials("s3://bucket/table", None) + .await + .unwrap(); + assert_eq!(cached.cache_size().await, 1); + + // Different location should create new cache entry + let _ = cached + .vend_credentials("s3://bucket/table2", None) + .await + .unwrap(); + assert_eq!(cached.cache_size().await, 2); + } + + #[tokio::test] + async fn test_clear_cache() { + let mock = MockVendor::new(3600 * 1000); + let cached = CachingCredentialVendor::new(Box::new(mock)); + + let _ = cached + .vend_credentials("s3://bucket/table", None) + .await + .unwrap(); + assert_eq!(cached.cache_size().await, 1); + + cached.clear_cache().await; + assert_eq!(cached.cache_size().await, 0); + } + + #[tokio::test] + async fn test_different_identities_cached_separately() { + let mock = MockVendor::new(3600 * 1000); + let cached = CachingCredentialVendor::new(Box::new(mock)); + + let identity1 = Identity { + api_key: Some("key1".to_string()), + auth_token: None, + }; + let identity2 = Identity { + api_key: Some("key2".to_string()), + auth_token: None, + }; + + // Same location with different identities should cache separately + let _ = cached + .vend_credentials("s3://bucket/table", Some(&identity1)) + .await + .unwrap(); + let _ = cached + .vend_credentials("s3://bucket/table", Some(&identity2)) + .await + .unwrap(); + let _ = cached + .vend_credentials("s3://bucket/table", None) + .await + .unwrap(); + + assert_eq!(cached.cache_size().await, 3); + } +} diff --git a/rust/lance-namespace-impls/src/credentials/gcp.rs b/rust/lance-namespace-impls/src/credentials/gcp.rs index ce4bac40fa1..0749bdb1b97 100644 --- a/rust/lance-namespace-impls/src/credentials/gcp.rs +++ b/rust/lance-namespace-impls/src/credentials/gcp.rs @@ -44,12 +44,15 @@ use std::collections::HashMap; use async_trait::async_trait; +use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use google_cloud_auth::credentials; use lance_core::{Error, Result}; use lance_io::object_store::uri_to_url; -use log::{debug, info}; +use lance_namespace::models::Identity; +use log::{debug, info, warn}; use reqwest::Client; use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; use super::{redact_credential, CredentialVendor, VendedCredentials, VendedPermission}; @@ -79,6 +82,31 @@ pub struct GcpCredentialVendorConfig { /// Note: GCP token duration cannot be configured; the token lifetime /// is determined by the STS endpoint (typically 1 hour). pub permission: VendedPermission, + + /// Workload Identity Provider resource name for OIDC token exchange. + /// Required when using auth_token identity for Workload Identity Federation. + /// + /// Format: `projects/{project_number}/locations/global/workloadIdentityPools/{pool_id}/providers/{provider_id}` + /// + /// The OIDC token's issuer must match the provider's configuration. + pub workload_identity_provider: Option, + + /// Service account to impersonate after Workload Identity Federation. + /// Optional - if set, the exchanged token will be used to generate an + /// access token for this service account. + /// + /// Format: `my-sa@project.iam.gserviceaccount.com` + pub impersonation_service_account: Option, + + /// Salt for API key hashing. + /// Required when using API key authentication. + /// API keys are hashed as: SHA256(api_key + ":" + salt) + pub api_key_salt: Option, + + /// Map of SHA256(api_key + ":" + salt) -> permission level. + /// When an API key is provided, its hash is looked up in this map. + /// If found, the mapped permission is used instead of the default permission. + pub api_key_hash_permissions: HashMap, } impl GcpCredentialVendorConfig { @@ -104,6 +132,47 @@ impl GcpCredentialVendorConfig { self.permission = permission; self } + + /// Set the Workload Identity Provider for OIDC token exchange. + pub fn with_workload_identity_provider(mut self, provider: impl Into) -> Self { + self.workload_identity_provider = Some(provider.into()); + self + } + + /// Set the service account to impersonate after Workload Identity Federation. + pub fn with_impersonation_service_account( + mut self, + service_account: impl Into, + ) -> Self { + self.impersonation_service_account = Some(service_account.into()); + self + } + + /// Set the API key salt for hashing. + pub fn with_api_key_salt(mut self, salt: impl Into) -> Self { + self.api_key_salt = Some(salt.into()); + self + } + + /// Add an API key hash to permission mapping. + pub fn with_api_key_hash_permission( + mut self, + key_hash: impl Into, + permission: VendedPermission, + ) -> Self { + self.api_key_hash_permissions + .insert(key_hash.into(), permission); + self + } + + /// Set the entire API key hash permissions map. + pub fn with_api_key_hash_permissions( + mut self, + permissions: HashMap, + ) -> Self { + self.api_key_hash_permissions = permissions; + self + } } /// Access boundary rule for a single resource. @@ -459,25 +528,237 @@ impl GcpCredentialVendor { Ok((token_response.access_token, expires_at_millis)) } -} -#[async_trait] -impl CredentialVendor for GcpCredentialVendor { - async fn vend_credentials(&self, table_location: &str) -> Result { + /// Hash an API key using SHA-256 with salt (Polaris pattern). + /// Format: SHA256(api_key + ":" + salt) as hex string. + pub fn hash_api_key(api_key: &str, salt: &str) -> String { + let mut hasher = Sha256::new(); + hasher.update(format!("{}:{}", api_key, salt)); + format!("{:x}", hasher.finalize()) + } + + /// Extract a session name from a JWT token (best effort, no validation). + /// Decodes the payload and extracts 'sub' or 'email' claim. + /// Falls back to "lance-gcp-identity" if parsing fails. + fn derive_session_name_from_token(token: &str) -> String { + // JWT format: header.payload.signature + let parts: Vec<&str> = token.split('.').collect(); + if parts.len() != 3 { + return "lance-gcp-identity".to_string(); + } + + // Decode the payload (second part) + let payload = match URL_SAFE_NO_PAD.decode(parts[1]) { + Ok(bytes) => bytes, + Err(_) => { + // Try standard base64 as fallback + match base64::engine::general_purpose::STANDARD_NO_PAD.decode(parts[1]) { + Ok(bytes) => bytes, + Err(_) => return "lance-gcp-identity".to_string(), + } + } + }; + + // Parse as JSON and extract 'sub' or 'email' + let json: serde_json::Value = match serde_json::from_slice(&payload) { + Ok(v) => v, + Err(_) => return "lance-gcp-identity".to_string(), + }; + + let subject = json + .get("sub") + .or_else(|| json.get("email")) + .and_then(|v| v.as_str()) + .unwrap_or("unknown"); + + // Sanitize: keep only alphanumeric, @, -, . + let sanitized: String = subject + .chars() + .filter(|c| c.is_alphanumeric() || *c == '@' || *c == '-' || *c == '.') + .collect(); + + format!("lance-{}", sanitized) + } + + /// Normalize the Workload Identity Provider to the full audience format expected by GCP STS. + /// + /// GCP STS expects audience in the format: + /// `//iam.googleapis.com/projects/{project}/locations/global/workloadIdentityPools/{pool}/providers/{provider}` + /// + /// This function accepts either: + /// - Full format: `//iam.googleapis.com/projects/...` + /// - Short format: `projects/...` (will be prefixed with `//iam.googleapis.com/`) + fn normalize_workload_identity_audience(provider: &str) -> String { + const IAM_PREFIX: &str = "//iam.googleapis.com/"; + if provider.starts_with(IAM_PREFIX) { + provider.to_string() + } else { + format!("{}{}", IAM_PREFIX, provider) + } + } + + /// Exchange an OIDC token for GCP access token using Workload Identity Federation. + /// + /// This requires: + /// 1. A Workload Identity Pool and Provider configured in GCP + /// 2. The OIDC token's issuer to match the provider's configuration + /// 3. Optionally, a service account to impersonate after token exchange + async fn exchange_oidc_for_gcp_token(&self, oidc_token: &str) -> Result { + let workload_identity_provider = self + .config + .workload_identity_provider + .as_ref() + .ok_or_else(|| Error::InvalidInput { + source: "gcp_workload_identity_provider must be configured for OIDC token exchange" + .into(), + location: snafu::location!(), + })?; + + // Normalize audience to full format expected by GCP STS + let audience = Self::normalize_workload_identity_audience(workload_identity_provider); + + // Exchange OIDC token for GCP federated token via STS + let params = [ + ( + "grant_type", + "urn:ietf:params:oauth:grant-type:token-exchange", + ), + ("subject_token_type", "urn:ietf:params:oauth:token-type:jwt"), + ( + "requested_token_type", + "urn:ietf:params:oauth:token-type:access_token", + ), + ("subject_token", oidc_token), + ("audience", audience.as_str()), + ("scope", "https://www.googleapis.com/auth/cloud-platform"), + ]; + + let response = self + .http_client + .post(STS_TOKEN_EXCHANGE_URL) + .form(¶ms) + .send() + .await + .map_err(|e| Error::IO { + source: Box::new(std::io::Error::other(format!( + "Failed to exchange OIDC token for GCP token: {}", + e + ))), + location: snafu::location!(), + })?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + return Err(Error::IO { + source: Box::new(std::io::Error::other(format!( + "GCP STS token exchange failed with status {}: {}", + status, body + ))), + location: snafu::location!(), + }); + } + + let token_response: TokenExchangeResponse = + response.json().await.map_err(|e| Error::IO { + source: Box::new(std::io::Error::other(format!( + "Failed to parse GCP STS token response: {}", + e + ))), + location: snafu::location!(), + })?; + + let federated_token = token_response.access_token; + + // If impersonation is configured, use the federated token to get an impersonated token + if let Some(ref service_account) = self.config.impersonation_service_account { + return self + .impersonate_service_account(&federated_token, service_account) + .await; + } + + Ok(federated_token) + } + + /// Vend credentials using Workload Identity Federation (for auth_token). + async fn vend_with_web_identity( + &self, + bucket: &str, + prefix: &str, + auth_token: &str, + ) -> Result { + let session_name = Self::derive_session_name_from_token(auth_token); debug!( - "GCP credential vending: location={}, permission={}", - table_location, self.config.permission + "GCP vend_with_web_identity: bucket={}, prefix={}, session={}", + bucket, prefix, session_name ); - let (bucket, prefix) = Self::parse_gcs_uri(table_location)?; + // Exchange OIDC token for GCP token + let gcp_token = self.exchange_oidc_for_gcp_token(auth_token).await?; - // Get source token from default credentials - let source_token = self.get_source_token().await?; + // Build access boundary and downscope + let access_boundary = Self::build_access_boundary(bucket, prefix, self.config.permission); + let (downscoped_token, expires_at_millis) = + self.downscope_token(&gcp_token, &access_boundary).await?; + + let mut storage_options = HashMap::new(); + storage_options.insert("google_storage_token".to_string(), downscoped_token.clone()); + storage_options.insert( + "expires_at_millis".to_string(), + expires_at_millis.to_string(), + ); + + info!( + "GCP credentials vended (web identity): bucket={}, prefix={}, permission={}, expires_at={}, token={}", + bucket, prefix, self.config.permission, expires_at_millis, redact_credential(&downscoped_token) + ); - // Build access boundary for this location and permission - let access_boundary = Self::build_access_boundary(&bucket, &prefix, self.config.permission); + Ok(VendedCredentials::new(storage_options, expires_at_millis)) + } + + /// Vend credentials using API key validation. + async fn vend_with_api_key( + &self, + bucket: &str, + prefix: &str, + api_key: &str, + ) -> Result { + let salt = self + .config + .api_key_salt + .as_ref() + .ok_or_else(|| Error::InvalidInput { + source: "api_key_salt must be configured to use API key authentication".into(), + location: snafu::location!(), + })?; - // Exchange for downscoped token + let key_hash = Self::hash_api_key(api_key, salt); + + // Look up permission from hash mapping + let permission = self + .config + .api_key_hash_permissions + .get(&key_hash) + .copied() + .ok_or_else(|| { + warn!( + "Invalid API key: hash {} not found in permissions map", + &key_hash[..8] + ); + Error::InvalidInput { + source: "Invalid API key".into(), + location: snafu::location!(), + } + })?; + + debug!( + "GCP vend_with_api_key: bucket={}, prefix={}, permission={}", + bucket, prefix, permission + ); + + // Get source token using ADC and downscope with the API key's permission + let source_token = self.get_source_token().await?; + let access_boundary = Self::build_access_boundary(bucket, prefix, permission); let (downscoped_token, expires_at_millis) = self .downscope_token(&source_token, &access_boundary) .await?; @@ -490,16 +771,75 @@ impl CredentialVendor for GcpCredentialVendor { ); info!( - "GCP credentials vended: bucket={}, prefix={}, permission={}, expires_at={}, token={}", - bucket, - prefix, - self.config.permission, - expires_at_millis, - redact_credential(&downscoped_token) + "GCP credentials vended (api_key): bucket={}, prefix={}, permission={}, expires_at={}, token={}", + bucket, prefix, permission, expires_at_millis, redact_credential(&downscoped_token) ); Ok(VendedCredentials::new(storage_options, expires_at_millis)) } +} + +#[async_trait] +impl CredentialVendor for GcpCredentialVendor { + async fn vend_credentials( + &self, + table_location: &str, + identity: Option<&Identity>, + ) -> Result { + debug!( + "GCP credential vending: location={}, permission={}, identity={:?}", + table_location, + self.config.permission, + identity.map(|i| format!( + "api_key={}, auth_token={}", + i.api_key.is_some(), + i.auth_token.is_some() + )) + ); + + let (bucket, prefix) = Self::parse_gcs_uri(table_location)?; + + // Dispatch based on identity + match identity { + Some(id) if id.auth_token.is_some() => { + let auth_token = id.auth_token.as_ref().unwrap(); + self.vend_with_web_identity(&bucket, &prefix, auth_token) + .await + } + Some(id) if id.api_key.is_some() => { + let api_key = id.api_key.as_ref().unwrap(); + self.vend_with_api_key(&bucket, &prefix, api_key).await + } + Some(_) => Err(Error::InvalidInput { + source: "Identity provided but neither auth_token nor api_key is set".into(), + location: snafu::location!(), + }), + None => { + // Static credential vending using ADC + let source_token = self.get_source_token().await?; + let access_boundary = + Self::build_access_boundary(&bucket, &prefix, self.config.permission); + let (downscoped_token, expires_at_millis) = self + .downscope_token(&source_token, &access_boundary) + .await?; + + let mut storage_options = HashMap::new(); + storage_options + .insert("google_storage_token".to_string(), downscoped_token.clone()); + storage_options.insert( + "expires_at_millis".to_string(), + expires_at_millis.to_string(), + ); + + info!( + "GCP credentials vended (static): bucket={}, prefix={}, permission={}, expires_at={}, token={}", + bucket, prefix, self.config.permission, expires_at_millis, redact_credential(&downscoped_token) + ); + + Ok(VendedCredentials::new(storage_options, expires_at_millis)) + } + } + } fn provider_name(&self) -> &'static str { "gcp" @@ -634,4 +974,26 @@ mod tests { // No condition when prefix is empty (full bucket access) assert!(rules[0].availability_condition.is_none()); } + + #[test] + fn test_normalize_workload_identity_audience() { + // Short format should be prefixed + let short = + "projects/123456/locations/global/workloadIdentityPools/my-pool/providers/my-provider"; + let normalized = GcpCredentialVendor::normalize_workload_identity_audience(short); + assert_eq!( + normalized, + "//iam.googleapis.com/projects/123456/locations/global/workloadIdentityPools/my-pool/providers/my-provider" + ); + + // Full format should be unchanged + let full = "//iam.googleapis.com/projects/123456/locations/global/workloadIdentityPools/my-pool/providers/my-provider"; + let normalized = GcpCredentialVendor::normalize_workload_identity_audience(full); + assert_eq!(normalized, full); + + // Edge case: already has prefix (idempotent) + let normalized_again = + GcpCredentialVendor::normalize_workload_identity_audience(&normalized); + assert_eq!(normalized_again, full); + } } diff --git a/rust/lance-namespace-impls/src/dir.rs b/rust/lance-namespace-impls/src/dir.rs index b0170023e3c..2168324a308 100644 --- a/rust/lance-namespace-impls/src/dir.rs +++ b/rust/lance-namespace-impls/src/dir.rs @@ -26,7 +26,7 @@ use lance_namespace::models::{ CreateNamespaceResponse, CreateTableRequest, CreateTableResponse, DeclareTableRequest, DeclareTableResponse, DescribeNamespaceRequest, DescribeNamespaceResponse, DescribeTableRequest, DescribeTableResponse, DropNamespaceRequest, DropNamespaceResponse, - DropTableRequest, DropTableResponse, ListNamespacesRequest, ListNamespacesResponse, + DropTableRequest, DropTableResponse, Identity, ListNamespacesRequest, ListNamespacesResponse, ListTablesRequest, ListTablesResponse, NamespaceExistsRequest, TableExistsRequest, }; @@ -712,12 +712,14 @@ impl DirectoryNamespace { /// # Arguments /// /// * `table_uri` - The full URI of the table + /// * `identity` - Optional identity from the request for identity-based credential vending async fn get_storage_options_for_table( &self, table_uri: &str, + identity: Option<&Identity>, ) -> Result>> { if let Some(ref vendor) = self.credential_vendor { - let vended = vendor.vend_credentials(table_uri).await?; + let vended = vendor.vend_credentials(table_uri, identity).await?; return Ok(Some(vended.storage_options)); } Ok(self.storage_options.clone()) @@ -828,8 +830,10 @@ impl LanceNamespace for DirectoryNamespace { } Self::validate_root_namespace_id(&request.id)?; + #[allow(clippy::needless_update)] Ok(DescribeNamespaceResponse { properties: Some(HashMap::new()), + ..Default::default() }) } @@ -962,7 +966,20 @@ impl LanceNamespace for DirectoryNamespace { async fn describe_table(&self, request: DescribeTableRequest) -> Result { if let Some(ref manifest_ns) = self.manifest_ns { match manifest_ns.describe_table(request.clone()).await { - Ok(response) => return Ok(response), + Ok(mut response) => { + // Only apply identity-based credential vending when explicitly requested + if request.vend_credentials == Some(true) && self.credential_vendor.is_some() { + if let Some(ref table_uri) = response.table_uri { + let identity = request.identity.as_deref(); + response.storage_options = self + .get_storage_options_for_table(table_uri, identity) + .await?; + } + } else if request.vend_credentials == Some(false) { + response.storage_options = None; + } + return Ok(response); + } Err(_) if self.dir_listing_enabled && request.id.as_ref().is_some_and(|id| id.len() == 1) => @@ -993,6 +1010,35 @@ impl LanceNamespace for DirectoryNamespace { }); } + let load_detailed_metadata = request.load_detailed_metadata.unwrap_or(false); + // For backwards compatibility, only skip vending credentials when explicitly set to false + let vend_credentials = request.vend_credentials.unwrap_or(true); + let identity = request.identity.as_deref(); + + // If not loading detailed metadata, return minimal response with just location + if !load_detailed_metadata { + let storage_options = if vend_credentials { + self.get_storage_options_for_table(&table_uri, identity) + .await? + } else { + None + }; + return Ok(DescribeTableResponse { + table: Some(table_name), + namespace: request.id.as_ref().map(|id| { + if id.len() > 1 { + id[..id.len() - 1].to_vec() + } else { + vec![] + } + }), + location: Some(table_uri.clone()), + table_uri: Some(table_uri), + storage_options, + ..Default::default() + }); + } + // Try to load the dataset to get real information match Dataset::open(&table_uri).await { Ok(mut dataset) => { @@ -1001,11 +1047,20 @@ impl LanceNamespace for DirectoryNamespace { dataset = dataset.checkout_version(requested_version as u64).await?; } - let version = dataset.version().version; + let version_info = dataset.version(); let lance_schema = dataset.schema(); let arrow_schema: arrow_schema::Schema = lance_schema.into(); let json_schema = arrow_schema_to_json(&arrow_schema)?; - let storage_options = self.get_storage_options_for_table(&table_uri).await?; + let storage_options = if vend_credentials { + self.get_storage_options_for_table(&table_uri, identity) + .await? + } else { + None + }; + + // Convert BTreeMap to HashMap for the response + let metadata: std::collections::HashMap = + version_info.metadata.into_iter().collect(); Ok(DescribeTableResponse { table: Some(table_name), @@ -1016,18 +1071,24 @@ impl LanceNamespace for DirectoryNamespace { vec![] } }), - version: Some(version as i64), + version: Some(version_info.version as i64), location: Some(table_uri.clone()), table_uri: Some(table_uri), schema: Some(Box::new(json_schema)), storage_options, - stats: None, + metadata: Some(metadata), + ..Default::default() }) } Err(err) => { // Use the reserved file status from the atomic check if status.has_reserved_file { - let storage_options = self.get_storage_options_for_table(&table_uri).await?; + let storage_options = if vend_credentials { + self.get_storage_options_for_table(&table_uri, identity) + .await? + } else { + None + }; Ok(DescribeTableResponse { table: Some(table_name), namespace: request.id.as_ref().map(|id| { @@ -1037,12 +1098,10 @@ impl LanceNamespace for DirectoryNamespace { vec![] } }), - version: None, location: Some(table_uri.clone()), table_uri: Some(table_uri), - schema: None, storage_options, - stats: None, + ..Default::default() }) } else { Err(Error::Namespace { @@ -1111,8 +1170,7 @@ impl LanceNamespace for DirectoryNamespace { Ok(DropTableResponse { id: request.id, location: Some(table_uri), - properties: None, - transaction_id: None, + ..Default::default() }) } @@ -1181,10 +1239,10 @@ impl LanceNamespace for DirectoryNamespace { })?; Ok(CreateTableResponse { - transaction_id: None, version: Some(1), location: Some(table_uri), storage_options: self.storage_options.clone(), + ..Default::default() }) } @@ -1194,7 +1252,19 @@ impl LanceNamespace for DirectoryNamespace { ) -> Result { if let Some(ref manifest_ns) = self.manifest_ns { #[allow(deprecated)] - return manifest_ns.create_empty_table(request).await; + let mut response = manifest_ns.create_empty_table(request.clone()).await?; + // Only apply identity-based credential vending when explicitly requested + if request.vend_credentials == Some(true) && self.credential_vendor.is_some() { + if let Some(ref location) = response.location { + let identity = request.identity.as_deref(); + response.storage_options = self + .get_storage_options_for_table(location, identity) + .await?; + } + } else if request.vend_credentials == Some(false) { + response.storage_options = None; + } + return Ok(response); } let table_name = Self::table_name_from_id(&request.id)?; @@ -1226,16 +1296,38 @@ impl LanceNamespace for DirectoryNamespace { location: snafu::location!(), })?; + // For backwards compatibility, only skip vending credentials when explicitly set to false + let vend_credentials = request.vend_credentials.unwrap_or(true); + let identity = request.identity.as_deref(); + let storage_options = if vend_credentials { + self.get_storage_options_for_table(&table_uri, identity) + .await? + } else { + None + }; + Ok(CreateEmptyTableResponse { - transaction_id: None, location: Some(table_uri), - storage_options: self.storage_options.clone(), + storage_options, + ..Default::default() }) } async fn declare_table(&self, request: DeclareTableRequest) -> Result { if let Some(ref manifest_ns) = self.manifest_ns { - return manifest_ns.declare_table(request).await; + let mut response = manifest_ns.declare_table(request.clone()).await?; + // Only apply identity-based credential vending when explicitly requested + if request.vend_credentials == Some(true) && self.credential_vendor.is_some() { + if let Some(ref location) = response.location { + let identity = request.identity.as_deref(); + response.storage_options = self + .get_storage_options_for_table(location, identity) + .await?; + } + } else if request.vend_credentials == Some(false) { + response.storage_options = None; + } + return Ok(response); } let table_name = Self::table_name_from_id(&request.id)?; @@ -1280,10 +1372,20 @@ impl LanceNamespace for DirectoryNamespace { location: snafu::location!(), })?; + // For backwards compatibility, only skip vending credentials when explicitly set to false + let vend_credentials = request.vend_credentials.unwrap_or(true); + let identity = request.identity.as_deref(); + let storage_options = if vend_credentials { + self.get_storage_options_for_table(&table_uri, identity) + .await? + } else { + None + }; + Ok(DeclareTableResponse { - transaction_id: None, location: Some(table_uri), - storage_options: self.storage_options.clone(), + storage_options, + ..Default::default() }) } @@ -1361,8 +1463,7 @@ impl LanceNamespace for DirectoryNamespace { Ok(lance_namespace::models::DeregisterTableResponse { id: request.id, location: Some(table_uri), - properties: None, - transaction_id: None, + ..Default::default() }) } @@ -2161,8 +2262,7 @@ mod tests { // List child namespaces let list_req = ListNamespacesRequest { id: Some(vec![]), - page_token: None, - limit: None, + ..Default::default() }; let result = namespace.list_namespaces(list_req).await; assert!(result.is_ok()); @@ -2194,8 +2294,7 @@ mod tests { // List children of parent let list_req = ListNamespacesRequest { id: Some(vec!["parent".to_string()]), - page_token: None, - limit: None, + ..Default::default() }; let result = namespace.list_namespaces(list_req).await; assert!(result.is_ok()); @@ -2207,8 +2306,7 @@ mod tests { // List root should only show parent let list_req = ListNamespacesRequest { id: Some(vec![]), - page_token: None, - limit: None, + ..Default::default() }; let result = namespace.list_namespaces(list_req).await; assert!(result.is_ok()); @@ -2239,8 +2337,7 @@ mod tests { // List tables in child namespace let list_req = ListTablesRequest { id: Some(vec!["test_ns".to_string()]), - page_token: None, - limit: None, + ..Default::default() }; let result = namespace.list_tables(list_req).await; assert!(result.is_ok()); @@ -2287,8 +2384,7 @@ mod tests { // List tables let list_req = ListTablesRequest { id: Some(vec!["test_ns".to_string()]), - page_token: None, - limit: None, + ..Default::default() }; let result = namespace.list_tables(list_req).await; assert!(result.is_ok()); @@ -2425,6 +2521,7 @@ mod tests { // Describe namespace and verify properties let describe_req = DescribeNamespaceRequest { id: Some(vec!["test_ns".to_string()]), + ..Default::default() }; let result = namespace.describe_namespace(describe_req).await; assert!(result.is_ok()); @@ -2503,6 +2600,7 @@ mod tests { id: Some(vec!["ns1".to_string()]), page_token: None, limit: None, + ..Default::default() }; let result = namespace.list_tables(list_req).await.unwrap(); assert_eq!(result.tables.len(), 1); @@ -2512,6 +2610,7 @@ mod tests { id: Some(vec!["ns2".to_string()]), page_token: None, limit: None, + ..Default::default() }; let result = namespace.list_tables(list_req).await.unwrap(); assert_eq!(result.tables.len(), 1); diff --git a/rust/lance-namespace-impls/src/dir/manifest.rs b/rust/lance-namespace-impls/src/dir/manifest.rs index 8dca60b83d8..bfcb9602b9a 100644 --- a/rust/lance-namespace-impls/src/dir/manifest.rs +++ b/rust/lance-namespace-impls/src/dir/manifest.rs @@ -1087,11 +1087,33 @@ impl LanceNamespace for ManifestNamespace { vec![] }; + let load_detailed_metadata = request.load_detailed_metadata.unwrap_or(false); + // For backwards compatibility, only skip vending credentials when explicitly set to false + let vend_credentials = request.vend_credentials.unwrap_or(true); + match table_info { Some(info) => { // Construct full URI from relative location let table_uri = Self::construct_full_uri(&self.root, &info.location)?; + let storage_options = if vend_credentials { + self.storage_options.clone() + } else { + None + }; + + // If not loading detailed metadata, return minimal response with just location + if !load_detailed_metadata { + return Ok(DescribeTableResponse { + table: Some(table_name), + namespace: Some(namespace_id), + location: Some(table_uri.clone()), + table_uri: Some(table_uri), + storage_options, + ..Default::default() + }); + } + // Try to open the dataset to get version and schema match Dataset::open(&table_uri).await { Ok(mut dataset) => { @@ -1112,8 +1134,8 @@ impl LanceNamespace for ManifestNamespace { location: Some(table_uri.clone()), table_uri: Some(table_uri), schema: Some(Box::new(json_schema)), - storage_options: self.storage_options.clone(), - stats: None, + storage_options, + ..Default::default() }) } Err(_) => { @@ -1121,12 +1143,10 @@ impl LanceNamespace for ManifestNamespace { Ok(DescribeTableResponse { table: Some(table_name), namespace: Some(namespace_id), - version: None, location: Some(table_uri.clone()), table_uri: Some(table_uri), - schema: None, - storage_options: self.storage_options.clone(), - stats: None, + storage_options, + ..Default::default() }) } } @@ -1250,10 +1270,10 @@ impl LanceNamespace for ManifestNamespace { .await?; Ok(CreateTableResponse { - transaction_id: None, version: Some(1), location: Some(table_uri), storage_options: self.storage_options.clone(), + ..Default::default() }) } @@ -1297,8 +1317,7 @@ impl LanceNamespace for ManifestNamespace { Ok(DropTableResponse { id: request.id.clone(), location: Some(table_uri), - properties: None, - transaction_id: None, + ..Default::default() }) } None => Err(Error::Namespace { @@ -1370,8 +1389,10 @@ impl LanceNamespace for ManifestNamespace { // Root namespace always exists if namespace_id.is_empty() { + #[allow(clippy::needless_update)] return Ok(DescribeNamespaceResponse { properties: Some(HashMap::new()), + ..Default::default() }); } @@ -1380,8 +1401,10 @@ impl LanceNamespace for ManifestNamespace { let namespace_info = self.query_manifest_for_namespace(&object_id).await?; match namespace_info { + #[allow(clippy::needless_update)] Some(info) => Ok(DescribeNamespaceResponse { properties: info.metadata, + ..Default::default() }), None => Err(Error::Namespace { source: format!("Namespace '{}' not found", object_id).into(), @@ -1440,8 +1463,8 @@ impl LanceNamespace for ManifestNamespace { .await?; Ok(CreateNamespaceResponse { - transaction_id: None, properties: request.properties, + ..Default::default() }) } @@ -1503,10 +1526,7 @@ impl LanceNamespace for ManifestNamespace { self.delete_from_manifest(&object_id).await?; - Ok(DropNamespaceResponse { - properties: None, - transaction_id: None, - }) + Ok(DropNamespaceResponse::default()) } async fn namespace_exists(&self, request: NamespaceExistsRequest) -> Result<()> { @@ -1622,10 +1642,18 @@ impl LanceNamespace for ManifestNamespace { table_uri ); + // For backwards compatibility, only skip vending credentials when explicitly set to false + let vend_credentials = request.vend_credentials.unwrap_or(true); + let storage_options = if vend_credentials { + self.storage_options.clone() + } else { + None + }; + Ok(CreateEmptyTableResponse { - transaction_id: None, location: Some(table_uri), - storage_options: self.storage_options.clone(), + storage_options, + ..Default::default() }) } @@ -1717,10 +1745,18 @@ impl LanceNamespace for ManifestNamespace { table_uri ); + // For backwards compatibility, only skip vending credentials when explicitly set to false + let vend_credentials = request.vend_credentials.unwrap_or(true); + let storage_options = if vend_credentials { + self.storage_options.clone() + } else { + None + }; + Ok(DeclareTableResponse { - transaction_id: None, location: Some(table_uri), - storage_options: self.storage_options.clone(), + storage_options, + ..Default::default() }) } @@ -1793,9 +1829,8 @@ impl LanceNamespace for ManifestNamespace { .await?; Ok(RegisterTableResponse { - transaction_id: None, location: Some(location), - properties: None, + ..Default::default() }) } @@ -1836,10 +1871,9 @@ impl LanceNamespace for ManifestNamespace { }; Ok(DeregisterTableResponse { - transaction_id: None, id: request.id.clone(), location: Some(table_uri), - properties: None, + ..Default::default() }) } } @@ -2267,6 +2301,7 @@ mod tests { // Verify namespace exists let exists_req = NamespaceExistsRequest { id: Some(vec!["ns1".to_string()]), + ..Default::default() }; let result = dir_namespace.namespace_exists(exists_req).await; assert!(result.is_ok(), "Namespace should exist"); @@ -2276,6 +2311,7 @@ mod tests { id: Some(vec![]), page_token: None, limit: None, + ..Default::default() }; let result = dir_namespace.list_namespaces(list_req).await; assert!(result.is_ok()); @@ -2320,6 +2356,7 @@ mod tests { // Verify nested namespace exists let exists_req = NamespaceExistsRequest { id: Some(vec!["parent".to_string(), "child".to_string()]), + ..Default::default() }; let result = dir_namespace.namespace_exists(exists_req).await; assert!(result.is_ok(), "Nested namespace should exist"); @@ -2329,6 +2366,7 @@ mod tests { id: Some(vec!["parent".to_string()]), page_token: None, limit: None, + ..Default::default() }; let result = dir_namespace.list_namespaces(list_req).await; assert!(result.is_ok()); @@ -2396,6 +2434,7 @@ mod tests { // Verify namespace no longer exists let exists_req = NamespaceExistsRequest { id: Some(vec!["ns1".to_string()]), + ..Default::default() }; let result = dir_namespace.namespace_exists(exists_req).await; assert!(result.is_err(), "Namespace should not exist after drop"); @@ -2474,6 +2513,7 @@ mod tests { id: Some(vec!["ns1".to_string()]), page_token: None, limit: None, + ..Default::default() }; let result = dir_namespace.list_tables(list_req).await; assert!(result.is_ok()); @@ -2510,6 +2550,7 @@ mod tests { // Describe the namespace let describe_req = DescribeNamespaceRequest { id: Some(vec!["ns1".to_string()]), + ..Default::default() }; let result = dir_namespace.describe_namespace(describe_req).await; assert!( diff --git a/rust/lance-namespace-impls/src/rest.rs b/rust/lance-namespace-impls/src/rest.rs index f92d44cd305..020746487a4 100644 --- a/rust/lance-namespace-impls/src/rest.rs +++ b/rust/lance-namespace-impls/src/rest.rs @@ -472,6 +472,7 @@ impl LanceNamespace for RestNamespace { request.clone(), Some(&self.delimiter), request.with_table_uri, + request.load_detailed_metadata, ) .await .map_err(convert_api_error) @@ -1037,8 +1038,7 @@ mod tests { let request = ListNamespacesRequest { id: Some(vec!["test".to_string()]), - page_token: None, - limit: None, + ..Default::default() }; let result = namespace.list_namespaces(request).await; @@ -1160,8 +1160,8 @@ mod tests { let request = ListNamespacesRequest { id: Some(vec!["test".to_string()]), - page_token: None, limit: Some(10), + ..Default::default() }; let result = namespace.list_namespaces(request).await; @@ -1199,8 +1199,8 @@ mod tests { let request = ListNamespacesRequest { id: Some(vec!["test".to_string()]), - page_token: None, limit: Some(10), + ..Default::default() }; let result = namespace.list_namespaces(request).await; @@ -1235,8 +1235,7 @@ mod tests { let request = CreateNamespaceRequest { id: Some(vec!["test".to_string(), "newnamespace".to_string()]), - properties: None, - mode: None, + ..Default::default() }; let result = namespace.create_namespace(request).await; @@ -1277,6 +1276,7 @@ mod tests { "table".to_string(), ]), mode: Some("Create".to_string()), + ..Default::default() }; let data = Bytes::from("arrow data here"); @@ -1314,6 +1314,7 @@ mod tests { "table".to_string(), ]), mode: Some("Append".to_string()), + ..Default::default() }; let data = Bytes::from("arrow data here"); diff --git a/rust/lance-namespace-impls/src/rest_adapter.rs b/rust/lance-namespace-impls/src/rest_adapter.rs index f0d1c3ac60d..4a12b92838a 100644 --- a/rust/lance-namespace-impls/src/rest_adapter.rs +++ b/rust/lance-namespace-impls/src/rest_adapter.rs @@ -12,7 +12,7 @@ use std::sync::Arc; use axum::{ body::Bytes, extract::{Path, Query, Request, State}, - http::StatusCode, + http::{HeaderMap, StatusCode}, response::{IntoResponse, Response}, routing::{get, post}, Json, Router, ServiceExt, @@ -312,11 +312,13 @@ fn error_to_response(err: Error) -> Response { async fn create_namespace( State(backend): State>, + headers: HeaderMap, Path(id): Path, Query(params): Query, Json(mut request): Json, ) -> Response { request.id = Some(parse_id(&id, params.delimiter.as_deref())); + request.identity = extract_identity(&headers); match backend.create_namespace(request).await { Ok(response) => (StatusCode::CREATED, Json(response)).into_response(), @@ -326,6 +328,7 @@ async fn create_namespace( async fn list_namespaces( State(backend): State>, + headers: HeaderMap, Path(id): Path, Query(params): Query, ) -> Response { @@ -333,6 +336,8 @@ async fn list_namespaces( id: Some(parse_id(&id, params.delimiter.as_deref())), page_token: params.page_token, limit: params.limit, + identity: extract_identity(&headers), + ..Default::default() }; match backend.list_namespaces(request).await { @@ -343,11 +348,13 @@ async fn list_namespaces( async fn describe_namespace( State(backend): State>, + headers: HeaderMap, Path(id): Path, Query(params): Query, Json(mut request): Json, ) -> Response { request.id = Some(parse_id(&id, params.delimiter.as_deref())); + request.identity = extract_identity(&headers); match backend.describe_namespace(request).await { Ok(response) => (StatusCode::OK, Json(response)).into_response(), @@ -357,11 +364,13 @@ async fn describe_namespace( async fn drop_namespace( State(backend): State>, + headers: HeaderMap, Path(id): Path, Query(params): Query, Json(mut request): Json, ) -> Response { request.id = Some(parse_id(&id, params.delimiter.as_deref())); + request.identity = extract_identity(&headers); match backend.drop_namespace(request).await { Ok(response) => (StatusCode::OK, Json(response)).into_response(), @@ -371,11 +380,13 @@ async fn drop_namespace( async fn namespace_exists( State(backend): State>, + headers: HeaderMap, Path(id): Path, Query(params): Query, Json(mut request): Json, ) -> Response { request.id = Some(parse_id(&id, params.delimiter.as_deref())); + request.identity = extract_identity(&headers); match backend.namespace_exists(request).await { Ok(_) => StatusCode::NO_CONTENT.into_response(), @@ -389,6 +400,7 @@ async fn namespace_exists( async fn list_tables( State(backend): State>, + headers: HeaderMap, Path(id): Path, Query(params): Query, ) -> Response { @@ -396,6 +408,8 @@ async fn list_tables( id: Some(parse_id(&id, params.delimiter.as_deref())), page_token: params.page_token, limit: params.limit, + identity: extract_identity(&headers), + ..Default::default() }; match backend.list_tables(request).await { @@ -406,11 +420,13 @@ async fn list_tables( async fn register_table( State(backend): State>, + headers: HeaderMap, Path(id): Path, Query(params): Query, Json(mut request): Json, ) -> Response { request.id = Some(parse_id(&id, params.delimiter.as_deref())); + request.identity = extract_identity(&headers); match backend.register_table(request).await { Ok(response) => (StatusCode::OK, Json(response)).into_response(), @@ -420,11 +436,13 @@ async fn register_table( async fn describe_table( State(backend): State>, + headers: HeaderMap, Path(id): Path, Query(params): Query, Json(mut request): Json, ) -> Response { request.id = Some(parse_id(&id, params.delimiter.as_deref())); + request.identity = extract_identity(&headers); match backend.describe_table(request).await { Ok(response) => (StatusCode::OK, Json(response)).into_response(), @@ -434,11 +452,13 @@ async fn describe_table( async fn table_exists( State(backend): State>, + headers: HeaderMap, Path(id): Path, Query(params): Query, Json(mut request): Json, ) -> Response { request.id = Some(parse_id(&id, params.delimiter.as_deref())); + request.identity = extract_identity(&headers); match backend.table_exists(request).await { Ok(_) => StatusCode::NO_CONTENT.into_response(), @@ -448,11 +468,14 @@ async fn table_exists( async fn drop_table( State(backend): State>, + headers: HeaderMap, Path(id): Path, Query(params): Query, ) -> Response { let request = DropTableRequest { id: Some(parse_id(&id, params.delimiter.as_deref())), + identity: extract_identity(&headers), + ..Default::default() }; match backend.drop_table(request).await { @@ -463,11 +486,13 @@ async fn drop_table( async fn deregister_table( State(backend): State>, + headers: HeaderMap, Path(id): Path, Query(params): Query, Json(mut request): Json, ) -> Response { request.id = Some(parse_id(&id, params.delimiter.as_deref())); + request.identity = extract_identity(&headers); match backend.deregister_table(request).await { Ok(response) => (StatusCode::OK, Json(response)).into_response(), @@ -487,6 +512,7 @@ struct CreateTableQuery { async fn create_table( State(backend): State>, + headers: HeaderMap, Path(id): Path, Query(params): Query, body: Bytes, @@ -494,6 +520,8 @@ async fn create_table( let request = CreateTableRequest { id: Some(parse_id(&id, params.delimiter.as_deref())), mode: params.mode.clone(), + identity: extract_identity(&headers), + ..Default::default() }; match backend.create_table(request, body).await { @@ -505,11 +533,13 @@ async fn create_table( #[allow(deprecated)] async fn create_empty_table( State(backend): State>, + headers: HeaderMap, Path(id): Path, Query(params): Query, Json(mut request): Json, ) -> Response { request.id = Some(parse_id(&id, params.delimiter.as_deref())); + request.identity = extract_identity(&headers); match backend.create_empty_table(request).await { Ok(response) => (StatusCode::CREATED, Json(response)).into_response(), @@ -519,11 +549,13 @@ async fn create_empty_table( async fn declare_table( State(backend): State>, + headers: HeaderMap, Path(id): Path, Query(params): Query, Json(mut request): Json, ) -> Response { request.id = Some(parse_id(&id, params.delimiter.as_deref())); + request.identity = extract_identity(&headers); match backend.declare_table(request).await { Ok(response) => (StatusCode::CREATED, Json(response)).into_response(), @@ -539,6 +571,7 @@ struct InsertQuery { async fn insert_into_table( State(backend): State>, + headers: HeaderMap, Path(id): Path, Query(params): Query, body: Bytes, @@ -546,6 +579,8 @@ async fn insert_into_table( let request = InsertIntoTableRequest { id: Some(parse_id(&id, params.delimiter.as_deref())), mode: params.mode.clone(), + identity: extract_identity(&headers), + ..Default::default() }; match backend.insert_into_table(request, body).await { @@ -569,6 +604,7 @@ struct MergeInsertQuery { async fn merge_insert_into_table( State(backend): State>, + headers: HeaderMap, Path(id): Path, Query(params): Query, body: Bytes, @@ -583,6 +619,8 @@ async fn merge_insert_into_table( when_not_matched_by_source_delete_filt: params.when_not_matched_by_source_delete_filt, timeout: params.timeout, use_index: params.use_index, + identity: extract_identity(&headers), + ..Default::default() }; match backend.merge_insert_into_table(request, body).await { @@ -593,11 +631,13 @@ async fn merge_insert_into_table( async fn update_table( State(backend): State>, + headers: HeaderMap, Path(id): Path, Query(params): Query, Json(mut request): Json, ) -> Response { request.id = Some(parse_id(&id, params.delimiter.as_deref())); + request.identity = extract_identity(&headers); match backend.update_table(request).await { Ok(response) => (StatusCode::OK, Json(response)).into_response(), @@ -607,11 +647,13 @@ async fn update_table( async fn delete_from_table( State(backend): State>, + headers: HeaderMap, Path(id): Path, Query(params): Query, Json(mut request): Json, ) -> Response { request.id = Some(parse_id(&id, params.delimiter.as_deref())); + request.identity = extract_identity(&headers); match backend.delete_from_table(request).await { Ok(response) => (StatusCode::OK, Json(response)).into_response(), @@ -621,11 +663,13 @@ async fn delete_from_table( async fn query_table( State(backend): State>, + headers: HeaderMap, Path(id): Path, Query(params): Query, Json(mut request): Json, ) -> Response { request.id = Some(parse_id(&id, params.delimiter.as_deref())); + request.identity = extract_identity(&headers); match backend.query_table(request).await { Ok(bytes) => (StatusCode::OK, bytes).into_response(), @@ -635,6 +679,7 @@ async fn query_table( async fn count_table_rows( State(backend): State>, + headers: HeaderMap, Path(id): Path, Query(params): Query, ) -> Response { @@ -642,6 +687,8 @@ async fn count_table_rows( id: Some(parse_id(&id, params.delimiter.as_deref())), version: None, predicate: None, + identity: extract_identity(&headers), + ..Default::default() }; match backend.count_table_rows(request).await { @@ -656,11 +703,13 @@ async fn count_table_rows( async fn rename_table( State(backend): State>, + headers: HeaderMap, Path(id): Path, Query(params): Query, Json(mut request): Json, ) -> Response { request.id = Some(parse_id(&id, params.delimiter.as_deref())); + request.identity = extract_identity(&headers); match backend.rename_table(request).await { Ok(response) => (StatusCode::OK, Json(response)).into_response(), @@ -670,11 +719,13 @@ async fn rename_table( async fn restore_table( State(backend): State>, + headers: HeaderMap, Path(id): Path, Query(params): Query, Json(mut request): Json, ) -> Response { request.id = Some(parse_id(&id, params.delimiter.as_deref())); + request.identity = extract_identity(&headers); match backend.restore_table(request).await { Ok(response) => (StatusCode::OK, Json(response)).into_response(), @@ -684,6 +735,7 @@ async fn restore_table( async fn list_table_versions( State(backend): State>, + headers: HeaderMap, Path(id): Path, Query(params): Query, ) -> Response { @@ -691,6 +743,8 @@ async fn list_table_versions( id: Some(parse_id(&id, params.delimiter.as_deref())), page_token: params.page_token, limit: params.limit, + identity: extract_identity(&headers), + ..Default::default() }; match backend.list_table_versions(request).await { @@ -701,11 +755,14 @@ async fn list_table_versions( async fn get_table_stats( State(backend): State>, + headers: HeaderMap, Path(id): Path, Query(params): Query, ) -> Response { let request = GetTableStatsRequest { id: Some(parse_id(&id, params.delimiter.as_deref())), + identity: extract_identity(&headers), + ..Default::default() }; match backend.get_table_stats(request).await { @@ -716,12 +773,15 @@ async fn get_table_stats( async fn list_all_tables( State(backend): State>, + headers: HeaderMap, Query(params): Query, ) -> Response { let request = ListTablesRequest { id: None, page_token: params.page_token, limit: params.limit, + identity: extract_identity(&headers), + ..Default::default() }; match backend.list_all_tables(request).await { @@ -736,11 +796,13 @@ async fn list_all_tables( async fn create_table_index( State(backend): State>, + headers: HeaderMap, Path(id): Path, Query(params): Query, Json(mut request): Json, ) -> Response { request.id = Some(parse_id(&id, params.delimiter.as_deref())); + request.identity = extract_identity(&headers); match backend.create_table_index(request).await { Ok(response) => (StatusCode::CREATED, Json(response)).into_response(), @@ -750,11 +812,13 @@ async fn create_table_index( async fn create_table_scalar_index( State(backend): State>, + headers: HeaderMap, Path(id): Path, Query(params): Query, Json(mut request): Json, ) -> Response { request.id = Some(parse_id(&id, params.delimiter.as_deref())); + request.identity = extract_identity(&headers); match backend.create_table_scalar_index(request).await { Ok(response) => (StatusCode::CREATED, Json(response)).into_response(), @@ -764,6 +828,7 @@ async fn create_table_scalar_index( async fn list_table_indices( State(backend): State>, + headers: HeaderMap, Path(id): Path, Query(params): Query, ) -> Response { @@ -772,6 +837,8 @@ async fn list_table_indices( version: None, page_token: None, limit: None, + identity: extract_identity(&headers), + ..Default::default() }; match backend.list_table_indices(request).await { @@ -788,6 +855,7 @@ struct IndexPathParams { async fn describe_table_index_stats( State(backend): State>, + headers: HeaderMap, Path(params): Path, Query(query): Query, ) -> Response { @@ -795,6 +863,8 @@ async fn describe_table_index_stats( id: Some(parse_id(¶ms.id, query.delimiter.as_deref())), version: None, index_name: Some(params.index_name), + identity: extract_identity(&headers), + ..Default::default() }; match backend.describe_table_index_stats(request).await { @@ -805,12 +875,15 @@ async fn describe_table_index_stats( async fn drop_table_index( State(backend): State>, + headers: HeaderMap, Path(params): Path, Query(query): Query, ) -> Response { let request = DropTableIndexRequest { id: Some(parse_id(¶ms.id, query.delimiter.as_deref())), index_name: Some(params.index_name), + identity: extract_identity(&headers), + ..Default::default() }; match backend.drop_table_index(request).await { @@ -825,11 +898,13 @@ async fn drop_table_index( async fn alter_table_add_columns( State(backend): State>, + headers: HeaderMap, Path(id): Path, Query(params): Query, Json(mut request): Json, ) -> Response { request.id = Some(parse_id(&id, params.delimiter.as_deref())); + request.identity = extract_identity(&headers); match backend.alter_table_add_columns(request).await { Ok(response) => (StatusCode::OK, Json(response)).into_response(), @@ -839,11 +914,13 @@ async fn alter_table_add_columns( async fn alter_table_alter_columns( State(backend): State>, + headers: HeaderMap, Path(id): Path, Query(params): Query, Json(mut request): Json, ) -> Response { request.id = Some(parse_id(&id, params.delimiter.as_deref())); + request.identity = extract_identity(&headers); match backend.alter_table_alter_columns(request).await { Ok(response) => (StatusCode::OK, Json(response)).into_response(), @@ -853,11 +930,13 @@ async fn alter_table_alter_columns( async fn alter_table_drop_columns( State(backend): State>, + headers: HeaderMap, Path(id): Path, Query(params): Query, Json(mut request): Json, ) -> Response { request.id = Some(parse_id(&id, params.delimiter.as_deref())); + request.identity = extract_identity(&headers); match backend.alter_table_drop_columns(request).await { Ok(response) => (StatusCode::OK, Json(response)).into_response(), @@ -867,11 +946,13 @@ async fn alter_table_drop_columns( async fn update_table_schema_metadata( State(backend): State>, + headers: HeaderMap, Path(id): Path, Query(params): Query, Json(mut request): Json, ) -> Response { request.id = Some(parse_id(&id, params.delimiter.as_deref())); + request.identity = extract_identity(&headers); match backend.update_table_schema_metadata(request).await { Ok(response) => (StatusCode::OK, Json(response)).into_response(), @@ -885,6 +966,7 @@ async fn update_table_schema_metadata( async fn list_table_tags( State(backend): State>, + headers: HeaderMap, Path(id): Path, Query(params): Query, ) -> Response { @@ -892,6 +974,8 @@ async fn list_table_tags( id: Some(parse_id(&id, params.delimiter.as_deref())), page_token: params.page_token, limit: params.limit, + identity: extract_identity(&headers), + ..Default::default() }; match backend.list_table_tags(request).await { @@ -902,11 +986,13 @@ async fn list_table_tags( async fn get_table_tag_version( State(backend): State>, + headers: HeaderMap, Path(id): Path, Query(params): Query, Json(mut request): Json, ) -> Response { request.id = Some(parse_id(&id, params.delimiter.as_deref())); + request.identity = extract_identity(&headers); match backend.get_table_tag_version(request).await { Ok(response) => (StatusCode::OK, Json(response)).into_response(), @@ -916,11 +1002,13 @@ async fn get_table_tag_version( async fn create_table_tag( State(backend): State>, + headers: HeaderMap, Path(id): Path, Query(params): Query, Json(mut request): Json, ) -> Response { request.id = Some(parse_id(&id, params.delimiter.as_deref())); + request.identity = extract_identity(&headers); match backend.create_table_tag(request).await { Ok(response) => (StatusCode::CREATED, Json(response)).into_response(), @@ -930,11 +1018,13 @@ async fn create_table_tag( async fn delete_table_tag( State(backend): State>, + headers: HeaderMap, Path(id): Path, Query(params): Query, Json(mut request): Json, ) -> Response { request.id = Some(parse_id(&id, params.delimiter.as_deref())); + request.identity = extract_identity(&headers); match backend.delete_table_tag(request).await { Ok(response) => (StatusCode::OK, Json(response)).into_response(), @@ -944,11 +1034,13 @@ async fn delete_table_tag( async fn update_table_tag( State(backend): State>, + headers: HeaderMap, Path(id): Path, Query(params): Query, Json(mut request): Json, ) -> Response { request.id = Some(parse_id(&id, params.delimiter.as_deref())); + request.identity = extract_identity(&headers); match backend.update_table_tag(request).await { Ok(response) => (StatusCode::OK, Json(response)).into_response(), @@ -962,11 +1054,13 @@ async fn update_table_tag( async fn explain_table_query_plan( State(backend): State>, + headers: HeaderMap, Path(id): Path, Query(params): Query, Json(mut request): Json, ) -> Response { request.id = Some(parse_id(&id, params.delimiter.as_deref())); + request.identity = extract_identity(&headers); match backend.explain_table_query_plan(request).await { Ok(plan) => (StatusCode::OK, plan).into_response(), @@ -976,11 +1070,13 @@ async fn explain_table_query_plan( async fn analyze_table_query_plan( State(backend): State>, + headers: HeaderMap, Path(id): Path, Query(params): Query, Json(mut request): Json, ) -> Response { request.id = Some(parse_id(&id, params.delimiter.as_deref())); + request.identity = extract_identity(&headers); match backend.analyze_table_query_plan(request).await { Ok(plan) => (StatusCode::OK, plan).into_response(), @@ -994,6 +1090,7 @@ async fn analyze_table_query_plan( async fn describe_transaction( State(backend): State>, + headers: HeaderMap, Path(id): Path, Query(_params): Query, Json(mut request): Json, @@ -1007,6 +1104,7 @@ async fn describe_transaction( } else { request.id = Some(vec![id]); } + request.identity = extract_identity(&headers); match backend.describe_transaction(request).await { Ok(response) => (StatusCode::OK, Json(response)).into_response(), @@ -1016,6 +1114,7 @@ async fn describe_transaction( async fn alter_transaction( State(backend): State>, + headers: HeaderMap, Path(id): Path, Query(_params): Query, Json(mut request): Json, @@ -1027,6 +1126,7 @@ async fn alter_transaction( } else { request.id = Some(vec![id]); } + request.identity = extract_identity(&headers); match backend.alter_transaction(request).await { Ok(response) => (StatusCode::OK, Json(response)).into_response(), @@ -1054,6 +1154,36 @@ fn parse_id(id_str: &str, delimiter: Option<&str>) -> Vec { .collect() } +/// Extract identity information from HTTP headers +/// +/// Extracts `x-api-key` and `Authorization` (Bearer token) headers and returns +/// an Identity object if either is present. +fn extract_identity(headers: &HeaderMap) -> Option> { + let api_key = headers + .get("x-api-key") + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()); + + let auth_token = headers + .get("authorization") + .and_then(|v| v.to_str().ok()) + .and_then(|s| { + // Extract token from "Bearer " format + s.strip_prefix("Bearer ") + .or_else(|| s.strip_prefix("bearer ")) + .map(|t| t.to_string()) + }); + + if api_key.is_some() || auth_token.is_some() { + Some(Box::new(Identity { + api_key, + auth_token, + })) + } else { + None + } +} + #[cfg(test)] mod tests { use super::*; @@ -1197,6 +1327,7 @@ mod tests { id: Some(vec!["test_namespace".to_string()]), properties: None, mode: None, + ..Default::default() }; fixture .namespace @@ -1251,6 +1382,7 @@ mod tests { id: Some(vec![format!("namespace{}", i)]), properties: None, mode: None, + ..Default::default() }; let result = fixture.namespace.create_namespace(create_req).await; assert!(result.is_ok(), "Failed to create namespace{}", i); @@ -1261,6 +1393,7 @@ mod tests { id: Some(vec![]), page_token: None, limit: None, + ..Default::default() }; let result = fixture.namespace.list_namespaces(list_req).await; assert!(result.is_ok()); @@ -1280,6 +1413,7 @@ mod tests { id: Some(vec!["parent".to_string()]), properties: None, mode: None, + ..Default::default() }; fixture .namespace @@ -1292,6 +1426,7 @@ mod tests { id: Some(vec!["parent".to_string(), "child1".to_string()]), properties: None, mode: None, + ..Default::default() }; fixture .namespace @@ -1303,6 +1438,7 @@ mod tests { id: Some(vec!["parent".to_string(), "child2".to_string()]), properties: None, mode: None, + ..Default::default() }; fixture .namespace @@ -1315,6 +1451,7 @@ mod tests { id: Some(vec!["parent".to_string()]), page_token: None, limit: None, + ..Default::default() }; let result = fixture.namespace.list_namespaces(list_req).await; assert!(result.is_ok()); @@ -1334,6 +1471,7 @@ mod tests { id: Some(vec!["test_namespace".to_string()]), properties: None, mode: None, + ..Default::default() }; fixture .namespace @@ -1345,6 +1483,7 @@ mod tests { let create_table_req = CreateTableRequest { id: Some(vec!["test_namespace".to_string(), "test_table".to_string()]), mode: Some("Create".to_string()), + ..Default::default() }; let result = fixture @@ -1385,6 +1524,7 @@ mod tests { id: Some(vec!["test_namespace".to_string()]), properties: None, mode: None, + ..Default::default() }; fixture .namespace @@ -1397,6 +1537,7 @@ mod tests { let create_table_req = CreateTableRequest { id: Some(vec!["test_namespace".to_string(), format!("table{}", i)]), mode: Some("Create".to_string()), + ..Default::default() }; fixture .namespace @@ -1410,6 +1551,7 @@ mod tests { id: Some(vec!["test_namespace".to_string()]), page_token: None, limit: None, + ..Default::default() }; let result = fixture.namespace.list_tables(list_req).await; assert!(result.is_ok()); @@ -1430,6 +1572,7 @@ mod tests { id: Some(vec!["test_namespace".to_string()]), properties: None, mode: None, + ..Default::default() }; fixture .namespace @@ -1441,6 +1584,7 @@ mod tests { let create_table_req = CreateTableRequest { id: Some(vec!["test_namespace".to_string(), "test_table".to_string()]), mode: Some("Create".to_string()), + ..Default::default() }; fixture .namespace @@ -1465,6 +1609,7 @@ mod tests { id: Some(vec!["test_namespace".to_string()]), properties: None, mode: None, + ..Default::default() }; fixture .namespace @@ -1501,6 +1646,7 @@ mod tests { id: Some(vec!["test_namespace".to_string()]), properties: None, mode: None, + ..Default::default() }; fixture .namespace @@ -1512,6 +1658,7 @@ mod tests { let create_table_req = CreateTableRequest { id: Some(vec!["test_namespace".to_string(), "test_table".to_string()]), mode: Some("Create".to_string()), + ..Default::default() }; fixture .namespace @@ -1589,6 +1736,7 @@ mod tests { id: Some(vec!["test_namespace".to_string()]), properties: None, mode: None, + ..Default::default() }; fixture .namespace @@ -1600,6 +1748,7 @@ mod tests { let create_table_req = CreateTableRequest { id: Some(vec!["test_namespace".to_string(), "test_table".to_string()]), mode: Some("Create".to_string()), + ..Default::default() }; fixture .namespace @@ -1610,6 +1759,7 @@ mod tests { // Drop the table let drop_req = DropTableRequest { id: Some(vec!["test_namespace".to_string(), "test_table".to_string()]), + ..Default::default() }; let result = fixture.namespace.drop_table(drop_req).await; assert!( @@ -1637,6 +1787,7 @@ mod tests { id: Some(vec!["test_namespace".to_string()]), properties: None, mode: None, + ..Default::default() }; fixture .namespace @@ -1692,6 +1843,7 @@ mod tests { id: Some(vec!["test_namespace".to_string()]), properties: None, mode: None, + ..Default::default() }; fixture .namespace @@ -1748,6 +1900,7 @@ mod tests { id: Some(vec!["test_namespace".to_string()]), properties: None, mode: None, + ..Default::default() }; fixture .namespace @@ -1767,6 +1920,7 @@ mod tests { // Drop the empty table let drop_req = DropTableRequest { id: Some(vec!["test_namespace".to_string(), "test_table".to_string()]), + ..Default::default() }; let result = fixture.namespace.drop_table(drop_req).await; assert!( @@ -1794,6 +1948,7 @@ mod tests { id: Some(vec!["level1".to_string()]), properties: None, mode: None, + ..Default::default() }; fixture .namespace @@ -1805,6 +1960,7 @@ mod tests { id: Some(vec!["level1".to_string(), "level2".to_string()]), properties: None, mode: None, + ..Default::default() }; fixture .namespace @@ -1820,6 +1976,7 @@ mod tests { ]), properties: None, mode: None, + ..Default::default() }; fixture .namespace @@ -1868,6 +2025,7 @@ mod tests { id: Some(vec!["level1".to_string()]), properties: None, mode: None, + ..Default::default() }; fixture .namespace @@ -1879,6 +2037,7 @@ mod tests { id: Some(vec!["level1".to_string(), "level2".to_string()]), properties: None, mode: None, + ..Default::default() }; fixture .namespace @@ -1894,6 +2053,7 @@ mod tests { ]), properties: None, mode: None, + ..Default::default() }; fixture .namespace @@ -1910,6 +2070,7 @@ mod tests { "deep_table".to_string(), ]), mode: Some("Create".to_string()), + ..Default::default() }; let result = fixture @@ -1947,6 +2108,7 @@ mod tests { id: Some(vec!["namespace1".to_string()]), properties: None, mode: None, + ..Default::default() }; fixture .namespace @@ -1958,6 +2120,7 @@ mod tests { id: Some(vec!["namespace2".to_string()]), properties: None, mode: None, + ..Default::default() }; fixture .namespace @@ -1969,6 +2132,7 @@ mod tests { let create_table_req = CreateTableRequest { id: Some(vec!["namespace1".to_string(), "shared_table".to_string()]), mode: Some("Create".to_string()), + ..Default::default() }; fixture .namespace @@ -1979,6 +2143,7 @@ mod tests { let create_table_req = CreateTableRequest { id: Some(vec!["namespace2".to_string(), "shared_table".to_string()]), mode: Some("Create".to_string()), + ..Default::default() }; fixture .namespace @@ -1989,6 +2154,7 @@ mod tests { // Drop table in namespace1 let drop_req = DropTableRequest { id: Some(vec!["namespace1".to_string(), "shared_table".to_string()]), + ..Default::default() }; fixture.namespace.drop_table(drop_req).await.unwrap(); @@ -2018,6 +2184,7 @@ mod tests { id: Some(vec!["test_namespace".to_string()]), properties: None, mode: None, + ..Default::default() }; fixture .namespace @@ -2029,6 +2196,7 @@ mod tests { let create_table_req = CreateTableRequest { id: Some(vec!["test_namespace".to_string(), "test_table".to_string()]), mode: Some("Create".to_string()), + ..Default::default() }; fixture .namespace @@ -2061,6 +2229,7 @@ mod tests { id: Some(vec!["test_namespace".to_string()]), properties: None, mode: None, + ..Default::default() }; fixture .namespace @@ -2080,6 +2249,7 @@ mod tests { // Verify namespace no longer exists let exists_req = NamespaceExistsRequest { id: Some(vec!["test_namespace".to_string()]), + ..Default::default() }; let result = fixture.namespace.namespace_exists(exists_req).await; assert!(result.is_err(), "Namespace should not exist after drop"); @@ -2100,6 +2270,7 @@ mod tests { id: Some(vec!["test_namespace".to_string()]), properties: Some(properties.clone()), mode: None, + ..Default::default() }; fixture .namespace @@ -2110,6 +2281,7 @@ mod tests { // Describe namespace and verify properties let describe_req = DescribeNamespaceRequest { id: Some(vec!["test_namespace".to_string()]), + ..Default::default() }; let result = fixture.namespace.describe_namespace(describe_req).await; assert!(result.is_ok()); @@ -2125,7 +2297,10 @@ mod tests { let fixture = RestServerFixture::new().await; // Root namespace should always exist - let exists_req = NamespaceExistsRequest { id: Some(vec![]) }; + let exists_req = NamespaceExistsRequest { + id: Some(vec![]), + ..Default::default() + }; let result = fixture.namespace.namespace_exists(exists_req).await; assert!(result.is_ok(), "Root namespace should exist"); @@ -2134,6 +2309,7 @@ mod tests { id: Some(vec![]), properties: None, mode: None, + ..Default::default() }; let result = fixture.namespace.create_namespace(create_req).await; assert!(result.is_err(), "Cannot create root namespace"); @@ -2167,6 +2343,7 @@ mod tests { id: Some(vec!["test_namespace".to_string()]), properties: None, mode: None, + ..Default::default() }; fixture .namespace @@ -2181,6 +2358,7 @@ mod tests { "physical_table".to_string(), ]), mode: Some("Create".to_string()), + ..Default::default() }; fixture .namespace @@ -2197,6 +2375,7 @@ mod tests { location: "test_namespace$physical_table.lance".to_string(), mode: None, properties: None, + ..Default::default() }; let result = fixture.namespace.register_table(register_req).await; @@ -2231,6 +2410,7 @@ mod tests { id: Some(vec!["test_namespace".to_string()]), properties: None, mode: None, + ..Default::default() }; fixture .namespace @@ -2244,6 +2424,7 @@ mod tests { location: "s3://bucket/table.lance".to_string(), mode: None, properties: None, + ..Default::default() }; let result = fixture.namespace.register_table(register_req).await; @@ -2265,6 +2446,7 @@ mod tests { id: Some(vec!["test_namespace".to_string()]), properties: None, mode: None, + ..Default::default() }; fixture .namespace @@ -2278,6 +2460,7 @@ mod tests { location: "../outside/table.lance".to_string(), mode: None, properties: None, + ..Default::default() }; let result = fixture.namespace.register_table(register_req).await; @@ -2300,6 +2483,7 @@ mod tests { id: Some(vec!["test_namespace".to_string()]), properties: None, mode: None, + ..Default::default() }; fixture .namespace @@ -2311,6 +2495,7 @@ mod tests { let create_table_req = CreateTableRequest { id: Some(vec!["test_namespace".to_string(), "test_table".to_string()]), mode: Some("Create".to_string()), + ..Default::default() }; fixture .namespace @@ -2330,6 +2515,7 @@ mod tests { // Deregister the table let deregister_req = DeregisterTableRequest { id: Some(vec!["test_namespace".to_string(), "test_table".to_string()]), + ..Default::default() }; let result = fixture.namespace.deregister_table(deregister_req).await; assert!( @@ -2375,6 +2561,7 @@ mod tests { id: Some(vec!["test_namespace".to_string()]), properties: None, mode: None, + ..Default::default() }; fixture .namespace @@ -2389,6 +2576,7 @@ mod tests { "original_table".to_string(), ]), mode: Some("Create".to_string()), + ..Default::default() }; let create_response = fixture .namespace @@ -2402,6 +2590,7 @@ mod tests { "test_namespace".to_string(), "original_table".to_string(), ]), + ..Default::default() }; fixture .namespace @@ -2431,6 +2620,7 @@ mod tests { location: relative_location.clone(), mode: None, properties: None, + ..Default::default() }; let register_response = fixture diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index 0ff3cb6873a..7d1f4dc8395 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -824,7 +824,7 @@ impl Dataset { WriteMode::Create => { let declare_request = DeclareTableRequest { id: Some(table_id.clone()), - location: None, + ..Default::default() }; // Try declare_table first, fall back to deprecated create_empty_table // for backward compatibility with older namespace implementations. @@ -835,7 +835,7 @@ impl Dataset { Err(Error::NotSupported { .. }) => { let fallback_request = CreateEmptyTableRequest { id: Some(table_id.clone()), - location: None, + ..Default::default() }; let fallback_resp = namespace .create_empty_table(fallback_request) @@ -894,8 +894,7 @@ impl Dataset { WriteMode::Append | WriteMode::Overwrite => { let request = DescribeTableRequest { id: Some(table_id.clone()), - version: None, - with_table_uri: None, + ..Default::default() }; let response = namespace diff --git a/rust/lance/src/dataset/builder.rs b/rust/lance/src/dataset/builder.rs index 3d463ce6ca4..639fda28bca 100644 --- a/rust/lance/src/dataset/builder.rs +++ b/rust/lance/src/dataset/builder.rs @@ -136,8 +136,7 @@ impl DatasetBuilder { ) -> Result { let request = DescribeTableRequest { id: Some(table_id.clone()), - version: None, - with_table_uri: None, + ..Default::default() }; let response = namespace