diff --git a/providers/microsoft/azure/docs/connections/adls_v2.rst b/providers/microsoft/azure/docs/connections/adls_v2.rst index ec60952ca0298..28156c7477337 100644 --- a/providers/microsoft/azure/docs/connections/adls_v2.rst +++ b/providers/microsoft/azure/docs/connections/adls_v2.rst @@ -67,6 +67,8 @@ Extra (optional) * ``tenant_id``: Specify the tenant to use. Needed for Active Directory (token) authentication. It can be left out to fall back on DefaultAzureCredential_. * ``connection_string``: Connection string for use with connection string authentication. It can be left out to fall back on DefaultAzureCredential_. + * ``account_host``: Override for the default Azure Blob endpoint domain. Use this to specify a custom or private domain (e.g., `myaccount.blob.core.customdomain.io`) instead of the default `core.windows.net`. + When specifying the connection in environment variable you should specify it using URI syntax. diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/fs/adls.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/fs/adls.py index 299d9f775df76..c42304d696a4c 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/fs/adls.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/fs/adls.py @@ -73,8 +73,11 @@ def get_fs(conn_id: str | None, storage_options: dict[str, Any] | None = None) - if tenant_id is None and password: options["account_key"] = password - # now take any fields from extras and overlay on these - # add empty field to remove defaults + # Now take any fields from extras and overlay them on top of existing options. + # Add empty field to remove defaults. + # 'account_host' is included to allow overriding the default Azure Blob endpoint domain + # (e.g., to use a private endpoint or custom domain instead of core.windows.net). + fields = [ "account_name", "account_key", @@ -84,6 +87,7 @@ def get_fs(conn_id: str | None, storage_options: dict[str, Any] | None = None) - "workload_identity_client_id", "workload_identity_tenant_id", "anon", + "account_host", ] for field in fields: value = get_field(conn_id=conn_id, conn_type=conn_type, extras=extras, field_name=field) diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/fs/test_adls.py b/providers/microsoft/azure/tests/unit/microsoft/azure/fs/test_adls.py index 24d6b862c4166..5bfa893a678dc 100644 --- a/providers/microsoft/azure/tests/unit/microsoft/azure/fs/test_adls.py +++ b/providers/microsoft/azure/tests/unit/microsoft/azure/fs/test_adls.py @@ -124,6 +124,23 @@ def mocked_blob_file_system(): "account_key": "p", }, ), + ( + Connection( + conn_id="testconn", + conn_type="wasb", + login="testaccountname", + password="p", + host="testaccountID", + extra={ + "account_host": "mystorageaccount.blob.core.mydomain.io", + }, + ), + { + "account_url": "https://testaccountname.blob.core.windows.net/", + "account_host": "mystorageaccount.blob.core.mydomain.io", + "account_key": "p", + }, + ), ], indirect=["mocked_connection"], )