diff --git a/daft/delta_lake/delta_lake_scan.py b/daft/delta_lake/delta_lake_scan.py index a5c7b94b29..f1357c2fa7 100644 --- a/daft/delta_lake/delta_lake_scan.py +++ b/daft/delta_lake/delta_lake_scan.py @@ -17,6 +17,7 @@ ScanTask, StorageConfig, ) +from daft.io.aws_config import boto3_client_from_s3_config from daft.io.object_store_options import io_config_to_storage_options from daft.io.scan import PartitionField, ScanOperator from daft.logical.schema import Schema @@ -43,6 +44,24 @@ def __init__( deltalake_sdk_io_config = storage_config.config.io_config scheme = urlparse(table_uri).scheme if scheme == "s3" or scheme == "s3a": + # Try to get region from boto3 + if deltalake_sdk_io_config.s3.region_name is None: + from botocore.exceptions import BotoCoreError + + try: + client = boto3_client_from_s3_config("s3", deltalake_sdk_io_config.s3) + response = client.get_bucket_location(Bucket=urlparse(table_uri).netloc) + except BotoCoreError as e: + logger.warning( + "Failed to get the S3 bucket region using existing storage config, will attempt to get it from the environment instead. Error from boto3: %s", + e, + ) + else: + deltalake_sdk_io_config = deltalake_sdk_io_config.replace( + s3=deltalake_sdk_io_config.s3.replace(region_name=response["LocationConstraint"]) + ) + + # Try to get config from the environment if any([deltalake_sdk_io_config.s3.key_id is None, deltalake_sdk_io_config.s3.region_name is None]): try: s3_config_from_env = S3Config.from_env() diff --git a/daft/io/aws_config.py b/daft/io/aws_config.py new file mode 100644 index 0000000000..7f0e9e3dff --- /dev/null +++ b/daft/io/aws_config.py @@ -0,0 +1,21 @@ +from typing import TYPE_CHECKING + +from daft.daft import S3Config + +if TYPE_CHECKING: + import boto3 + + +def boto3_client_from_s3_config(service: str, s3_config: S3Config) -> "boto3.client": + import boto3 + + return boto3.client( + service, + region_name=s3_config.region_name, + use_ssl=s3_config.use_ssl, + verify=s3_config.verify_ssl, + endpoint_url=s3_config.endpoint_url, + aws_access_key_id=s3_config.key_id, + aws_secret_access_key=s3_config.access_key, + aws_session_token=s3_config.session_token, + ) diff --git a/daft/io/catalog.py b/daft/io/catalog.py index 1183caa8ab..62cb16e672 100644 --- a/daft/io/catalog.py +++ b/daft/io/catalog.py @@ -5,6 +5,7 @@ from typing import Optional from daft.daft import IOConfig +from daft.io.aws_config import boto3_client_from_s3_config class DataCatalogType(Enum): @@ -42,20 +43,8 @@ def table_uri(self, io_config: IOConfig) -> str: """ if self.catalog == DataCatalogType.GLUE: # Use boto3 to get the table from AWS Glue Data Catalog. - import boto3 + glue = boto3_client_from_s3_config("glue", io_config.s3) - s3_config = io_config.s3 - - glue = boto3.client( - "glue", - region_name=s3_config.region_name, - use_ssl=s3_config.use_ssl, - verify=s3_config.verify_ssl, - endpoint_url=s3_config.endpoint_url, - aws_access_key_id=s3_config.key_id, - aws_secret_access_key=s3_config.access_key, - aws_session_token=s3_config.session_token, - ) if self.catalog_id is not None: # Allow cross account access, table.catalog_id should be the target account id glue_table = glue.get_table(