diff --git a/airflow/providers/microsoft/azure/hooks/wasb.py b/airflow/providers/microsoft/azure/hooks/wasb.py index 2281e184c22bd..6c4b76cdc47bd 100644 --- a/airflow/providers/microsoft/azure/hooks/wasb.py +++ b/airflow/providers/microsoft/azure/hooks/wasb.py @@ -210,13 +210,19 @@ def get_conn(self) -> BlobServiceClient: if sas_token.startswith("https"): return BlobServiceClient(account_url=sas_token, **extra) else: - return BlobServiceClient(account_url=f"{account_url}/{sas_token}", **extra) + if not account_url.startswith("https://"): + # TODO: require url in the host field in the next major version? + account_url = f"https://{conn.login}.blob.core.windows.net" + return BlobServiceClient(account_url=f"{account_url.rstrip('/')}/{sas_token}", **extra) # Fall back to old auth (password) or use managed identity if not provided. credential = conn.password if not credential: credential = DefaultAzureCredential() self.log.info("Using DefaultAzureCredential as credential") + if not account_url.startswith("https://"): + # TODO: require url in the host field in the next major version? + account_url = f"https://{conn.login}.blob.core.windows.net/" return BlobServiceClient( account_url=account_url, credential=credential, diff --git a/tests/providers/microsoft/azure/hooks/test_wasb.py b/tests/providers/microsoft/azure/hooks/test_wasb.py index 672d20d3a65eb..9e9ce9a0d06e8 100644 --- a/tests/providers/microsoft/azure/hooks/test_wasb.py +++ b/tests/providers/microsoft/azure/hooks/test_wasb.py @@ -200,6 +200,42 @@ def test_azure_directory_connection(self): assert isinstance(hook.get_conn(), BlobServiceClient) assert isinstance(hook.get_conn().credential, ClientSecretCredential) + @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient") + @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.DefaultAzureCredential") + @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection") + def test_active_directory_ID_used_as_host(self, mock_get_conn, mock_credential, mock_blob_service_client): + hook = WasbHook(wasb_conn_id="testconn") + mock_get_conn.return_value = Connection( + conn_id="testconn", + conn_type=self.connection_type, + login="testaccountname", + host="testaccountID", + ) + hook.get_conn() + assert mock_blob_service_client.call_args == mock.call( + account_url="https://testaccountname.blob.core.windows.net/", + credential=mock_credential.return_value, + ) + + @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient") + @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection") + def test_sas_token_provided_and_active_directory_ID_used_as_host( + self, mock_get_conn, mock_blob_service_client + ): + hook = WasbHook(wasb_conn_id="testconn") + mock_get_conn.return_value = Connection( + conn_id="testconn", + conn_type=self.connection_type, + login="testaccountname", + host="testaccountID", + extra=json.dumps({"sas_token": "SAStoken"}), + ) + hook.get_conn() + assert mock_blob_service_client.call_args == mock.call( + account_url="https://testaccountname.blob.core.windows.net/SAStoken", + sas_token="SAStoken", + ) + @pytest.mark.parametrize( argnames="conn_id_str", argvalues=[