Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not require AZURE_USERNAME for shared cache #8095

Merged
merged 4 commits into from
Oct 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 43 additions & 15 deletions sdk/identity/azure-identity/azure/identity/_authn_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
NetworkTraceLoggingPolicy,
ProxyPolicy,
RetryPolicy,
DistributedTracingPolicy
DistributedTracingPolicy,
)
from azure.core.pipeline.transport import RequestsTransport, HttpRequest
from azure.identity._constants import AZURE_CLI_CLIENT_ID, KnownAuthorities
from ._constants import AZURE_CLI_CLIENT_ID, KnownAuthorities

try:
ABC = abc.ABC
Expand Down Expand Up @@ -98,6 +98,7 @@ def request_token(self, scopes, method, headers, form_data, params, **kwargs):

@abc.abstractmethod
def obtain_token_by_refresh_token(self, scopes, username):
# type: (Iterable[str], Optional[str]) -> AccessToken
pass

def _deserialize_and_cache_token(self, response, scopes, request_time):
Expand Down Expand Up @@ -214,22 +215,45 @@ def request_token(
token = self._deserialize_and_cache_token(response=response, scopes=scopes, request_time=request_time)
return token

def obtain_token_by_refresh_token(self, scopes, username):
# type: (Iterable[str], str) -> Optional[AccessToken]
"""Acquire an access token using a cached refresh token. Returns ``None`` when that fails, or the cache has no
refresh token. This is only used by SharedTokenCacheCredential and isn't robust enough for anything else."""
def obtain_token_by_refresh_token(self, scopes, username=None):
# type: (Iterable[str], Optional[str]) -> AccessToken
"""Acquire an access token using a cached refresh token. Raises ClientAuthenticationError if that fails.
This is only used by SharedTokenCacheCredential and isn't robust enough for anything else."""

# if an username is provided, restrict our search to accounts that have that username
query = {"username": username} if username else {}
accounts = self._cache.find(TokenCache.CredentialType.ACCOUNT, query=query)

# if more than one account was returned, ensure that that they all have the same home_account_id. If so,
# we'll treat them as equal, otherwise we can't know which one to pick, so we'll raise an error.
if len(accounts) > 1 and len({account.get("home_account_id") for account in accounts}) != 1:
if username:
message = (
"Multiple entries found for user '{}' were found in the shared token cache. "
"This is not currently supported by SharedTokenCacheCredential."
).format(username)
else:
# TODO: we could identify usernames associated with exactly one home account id
message = (
"Multiple users were discovered in the shared token cache. If using DefaultAzureCredential, set "
"the AZURE_USERNAME environment variable to the preferred username. Otherwise, specify it when "
"constructing SharedTokenCacheCredential."
"\nDiscovered accounts: {}"
).format(", ".join({account.get("username") for account in accounts}))
raise ClientAuthenticationError(message=message)

# find account matching username
accounts = self._cache.find(TokenCache.CredentialType.ACCOUNT, query={"username": username})
for account in accounts:
# try each refresh token that might work, return the first access token acquired
for token in self.get_refresh_tokens(scopes, account):
# currently we only support login.microsoftonline.com, which has an alias login.windows.net
# TODO: this must change to support sovereign clouds
environment = account.get("environment")
if not environment or (environment not in self._auth_url and environment != "login.windows.net"):
# ensure the account is associated with the token authority we expect to use
# ('environment' is an authority e.g. 'login.microsoftonline.com')
environment = account.get("environment")
if not environment or environment not in self._auth_url:
# doubtful this account can get the access token we want but public cloud's a special case
# because its authority has an alias: for our purposes login.windows.net = login.microsoftonline.com
if not (environment == "login.windows.net" and KnownAuthorities.AZURE_PUBLIC_CLOUD in self._auth_url):
continue

# try each refresh token, returning the first access token acquired
for token in self.get_refresh_tokens(scopes, account):
request = self.get_refresh_token_grant_request(token, scopes)
request_time = int(time.time())
response = self._pipeline.run(request, stream=False)
Expand All @@ -240,7 +264,11 @@ def obtain_token_by_refresh_token(self, scopes, username):
except ClientAuthenticationError:
continue

return None
message = "No cached token found"
if username:
message += " for '{}'".format(username)

raise ClientAuthenticationError(message=message)

@staticmethod
def _create_config(**kwargs):
Expand Down
17 changes: 5 additions & 12 deletions sdk/identity/azure-identity/azure/identity/_credentials/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ class DefaultAzureCredential(ChainedTokenCredential):
1. A service principal configured by environment variables. See :class:`~azure.identity.EnvironmentCredential` for
more details.
2. An Azure managed identity. See :class:`~azure.identity.ManagedIdentityCredential` for more details.
3. On Windows only: a user who has signed in with a Microsoft application, such as Visual Studio. This requires a
value for the environment variable ``AZURE_USERNAME``. See :class:`~azure.identity.SharedTokenCacheCredential`
for more details.
3. On Windows only: a user who has signed in with a Microsoft application, such as Visual Studio. If multiple
identities are in the cache, then the value of the environment variable ``AZURE_USERNAME`` is used to select
which identity to use. See :class:`~azure.identity.SharedTokenCacheCredential` for more details.

Keyword arguments
- **authority** (str): Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com',
Expand All @@ -34,15 +34,8 @@ def __init__(self, **kwargs):
authority = kwargs.pop("authority", None)
credentials = [EnvironmentCredential(authority=authority, **kwargs), ManagedIdentityCredential(**kwargs)]

# SharedTokenCacheCredential is part of the default only on supported platforms, when $AZURE_USERNAME has a
# value (because the cache may contain tokens for multiple identities and we can only choose one arbitrarily
# without more information from the user), and when $AZURE_PASSWORD has no value (because when $AZURE_USERNAME
# and $AZURE_PASSWORD are set, EnvironmentCredential will be used instead)
if (
SharedTokenCacheCredential.supported()
and EnvironmentVariables.AZURE_USERNAME in os.environ
and EnvironmentVariables.AZURE_PASSWORD not in os.environ
):
# SharedTokenCacheCredential is part of the default only on supported platforms.
if SharedTokenCacheCredential.supported():
credentials.append(
SharedTokenCacheCredential(
username=os.environ.get(EnvironmentVariables.AZURE_USERNAME), authority=authority, **kwargs
Expand Down
10 changes: 3 additions & 7 deletions sdk/identity/azure-identity/azure/identity/_credentials/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ class SharedTokenCacheCredential(object):
defines authorities for other clouds.
"""

def __init__(self, username, **kwargs): # pylint:disable=unused-argument
# type: (str, **Any) -> None
def __init__(self, username=None, **kwargs): # pylint:disable=unused-argument
# type: (Optional[str], **Any) -> None

self._username = username

Expand Down Expand Up @@ -161,11 +161,7 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
if not self._client:
raise ClientAuthenticationError(message="Shared token cache unavailable")

token = self._client.obtain_token_by_refresh_token(scopes, self._username)
if not token:
raise ClientAuthenticationError(message="No cached token found for '{}'".format(self._username))

return token
return self._client.obtain_token_by_refresh_token(scopes, self._username)

@staticmethod
def supported():
Expand Down
54 changes: 42 additions & 12 deletions sdk/identity/azure-identity/azure/identity/aio/_authn_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from azure.core.pipeline.transport import AioHttpTransport

from .._authn_client import AuthnClientBase
from .._constants import KnownAuthorities

if TYPE_CHECKING:
from typing import Any, Dict, Iterable, Mapping, Optional
Expand Down Expand Up @@ -67,21 +68,46 @@ async def request_token(
token = self._deserialize_and_cache_token(response=response, scopes=scopes, request_time=request_time)
return token

async def obtain_token_by_refresh_token(self, scopes: "Iterable[str]", username: str) -> "Optional[AccessToken]":
"""Acquire an access token using a cached refresh token. Returns ``None`` when that fails, or the cache has no
refresh token. This is only used by SharedTokenCacheCredential and isn't robust enough for anything else."""
async def obtain_token_by_refresh_token(
self, scopes: "Iterable[str]", username: "Optional[str]" = None
) -> "AccessToken":
"""Acquire an access token using a cached refresh token. Raises ClientAuthenticationError if that fails.
This is only used by SharedTokenCacheCredential and isn't robust enough for anything else."""

# if an username is provided, restrict our search to accounts that have that username
query = {"username": username} if username else {}
accounts = self._cache.find(TokenCache.CredentialType.ACCOUNT, query=query)

# if more than one account was returned, ensure that that they all have the same home_account_id. If so,
# we'll treat them as equal, otherwise we can't know which one to pick, so we'll raise an error.
if len(accounts) > 1 and len({account.get("home_account_id") for account in accounts}) != 1:
if username:
message = (
"Multiple entries found for user '{}' were found in the shared token cache. "
"This is not currently supported by SharedTokenCacheCredential"
).format(username)
else:
# TODO: we could identify usernames associated with exactly one home account id
ellismg marked this conversation as resolved.
Show resolved Hide resolved
message = (
"Multiple users were discovered in the shared token cache. If using DefaultAzureCredential, set "
"the AZURE_USERNAME environment variable to the preferred username. Otherwise, specify it when "
"constructing SharedTokenCacheCredential."
"\nDiscovered accounts: {}"
).format(", ".join({account.get("username") for account in accounts}))
raise ClientAuthenticationError(message=message)

# find account matching username
accounts = self._cache.find(TokenCache.CredentialType.ACCOUNT, query={"username": username})
for account in accounts:
# try each refresh token that might work, return the first access token acquired
for token in self.get_refresh_tokens(scopes, account):
# currently we only support login.microsoftonline.com, which has an alias login.windows.net
# TODO: this must change to support sovereign clouds
environment = account.get("environment")
if not environment or (environment not in self._auth_url and environment != "login.windows.net"):
# ensure the account is associated with the token authority we expect to use
# ('environment' is an authority e.g. 'login.microsoftonline.com')
environment = account.get("environment")
if not environment or environment not in self._auth_url:
# doubtful this account can get the access token we want but public cloud's a special case
# because its authority has an alias: for our purposes login.windows.net = login.microsoftonline.com
if not (environment == "login.windows.net" and KnownAuthorities.AZURE_PUBLIC_CLOUD in self._auth_url):
continue

# try each refresh token, returning the first access token acquired
for token in self.get_refresh_tokens(scopes, account):
request = self.get_refresh_token_grant_request(token, scopes)
request_time = int(time.time())
response = await self._pipeline.run(request, stream=False)
Expand All @@ -92,7 +118,11 @@ async def obtain_token_by_refresh_token(self, scopes: "Iterable[str]", username:
except ClientAuthenticationError:
continue

return None
message = "No cached token found"
if username:
message += " for '{}'".format(username)

raise ClientAuthenticationError(message=message)

@staticmethod
def _create_config(**kwargs: "Any") -> Configuration:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ class DefaultAzureCredential(ChainedTokenCredential):
1. A service principal configured by environment variables. See :class:`~azure.identity.aio.EnvironmentCredential`
for more details.
2. An Azure managed identity. See :class:`~azure.identity.aio.ManagedIdentityCredential` for more details.
3. On Windows only: a user who has signed in with a Microsoft application, such as Visual Studio. This requires a
value for the environment variable ``AZURE_USERNAME``. See
:class:`~azure.identity.aio.SharedTokenCacheCredential` for more details.
3. On Windows only: a user who has signed in with a Microsoft application, such as Visual Studio. If multiple
identities are in the cache, then the value of the environment variable ``AZURE_USERNAME`` is used to select
which identity to use. See :class:`~azure.identity.aio.SharedTokenCacheCredential` for more details.

Keyword arguments
- **authority** (str): Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com',
Expand All @@ -34,15 +34,8 @@ def __init__(self, **kwargs):
authority = kwargs.pop("authority", None)
credentials = [EnvironmentCredential(authority=authority, **kwargs), ManagedIdentityCredential(**kwargs)]

# SharedTokenCacheCredential is part of the default only on supported platforms, when $AZURE_USERNAME has a
# value (because the cache may contain tokens for multiple identities and we can only choose one arbitrarily
# without more information from the user), and when $AZURE_PASSWORD has no value (because when $AZURE_USERNAME
# and $AZURE_PASSWORD are set, EnvironmentCredential will be used instead)
if (
SharedTokenCacheCredential.supported()
and EnvironmentVariables.AZURE_USERNAME in os.environ
and EnvironmentVariables.AZURE_PASSWORD not in os.environ
):
# SharedTokenCacheCredential is part of the default only on supported platforms.
if SharedTokenCacheCredential.supported():
credentials.append(
SharedTokenCacheCredential(
username=os.environ.get(EnvironmentVariables.AZURE_USERNAME), authority=authority, **kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,7 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py
if not self._client:
raise ClientAuthenticationError(message="Shared token cache unavailable")

token = await self._client.obtain_token_by_refresh_token(scopes, self._username)
if not token:
raise ClientAuthenticationError(message="No cached token found for '{}'".format(self._username))

return token
return await self._client.obtain_token_by_refresh_token(scopes, self._username)

@staticmethod
def _get_auth_client(cache: "msal_extensions.FileTokenCache") -> "AuthnClientBase":
Expand Down
27 changes: 2 additions & 25 deletions sdk/identity/azure-identity/tests/test_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,37 +247,14 @@ def test_default_credential_shared_cache_use(mock_credential):
assert mock_credential.supported.call_count == 1
mock_credential.supported.reset_mock()

# unsupported platform, $AZURE_USERNAME set, $AZURE_PASSWORD not set -> default credential shouldn't use shared cache
credential = DefaultAzureCredential()
assert mock_credential.call_count == 0
assert mock_credential.supported.call_count == 1

mock_credential.supported = Mock(return_value=True)

# supported platform, $AZURE_USERNAME not set -> default credential shouldn't use shared cache
# supported platform -> default credential should use shared cache
credential = DefaultAzureCredential()
assert mock_credential.call_count == 0
assert mock_credential.call_count == 1
assert mock_credential.supported.call_count == 1
mock_credential.supported.reset_mock()

# supported platform, $AZURE_USERNAME and $AZURE_PASSWORD set -> default credential shouldn't use shared cache
# (EnvironmentCredential should be used when both variables are set)
with patch.dict("os.environ", {"AZURE_USERNAME": "foo@bar.com", "AZURE_PASSWORD": "***"}):
credential = DefaultAzureCredential()
assert mock_credential.call_count == 0

# supported platform, $AZURE_USERNAME set, $AZURE_PASSWORD not set -> default credential should use shared cache
with patch.dict("os.environ", {"AZURE_USERNAME": "foo@bar.com"}):
expected_token = AccessToken("***", 42)
mock_credential.return_value = Mock(get_token=lambda *_: expected_token)

credential = DefaultAzureCredential()
assert mock_credential.call_count == 1

token = credential.get_token("scope")
assert token == expected_token


def test_device_code_credential():
expected_token = "access-token"
user_code = "user-code"
Expand Down
26 changes: 2 additions & 24 deletions sdk/identity/azure-identity/tests/test_identity_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,32 +296,10 @@ async def test_default_credential_shared_cache_use():
assert mock_credential.supported.call_count == 1
mock_credential.supported.reset_mock()

# unsupported platform, $AZURE_USERNAME set, $AZURE_PASSWORD not set -> default credential shouldn't use shared cache
credential = DefaultAzureCredential()
assert mock_credential.call_count == 0
assert mock_credential.supported.call_count == 1

mock_credential.supported = Mock(return_value=True)

# supported platform, $AZURE_USERNAME not set -> default credential shouldn't use shared cache
# supported platform -> default credential should use shared cache
credential = DefaultAzureCredential()
assert mock_credential.call_count == 0
assert mock_credential.call_count == 1
assert mock_credential.supported.call_count == 1
mock_credential.supported.reset_mock()

# supported platform, $AZURE_USERNAME and $AZURE_PASSWORD set -> default credential shouldn't use shared cache
# (EnvironmentCredential should be used when both variables are set)
with patch.dict("os.environ", {"AZURE_USERNAME": "foo@bar.com", "AZURE_PASSWORD": "***"}):
credential = DefaultAzureCredential()
assert mock_credential.call_count == 0

# supported platform, $AZURE_USERNAME set, $AZURE_PASSWORD not set -> default credential should use shared cache
with patch.dict("os.environ", {"AZURE_USERNAME": "foo@bar.com"}):
expected_token = AccessToken("***", 42)
mock_credential.return_value = Mock(get_token=asyncio.coroutine(lambda *_: expected_token))

credential = DefaultAzureCredential()
assert mock_credential.call_count == 1

token = await credential.get_token("scope")
assert token == expected_token