-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Configurable authority for DefaultAzureCredential (#8154)
- Loading branch information
Showing
4 changed files
with
153 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# ------------------------------------ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
# ------------------------------------ | ||
from azure.core.credentials import AccessToken | ||
from azure.identity import DefaultAzureCredential, KnownAuthorities, SharedTokenCacheCredential | ||
from azure.identity._constants import EnvironmentVariables | ||
from six.moves.urllib_parse import urlparse | ||
|
||
from helpers import mock_response | ||
|
||
try: | ||
from unittest.mock import Mock, patch | ||
except ImportError: # python < 3.3 | ||
from mock import Mock, patch # type: ignore | ||
|
||
|
||
def test_default_credential_authority(): | ||
# TODO need a mock cache to test SharedTokenCacheCredential | ||
tenant_id = "expected_tenant" | ||
expected_access_token = "***" | ||
response = mock_response( | ||
json_payload={ | ||
"access_token": expected_access_token, | ||
"expires_in": 0, | ||
"expires_on": 42, | ||
"not_before": 0, | ||
"resource": "scope", | ||
"token_type": "Bearer", | ||
} | ||
) | ||
|
||
def exercise_credentials(authority_kwarg, expected_authority=None): | ||
expected_authority = expected_authority or authority_kwarg | ||
def send(request, **_): | ||
scheme, netloc, path, _, _, _ = urlparse(request.url) | ||
assert scheme == "https" | ||
assert netloc == expected_authority | ||
assert path.startswith("/" + tenant_id) | ||
return response | ||
|
||
# environment credential configured with client secret should respect authority | ||
environment = { | ||
EnvironmentVariables.AZURE_CLIENT_ID: "client_id", | ||
EnvironmentVariables.AZURE_CLIENT_SECRET: "secret", | ||
EnvironmentVariables.AZURE_TENANT_ID: tenant_id, | ||
} | ||
with patch("os.environ", environment): | ||
transport=Mock(send=send) | ||
access_token, _ = DefaultAzureCredential(authority=authority_kwarg, transport=transport).get_token("scope") | ||
assert access_token == expected_access_token | ||
|
||
# managed identity credential should ignore authority | ||
with patch("os.environ", {EnvironmentVariables.MSI_ENDPOINT: "https://some.url"}): | ||
transport = Mock(send=lambda *_, **__: response) | ||
access_token, _ = DefaultAzureCredential(authority=authority_kwarg, transport=transport).get_token("scope") | ||
assert access_token == expected_access_token | ||
|
||
# all credentials not representing managed identities should use a specified authority or default to public cloud | ||
exercise_credentials("authority.com") | ||
exercise_credentials(None, KnownAuthorities.AZURE_PUBLIC_CLOUD) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# ------------------------------------ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
# ------------------------------------ | ||
import asyncio | ||
from unittest.mock import Mock, patch | ||
from urllib.parse import urlparse | ||
|
||
from azure.core.credentials import AccessToken | ||
from azure.identity import KnownAuthorities | ||
from azure.identity.aio import DefaultAzureCredential, SharedTokenCacheCredential | ||
from azure.identity._constants import EnvironmentVariables | ||
import pytest | ||
|
||
from helpers import mock_response | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_default_credential_authority(): | ||
# TODO need a mock cache to test SharedTokenCacheCredential | ||
you_shall_not_pass = "sentinel" | ||
authority = "authority.com" | ||
tenant_id = "expected_tenant" | ||
expected_access_token = "***" | ||
response = mock_response( | ||
json_payload={ | ||
"access_token": expected_access_token, | ||
"expires_in": 0, | ||
"expires_on": 42, | ||
"not_before": 0, | ||
"resource": "scope", | ||
"token_type": "Bearer", | ||
} | ||
) | ||
|
||
async def exercise_credentials(authority_kwarg, expected_authority=None): | ||
expected_authority = expected_authority or authority_kwarg | ||
async def send(request, **_): | ||
scheme, netloc, path, _, _, _ = urlparse(request.url) | ||
assert scheme == "https" | ||
assert netloc == expected_authority | ||
assert path.startswith("/" + tenant_id) | ||
return response | ||
|
||
# environment credential configured with client secret should respect authority | ||
environment = { | ||
EnvironmentVariables.AZURE_CLIENT_ID: "client_id", | ||
EnvironmentVariables.AZURE_CLIENT_SECRET: "secret", | ||
EnvironmentVariables.AZURE_TENANT_ID: tenant_id, | ||
} | ||
with patch("os.environ", environment): | ||
transport = Mock(send=send) | ||
if authority_kwarg: | ||
credential = DefaultAzureCredential(authority=authority_kwarg, transport=transport) | ||
else: | ||
credential = DefaultAzureCredential(transport=transport) | ||
access_token, _ = await credential.get_token("scope") | ||
assert access_token == expected_access_token | ||
|
||
# managed identity credential should ignore authority | ||
with patch("os.environ", {EnvironmentVariables.MSI_ENDPOINT: "https://some.url"}): | ||
transport = Mock(send=asyncio.coroutine(lambda *_, **__: response)) | ||
if authority_kwarg: | ||
credential = DefaultAzureCredential(authority=authority_kwarg, transport=transport) | ||
else: | ||
credential = DefaultAzureCredential(transport=transport) | ||
access_token, _ = await credential.get_token("scope") | ||
assert access_token == expected_access_token | ||
|
||
# all credentials not representing managed identities should use a specified authority or default to public cloud | ||
await exercise_credentials("authority.com") | ||
await exercise_credentials(None, KnownAuthorities.AZURE_PUBLIC_CLOUD) |