Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add credential provider utility classes for AWS, GCP #19297

Merged
merged 16 commits into from
Oct 18, 2024
51 changes: 23 additions & 28 deletions crates/polars-io/src/cloud/credential_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,8 +296,8 @@ impl Debug for CredentialProviderFunction {
impl Eq for CredentialProviderFunction {}

impl PartialEq for CredentialProviderFunction {
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.0, &other.0)
fn eq(&self, _other: &Self) -> bool {
false
}
}

Expand Down Expand Up @@ -379,8 +379,8 @@ impl<C: Clone> FetchedCredentialsCache<C> {
if last_fetched_expiry.saturating_sub(current_time) < REQUEST_TIME_BUFFER {
if verbose {
eprintln!(
"[FetchedCredentialsCache]: Call update_func: current_time = {},\
last_fetched_expiry = {}",
"[FetchedCredentialsCache]: Call update_func: current_time = {}\
, last_fetched_expiry = {}",
current_time, *last_fetched_expiry
)
}
Expand All @@ -401,17 +401,24 @@ impl<C: Clone> FetchedCredentialsCache<C> {
}

if verbose {
eprintln!(
"[FetchedCredentialsCache]: Finish update_func: \
new expiry = {} (in {} seconds)",
*last_fetched_expiry,
last_fetched_expiry.saturating_sub(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
),
)
if *last_fetched_expiry == u64::MAX {
eprintln!(
"[FetchedCredentialsCache]: Finish update_func: \
new expiry = (never expires)"
)
} else {
eprintln!(
"[FetchedCredentialsCache]: Finish update_func: \
new expiry = {} (in {} seconds)",
*last_fetched_expiry,
last_fetched_expiry.saturating_sub(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
),
)
}
}
}

Expand All @@ -433,7 +440,7 @@ mod python_impl {

use super::IntoCredentialProvider;

#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct PythonCredentialProvider(pub(super) Arc<PythonFunction>);

Expand Down Expand Up @@ -625,20 +632,8 @@ mod python_impl {
}
}

impl Eq for PythonCredentialProvider {}

impl PartialEq for PythonCredentialProvider {
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.0, &other.0)
ritchie46 marked this conversation as resolved.
Show resolved Hide resolved
}
}

impl Hash for PythonCredentialProvider {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
// # Safety
// * Inner is an `Arc`
// * Visibility is limited to super
// * No code in `mod python_impl` or `super` mutates the Arc inner.
state.write_usize(Arc::as_ptr(&self.0) as *const () as usize)
}
}
Expand Down
9 changes: 9 additions & 0 deletions crates/polars-io/src/cloud/object_store_setup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::sync::Arc;
use object_store::local::LocalFileSystem;
use object_store::ObjectStore;
use once_cell::sync::Lazy;
use polars_core::config;
use polars_error::{polars_bail, to_compute_err, PolarsError, PolarsResult};
use polars_utils::aliases::PlHashMap;
use tokio::sync::RwLock;
Expand Down Expand Up @@ -58,6 +59,8 @@ pub async fn build_object_store(
let parsed = parse_url(url).map_err(to_compute_err)?;
let cloud_location = CloudLocation::from_url(&parsed, glob)?;

// FIXME: `credential_provider` is currently serializing the entire Python function here
// into a string with pickle for this cache key because we are using `serde_json::to_string`
let key = url_and_creds_to_key(&parsed, options);
let mut allow_cache = true;

Expand Down Expand Up @@ -124,6 +127,12 @@ pub async fn build_object_store(
let mut cache = OBJECT_STORE_CACHE.write().await;
// Clear the cache if we surpass a certain amount of buckets.
if cache.len() > 8 {
if config::verbose() {
eprintln!(
"build_object_store: clearing store cache (cache.len(): {})",
cache.len()
);
}
cache.clear()
}
cache.insert(key, store.clone());
Expand Down
11 changes: 11 additions & 0 deletions py-polars/docs/source/reference/io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,14 @@ Connect to pyarrow datasets.
:toctree: api/

scan_pyarrow_dataset

Cloud Credentials
~~~~~~~~~~~~~~~~~
Configuration for cloud credential provisioning.
Copy link
Collaborator Author

@nameexhaustion nameexhaustion Oct 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also added to the Python docs, but everything has been marked unstable


.. autosummary::
:toctree: api/

CredentialProvider
CredentialProviderAWS
CredentialProviderGCP
13 changes: 13 additions & 0 deletions py-polars/polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,13 @@
scan_parquet,
scan_pyarrow_dataset,
)
from polars.io.cloud import (
CredentialProvider,
CredentialProviderAWS,
CredentialProviderFunction,
CredentialProviderFunctionReturn,
CredentialProviderGCP,
)
from polars.lazyframe import GPUEngine, LazyFrame
from polars.meta import (
build_info,
Expand Down Expand Up @@ -266,6 +273,12 @@
"scan_ndjson",
"scan_parquet",
"scan_pyarrow_dataset",
# polars.io.cloud
"CredentialProvider",
"CredentialProviderAWS",
"CredentialProviderFunction",
"CredentialProviderFunctionReturn",
"CredentialProviderGCP",
# polars.stringcache
"StringCache",
"disable_string_cache",
Expand Down
15 changes: 11 additions & 4 deletions py-polars/polars/_typing.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from __future__ import annotations

from collections.abc import Collection, Iterable, Mapping, Sequence
from pathlib import Path
from typing import (
IO,
TYPE_CHECKING,
Any,
Callable,
Literal,
Optional,
Protocol,
TypedDict,
TypeVar,
Expand Down Expand Up @@ -297,6 +297,13 @@ def fetchmany(self, *args: Any, **kwargs: Any) -> Any:
# LazyFrame engine selection
EngineType: TypeAlias = Union[Literal["cpu", "gpu"], "GPUEngine"]

CredentialProviderFunction: TypeAlias = Callable[
[], tuple[dict[str, Optional[str]], Optional[int]]
ScanSource: TypeAlias = Union[
str
| Path
| IO[bytes]
| bytes
| list[str]
| list[Path]
| list[IO[bytes]]
| list[bytes]
]
6 changes: 5 additions & 1 deletion py-polars/polars/io/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
from pathlib import Path
from typing import IO, TYPE_CHECKING, Any, overload

from polars._utils.various import is_int_sequence, is_str_sequence, normalize_filepath
from polars._utils.various import (
is_int_sequence,
is_str_sequence,
normalize_filepath,
)
from polars.dependencies import _FSSPEC_AVAILABLE, fsspec
from polars.exceptions import NoDataError

Expand Down
15 changes: 15 additions & 0 deletions py-polars/polars/io/cloud/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from polars.io.cloud.credential_provider import (
CredentialProvider,
CredentialProviderAWS,
CredentialProviderFunction,
CredentialProviderFunctionReturn,
CredentialProviderGCP,
)

__all__ = [
"CredentialProvider",
"CredentialProviderAWS",
"CredentialProviderFunction",
"CredentialProviderFunctionReturn",
"CredentialProviderGCP",
]
55 changes: 55 additions & 0 deletions py-polars/polars/io/cloud/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, Literal

from polars._utils.various import is_path_or_str_sequence

if TYPE_CHECKING:
from polars._typing import ScanSource


def _first_scan_path(
source: ScanSource,
) -> str | Path | None:
if isinstance(source, (str, Path)):
return source
elif is_path_or_str_sequence(source) and source:
return source[0]

return None


def _infer_cloud_type(
source: ScanSource,
) -> Literal["aws", "azure", "gcp", "file", "http", "hf"] | None:
if (path := _first_scan_path(source)) is None:
return None

splitted = str(path).split("://", maxsplit=1)

# Fast path - local file
if not splitted:
return None

scheme = splitted[0]

if scheme == "file":
return "file"

if any(scheme == x for x in ["s3", "s3a"]):
return "aws"

if any(scheme == x for x in ["az", "azure", "adl", "abfs", "abfss"]):
return "azure"

if any(scheme == x for x in ["gs", "gcp", "gcs"]):
return "gcp"

if any(scheme == x for x in ["http", "https"]):
return "http"

if scheme == "hf":
return "hf"

return None
Loading
Loading