diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/__init__.py b/sdk/identity/azure-identity/azure/identity/_credentials/__init__.py index baf64e6d5102..78e3e51e4cb5 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/__init__.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/__init__.py @@ -14,7 +14,7 @@ from .azure_cli import AzureCliCredential from .device_code import DeviceCodeCredential from .user_password import UsernamePasswordCredential -from .vscode_credential import VSCodeCredential +from .vscode import VSCodeCredential __all__ = [ diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/default.py b/sdk/identity/azure-identity/azure/identity/_credentials/default.py index a6354cd59da5..a2cadd674ca2 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/default.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/default.py @@ -13,7 +13,7 @@ from .managed_identity import ManagedIdentityCredential from .shared_cache import SharedTokenCacheCredential from .azure_cli import AzureCliCredential -from .vscode_credential import VSCodeCredential +from .vscode import VSCodeCredential try: diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/vscode_credential.py b/sdk/identity/azure-identity/azure/identity/_credentials/vscode.py similarity index 100% rename from sdk/identity/azure-identity/azure/identity/_credentials/vscode_credential.py rename to sdk/identity/azure-identity/azure/identity/_credentials/vscode.py diff --git a/sdk/identity/azure-identity/azure/identity/_internal/__init__.py b/sdk/identity/azure-identity/azure/identity/_internal/__init__.py index 60d776175e4f..236972381529 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/__init__.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/__init__.py @@ -37,7 +37,7 @@ def get_default_authority(): from .certificate_credential_base import CertificateCredentialBase from .client_secret_credential_base import ClientSecretCredentialBase from .decorators import wrap_exceptions -from .msal_credentials import InteractiveCredential +from .interactive import InteractiveCredential def _scopes_to_resource(*scopes): diff --git a/sdk/identity/azure-identity/azure/identity/_internal/interactive.py b/sdk/identity/azure-identity/azure/identity/_internal/interactive.py new file mode 100644 index 000000000000..4e226bc0c357 --- /dev/null +++ b/sdk/identity/azure-identity/azure/identity/_internal/interactive.py @@ -0,0 +1,197 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +"""Base class for credentials using MSAL for interactive user authentication""" + +import abc +import base64 +import json +import logging +import time +from typing import TYPE_CHECKING + +import msal +from six.moves.urllib_parse import urlparse +from azure.core.credentials import AccessToken +from azure.core.exceptions import ClientAuthenticationError + +from .msal_credentials import MsalCredential +from .._auth_record import AuthenticationRecord +from .._constants import KnownAuthorities +from .._exceptions import AuthenticationRequiredError, CredentialUnavailableError +from .._internal import wrap_exceptions + +if TYPE_CHECKING: + # pylint:disable=ungrouped-imports,unused-import + from typing import Any, Optional + +_LOGGER = logging.getLogger(__name__) + +_DEFAULT_AUTHENTICATE_SCOPES = { + "https://" + KnownAuthorities.AZURE_CHINA: ("https://management.core.chinacloudapi.cn//.default",), + "https://" + KnownAuthorities.AZURE_GERMANY: ("https://management.core.cloudapi.de//.default",), + "https://" + KnownAuthorities.AZURE_GOVERNMENT: ("https://management.core.usgovcloudapi.net//.default",), + "https://" + KnownAuthorities.AZURE_PUBLIC_CLOUD: ("https://management.core.windows.net//.default",), +} + + +def _decode_client_info(raw): + """Taken from msal.oauth2cli.oidc""" + + raw += "=" * (-len(raw) % 4) + raw = str(raw) # On Python 2.7, argument of urlsafe_b64decode must be str, not unicode. + return base64.urlsafe_b64decode(raw).decode("utf-8") + + +def _build_auth_record(response): + """Build an AuthenticationRecord from the result of an MSAL ClientApplication token request""" + + try: + id_token = response["id_token_claims"] + + if "client_info" in response: + client_info = json.loads(_decode_client_info(response["client_info"])) + home_account_id = "{uid}.{utid}".format(**client_info) + else: + # MSAL uses the subject claim as home_account_id when the STS doesn't provide client_info + home_account_id = id_token["sub"] + + return AuthenticationRecord( + authority=urlparse(id_token["iss"]).netloc, # "iss" is the URL of the issuing tenant + client_id=id_token["aud"], + home_account_id=home_account_id, + tenant_id=id_token["tid"], # tenant which issued the token, not necessarily user's home tenant + username=id_token["preferred_username"], + ) + except (KeyError, ValueError): + # surprising: msal.ClientApplication always requests an id token, whose shape shouldn't change + return None + + +class InteractiveCredential(MsalCredential): + def __init__(self, **kwargs): + self._disable_automatic_authentication = kwargs.pop("disable_automatic_authentication", False) + self._auth_record = kwargs.pop("authentication_record", None) # type: Optional[AuthenticationRecord] + if self._auth_record: + kwargs.pop("client_id", None) # authentication_record overrides client_id argument + tenant_id = kwargs.pop("tenant_id", None) or self._auth_record.tenant_id + super(InteractiveCredential, self).__init__( + client_id=self._auth_record.client_id, + authority=self._auth_record.authority, + tenant_id=tenant_id, + **kwargs + ) + else: + super(InteractiveCredential, self).__init__(**kwargs) + + def get_token(self, *scopes, **kwargs): + # type: (*str, **Any) -> AccessToken + """Request an access token for `scopes`. + + .. note:: This method is called by Azure SDK clients. It isn't intended for use in application code. + + :param str scopes: desired scopes for the access token. This method requires at least one scope. + :rtype: :class:`azure.core.credentials.AccessToken` + :raises CredentialUnavailableError: the credential is unable to attempt authentication because it lacks + required data, state, or platform support + :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` + attribute gives a reason. + :raises AuthenticationRequiredError: user interaction is necessary to acquire a token, and the credential is + configured not to begin this automatically. Call :func:`authenticate` to begin interactive authentication. + """ + if not scopes: + message = "'get_token' requires at least one scope" + _LOGGER.warning("%s.get_token failed: %s", self.__class__.__name__, message) + raise ValueError(message) + + allow_prompt = kwargs.pop("_allow_prompt", not self._disable_automatic_authentication) + try: + token = self._acquire_token_silent(*scopes, **kwargs) + _LOGGER.info("%s.get_token succeeded", self.__class__.__name__) + return token + except Exception as ex: # pylint:disable=broad-except + if not (isinstance(ex, AuthenticationRequiredError) and allow_prompt): + _LOGGER.warning( + "%s.get_token failed: %s", + self.__class__.__name__, + ex, + exc_info=_LOGGER.isEnabledFor(logging.DEBUG), + ) + raise + + # silent authentication failed -> authenticate interactively + now = int(time.time()) + + try: + result = self._request_token(*scopes, **kwargs) + if "access_token" not in result: + message = "Authentication failed: {}".format(result.get("error_description") or result.get("error")) + raise ClientAuthenticationError(message=message) + + # this may be the first authentication, or the user may have authenticated a different identity + self._auth_record = _build_auth_record(result) + except Exception as ex: # pylint:disable=broad-except + _LOGGER.warning( + "%s.get_token failed: %s", self.__class__.__name__, ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG), + ) + raise + + _LOGGER.info("%s.get_token succeeded", self.__class__.__name__) + return AccessToken(result["access_token"], now + int(result["expires_in"])) + + def authenticate(self, **kwargs): + # type: (**Any) -> AuthenticationRecord + """Interactively authenticate a user. + + :keyword Iterable[str] scopes: scopes to request during authentication, such as those provided by + :func:`AuthenticationRequiredError.scopes`. If provided, successful authentication will cache an access token + for these scopes. + :rtype: ~azure.identity.AuthenticationRecord + :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` + attribute gives a reason. + """ + + scopes = kwargs.pop("scopes", None) + if not scopes: + if self._authority not in _DEFAULT_AUTHENTICATE_SCOPES: + # the credential is configured to use a cloud whose ARM scope we can't determine + raise CredentialUnavailableError( + message="Authenticating in this environment requires a value for the 'scopes' keyword argument." + ) + + scopes = _DEFAULT_AUTHENTICATE_SCOPES[self._authority] + + _ = self.get_token(*scopes, _allow_prompt=True, **kwargs) + return self._auth_record # type: ignore + + @wrap_exceptions + def _acquire_token_silent(self, *scopes, **kwargs): + # type: (*str, **Any) -> AccessToken + result = None + if self._auth_record: + app = self._get_app() + for account in app.get_accounts(username=self._auth_record.username): + if account.get("home_account_id") != self._auth_record.home_account_id: + continue + + now = int(time.time()) + result = app.acquire_token_silent_with_error(list(scopes), account=account, **kwargs) + if result and "access_token" in result and "expires_in" in result: + return AccessToken(result["access_token"], now + int(result["expires_in"])) + + # if we get this far, result is either None or the content of an AAD error response + if result: + details = result.get("error_description") or result.get("error") + raise AuthenticationRequiredError(scopes, error_details=details) + raise AuthenticationRequiredError(scopes) + + def _get_app(self): + # type: () -> msal.PublicClientApplication + if not self._msal_app: + self._msal_app = self._create_app(msal.PublicClientApplication) + return self._msal_app + + @abc.abstractmethod + def _request_token(self, *scopes, **kwargs): + pass diff --git a/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py b/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py index ac2966c93e1d..fd5034acd4bb 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py @@ -3,22 +3,13 @@ # Licensed under the MIT License. # ------------------------------------ import abc -import base64 -import json -import logging -import time import msal -from six.moves.urllib_parse import urlparse from azure.core.credentials import AccessToken -from azure.core.exceptions import ClientAuthenticationError from .msal_client import MsalClient from .persistent_cache import load_user_cache -from .._constants import KnownAuthorities -from .._exceptions import AuthenticationRequiredError, CredentialUnavailableError -from .._internal import get_default_authority, normalize_authority, wrap_exceptions -from .._auth_record import AuthenticationRecord +from .._internal import get_default_authority, normalize_authority try: ABC = abc.ABC @@ -34,48 +25,6 @@ # pylint:disable=ungrouped-imports,unused-import from typing import Any, Mapping, Optional, Type, Union -_LOGGER = logging.getLogger(__name__) - -_DEFAULT_AUTHENTICATE_SCOPES = { - "https://" + KnownAuthorities.AZURE_CHINA: ("https://management.core.chinacloudapi.cn//.default",), - "https://" + KnownAuthorities.AZURE_GERMANY: ("https://management.core.cloudapi.de//.default",), - "https://" + KnownAuthorities.AZURE_GOVERNMENT: ("https://management.core.usgovcloudapi.net//.default",), - "https://" + KnownAuthorities.AZURE_PUBLIC_CLOUD: ("https://management.core.windows.net//.default",), -} - - -def _decode_client_info(raw): - """Taken from msal.oauth2cli.oidc""" - - raw += "=" * (-len(raw) % 4) - raw = str(raw) # On Python 2.7, argument of urlsafe_b64decode must be str, not unicode. - return base64.urlsafe_b64decode(raw).decode("utf-8") - - -def _build_auth_record(response): - """Build an AuthenticationRecord from the result of an MSAL ClientApplication token request""" - - try: - id_token = response["id_token_claims"] - - if "client_info" in response: - client_info = json.loads(_decode_client_info(response["client_info"])) - home_account_id = "{uid}.{utid}".format(**client_info) - else: - # MSAL uses the subject claim as home_account_id when the STS doesn't provide client_info - home_account_id = id_token["sub"] - - return AuthenticationRecord( - authority=urlparse(id_token["iss"]).netloc, # "iss" is the URL of the issuing tenant - client_id=id_token["aud"], - home_account_id=home_account_id, - tenant_id=id_token["tid"], # tenant which issued the token, not necessarily user's home tenant - username=id_token["preferred_username"], - ) - except (KeyError, ValueError): - # surprising: msal.ClientApplication always requests an id token, whose shape shouldn't change - return None - class MsalCredential(ABC): """Base class for credentials wrapping MSAL applications""" @@ -123,132 +72,3 @@ def _create_app(self, cls): ) return app - - -class InteractiveCredential(MsalCredential): - def __init__(self, **kwargs): - self._disable_automatic_authentication = kwargs.pop("disable_automatic_authentication", False) - self._auth_record = kwargs.pop("authentication_record", None) # type: Optional[AuthenticationRecord] - if self._auth_record: - kwargs.pop("client_id", None) # authentication_record overrides client_id argument - tenant_id = kwargs.pop("tenant_id", None) or self._auth_record.tenant_id - super(InteractiveCredential, self).__init__( - client_id=self._auth_record.client_id, - authority=self._auth_record.authority, - tenant_id=tenant_id, - **kwargs - ) - else: - super(InteractiveCredential, self).__init__(**kwargs) - - def get_token(self, *scopes, **kwargs): - # type: (*str, **Any) -> AccessToken - """Request an access token for `scopes`. - - .. note:: This method is called by Azure SDK clients. It isn't intended for use in application code. - - :param str scopes: desired scopes for the access token. This method requires at least one scope. - :rtype: :class:`azure.core.credentials.AccessToken` - :raises CredentialUnavailableError: the credential is unable to attempt authentication because it lacks - required data, state, or platform support - :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` - attribute gives a reason. - :raises AuthenticationRequiredError: user interaction is necessary to acquire a token, and the credential is - configured not to begin this automatically. Call :func:`authenticate` to begin interactive authentication. - """ - if not scopes: - message = "'get_token' requires at least one scope" - _LOGGER.warning("%s.get_token failed: %s", self.__class__.__name__, message) - raise ValueError(message) - - allow_prompt = kwargs.pop("_allow_prompt", not self._disable_automatic_authentication) - try: - token = self._acquire_token_silent(*scopes, **kwargs) - _LOGGER.info("%s.get_token succeeded", self.__class__.__name__) - return token - except Exception as ex: # pylint:disable=broad-except - if not (isinstance(ex, AuthenticationRequiredError) and allow_prompt): - _LOGGER.warning( - "%s.get_token failed: %s", - self.__class__.__name__, - ex, - exc_info=_LOGGER.isEnabledFor(logging.DEBUG), - ) - raise - - # silent authentication failed -> authenticate interactively - now = int(time.time()) - - try: - result = self._request_token(*scopes, **kwargs) - if "access_token" not in result: - message = "Authentication failed: {}".format(result.get("error_description") or result.get("error")) - raise ClientAuthenticationError(message=message) - - # this may be the first authentication, or the user may have authenticated a different identity - self._auth_record = _build_auth_record(result) - except Exception as ex: # pylint:disable=broad-except - _LOGGER.warning( - "%s.get_token failed: %s", self.__class__.__name__, ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG), - ) - raise - - _LOGGER.info("%s.get_token succeeded", self.__class__.__name__) - return AccessToken(result["access_token"], now + int(result["expires_in"])) - - def authenticate(self, **kwargs): - # type: (**Any) -> AuthenticationRecord - """Interactively authenticate a user. - - :keyword Iterable[str] scopes: scopes to request during authentication, such as those provided by - :func:`AuthenticationRequiredError.scopes`. If provided, successful authentication will cache an access token - for these scopes. - :rtype: ~azure.identity.AuthenticationRecord - :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` - attribute gives a reason. - """ - - scopes = kwargs.pop("scopes", None) - if not scopes: - if self._authority not in _DEFAULT_AUTHENTICATE_SCOPES: - # the credential is configured to use a cloud whose ARM scope we can't determine - raise CredentialUnavailableError( - message="Authenticating in this environment requires a value for the 'scopes' keyword argument." - ) - - scopes = _DEFAULT_AUTHENTICATE_SCOPES[self._authority] - - _ = self.get_token(*scopes, _allow_prompt=True, **kwargs) - return self._auth_record # type: ignore - - @wrap_exceptions - def _acquire_token_silent(self, *scopes, **kwargs): - # type: (*str, **Any) -> AccessToken - result = None - if self._auth_record: - app = self._get_app() - for account in app.get_accounts(username=self._auth_record.username): - if account.get("home_account_id") != self._auth_record.home_account_id: - continue - - now = int(time.time()) - result = app.acquire_token_silent_with_error(list(scopes), account=account, **kwargs) - if result and "access_token" in result and "expires_in" in result: - return AccessToken(result["access_token"], now + int(result["expires_in"])) - - # if we get this far, result is either None or the content of an AAD error response - if result: - details = result.get("error_description") or result.get("error") - raise AuthenticationRequiredError(scopes, error_details=details) - raise AuthenticationRequiredError(scopes) - - @abc.abstractmethod - def _request_token(self, *scopes, **kwargs): - # type: (*str, **Any) -> dict - """Request an access token via a non-silent MSAL token acquisition method, returning that method's result""" - - def _get_app(self): - # type: () -> msal.PublicClientApplication - if not self._msal_app: - self._msal_app = self._create_app(msal.PublicClientApplication) - return self._msal_app diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/__init__.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/__init__.py index d9146708694d..c5553286c582 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/__init__.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/__init__.py @@ -11,7 +11,7 @@ from .client_secret import ClientSecretCredential from .shared_cache import SharedTokenCacheCredential from .azure_cli import AzureCliCredential -from .vscode_credential import VSCodeCredential +from .vscode import VSCodeCredential __all__ = [ diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py index 2c9d37f68e50..8cbe49b30cae 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py @@ -13,7 +13,7 @@ from .environment import EnvironmentCredential from .managed_identity import ManagedIdentityCredential from .shared_cache import SharedTokenCacheCredential -from .vscode_credential import VSCodeCredential +from .vscode import VSCodeCredential if TYPE_CHECKING: from typing import Any diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/vscode_credential.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/vscode.py similarity index 98% rename from sdk/identity/azure-identity/azure/identity/aio/_credentials/vscode_credential.py rename to sdk/identity/azure-identity/azure/identity/aio/_credentials/vscode.py index c27247aedce3..f49d10cbd152 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/vscode_credential.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/vscode.py @@ -9,7 +9,7 @@ from ..._constants import AZURE_VSCODE_CLIENT_ID from .._internal.aad_client import AadClient from .._internal.decorators import log_get_token_async -from ..._credentials.vscode_credential import get_credentials +from ..._credentials.vscode import get_credentials if TYPE_CHECKING: # pylint:disable=unused-import,ungrouped-imports diff --git a/sdk/identity/azure-identity/tests/test_default.py b/sdk/identity/azure-identity/tests/test_default.py index 88c477dc593a..9191462da2a4 100644 --- a/sdk/identity/azure-identity/tests/test_default.py +++ b/sdk/identity/azure-identity/tests/test_default.py @@ -10,11 +10,11 @@ DefaultAzureCredential, InteractiveBrowserCredential, SharedTokenCacheCredential, + VSCodeCredential, ) from azure.identity._constants import EnvironmentVariables from azure.identity._credentials.azure_cli import AzureCliCredential from azure.identity._credentials.managed_identity import ManagedIdentityCredential -from azure.identity._credentials.vscode_credential import VSCodeCredential import pytest from six.moves.urllib_parse import urlparse diff --git a/sdk/identity/azure-identity/tests/test_default_async.py b/sdk/identity/azure-identity/tests/test_default_async.py index afd89b3f0b8a..a385a5b2a579 100644 --- a/sdk/identity/azure-identity/tests/test_default_async.py +++ b/sdk/identity/azure-identity/tests/test_default_async.py @@ -2,17 +2,19 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -import asyncio import os from unittest.mock import Mock, patch from urllib.parse import urlparse from azure.core.credentials import AccessToken from azure.identity import CredentialUnavailableError -from azure.identity.aio import DefaultAzureCredential, SharedTokenCacheCredential -from azure.identity.aio._credentials.azure_cli import AzureCliCredential -from azure.identity.aio._credentials.managed_identity import ManagedIdentityCredential -from azure.identity.aio._credentials.vscode_credential import VSCodeCredential +from azure.identity.aio import ( + AzureCliCredential, + DefaultAzureCredential, + ManagedIdentityCredential, + SharedTokenCacheCredential, + VSCodeCredential, +) from azure.identity._constants import EnvironmentVariables import pytest diff --git a/sdk/identity/azure-identity/tests/test_interactive_credential.py b/sdk/identity/azure-identity/tests/test_interactive_credential.py index 8bfeaac041a4..645e74f21bd0 100644 --- a/sdk/identity/azure-identity/tests/test_interactive_credential.py +++ b/sdk/identity/azure-identity/tests/test_interactive_credential.py @@ -9,7 +9,7 @@ KnownAuthorities, CredentialUnavailableError, ) -from azure.identity._internal.msal_credentials import InteractiveCredential +from azure.identity._internal import InteractiveCredential from msal import TokenCache import pytest diff --git a/sdk/identity/azure-identity/tests/test_vscode_credential.py b/sdk/identity/azure-identity/tests/test_vscode_credential.py index ed5a0f5af235..ef43604017ae 100644 --- a/sdk/identity/azure-identity/tests/test_vscode_credential.py +++ b/sdk/identity/azure-identity/tests/test_vscode_credential.py @@ -9,7 +9,7 @@ from azure.core.pipeline.policies import SansIOHTTPPolicy from azure.identity._constants import EnvironmentVariables from azure.identity._internal.user_agent import USER_AGENT -from azure.identity._credentials.vscode_credential import get_credentials +from azure.identity._credentials.vscode import get_credentials import pytest from six.moves.urllib_parse import urlparse