Skip to content

Commit

Permalink
Add support for regional STS (Azure#19392)
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell authored and rakshith91 committed Jul 16, 2021
1 parent 44b1ef9 commit d120b15
Show file tree
Hide file tree
Showing 13 changed files with 176 additions and 22 deletions.
5 changes: 5 additions & 0 deletions sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
- `InteractiveBrowserCredential` keyword argument `login_hint` enables
pre-filling the username/email address field on the login page
([#19225](https://github.com/Azure/azure-sdk-for-python/issues/19225))
- `CertificateCredential` and `ClientSecretCredential` support regional STS
on Azure VMs by either keyword argument `regional_authority` or environment
variable `AZURE_REGIONAL_AUTHORITY_NAME`. See `azure.identity.RegionalAuthority`
for possible values.
([#19301](https://github.com/Azure/azure-sdk-for-python/issues/19301))

## 1.7.0b1 (2021-06-08)
Beginning with this release, this library requires Python 2.7 or 3.6+.
Expand Down
2 changes: 2 additions & 0 deletions sdk/identity/azure-identity/azure/identity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""Credentials for Azure SDK clients."""

from ._auth_record import AuthenticationRecord
from ._enums import RegionalAuthority
from ._exceptions import AuthenticationRequiredError, CredentialUnavailableError
from ._constants import AzureAuthorityHosts, KnownAuthorities
from ._credentials import (
Expand Down Expand Up @@ -42,6 +43,7 @@
"EnvironmentCredential",
"InteractiveBrowserCredential",
"KnownAuthorities",
"RegionalAuthority",
"ManagedIdentityCredential",
"SharedTokenCacheCredential",
"TokenCachePersistenceOptions",
Expand Down
1 change: 1 addition & 0 deletions sdk/identity/azure-identity/azure/identity/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,4 @@ class EnvironmentVariables:
MSI_SECRET = "MSI_SECRET"

AZURE_AUTHORITY_HOST = "AZURE_AUTHORITY_HOST"
AZURE_REGIONAL_AUTHORITY_NAME = "AZURE_REGIONAL_AUTHORITY_NAME"
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
class CertificateCredential(ClientCredentialBase):
"""Authenticates as a service principal using a certificate.
The certificate must have an RSA private key, because this credential signs assertions using RS256.
See Azure Active Directory documentation for more information on configuring certificate authentication:
https://docs.microsoft.com/azure/active-directory/develop/active-directory-certificate-credentials#register-your-certificate-with-microsoft-identity-platform
The certificate must have an RSA private key, because this credential signs assertions using RS256. See
`Azure Active Directory documentation
<https://docs.microsoft.com/azure/active-directory/develop/active-directory-certificate-credentials#register-your-certificate-with-microsoft-identity-platform>`_
for more information on configuring certificate authentication.
:param str tenant_id: ID of the service principal's tenant. Also called its 'directory' ID.
:param str client_id: the service principal's client ID
Expand All @@ -44,6 +44,9 @@ class CertificateCredential(ClientCredentialBase):
:keyword cache_persistence_options: configuration for persistent token caching. If unspecified, the credential
will cache tokens in memory.
:paramtype cache_persistence_options: ~azure.identity.TokenCachePersistenceOptions
:keyword ~azure.identity.RegionalAuthority regional_authority: a :class:`~azure.identity.RegionalAuthority` to
which the credential will authenticate. This argument should be used only by applications deployed to Azure
VMs.
"""

def __init__(self, tenant_id, client_id, certificate_path=None, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ class ClientSecretCredential(ClientCredentialBase):
:keyword cache_persistence_options: configuration for persistent token caching. If unspecified, the credential
will cache tokens in memory.
:paramtype cache_persistence_options: ~azure.identity.TokenCachePersistenceOptions
:keyword ~azure.identity.RegionalAuthority regional_authority: a :class:`~azure.identity.RegionalAuthority` to
which the credential will authenticate. This argument should be used only by applications deployed to Azure
VMs.
"""

def __init__(self, tenant_id, client_id, client_secret, **kwargs):
Expand Down
71 changes: 71 additions & 0 deletions sdk/identity/azure-identity/azure/identity/_enums.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
# pylint:skip-file (avoids crash due to six.with_metaclass https://github.com/PyCQA/astroid/issues/713)
from enum import Enum
from six import with_metaclass

from azure.core import CaseInsensitiveEnumMeta
from msal import ConfidentialClientApplication


class RegionalAuthority(with_metaclass(CaseInsensitiveEnumMeta, str, Enum)):
"""Identifies a regional authority for authentication"""

#: Attempt to discover the appropriate authority. This works on some Azure hosts, such as VMs and
#: Azure Functions. The non-regional authority is used when discovery fails.
AUTO_DISCOVER_REGION = ConfidentialClientApplication.ATTEMPT_REGION_DISCOVERY

ASIA_EAST = "eastasia"
ASIA_SOUTHEAST = "southeastasia"
AUSTRALIA_CENTRAL = "australiacentral"
AUSTRALIA_CENTRAL_2 = "australiacentral2"
AUSTRALIA_EAST = "australiaeast"
AUSTRALIA_SOUTHEAST = "australiasoutheast"
BRAZIL_SOUTH = "brazilsouth"
CANADA_CENTRAL = "canadacentral"
CANADA_EAST = "canadaeast"
CHINA_EAST = "chinaeast"
CHINA_EAST_2 = "chinaeast2"
CHINA_NORTH = "chinanorth"
CHINA_NORTH_2 = "chinanorth2"
EUROPE_NORTH = "northeurope"
EUROPE_WEST = "westeurope"
FRANCE_CENTRAL = "francecentral"
FRANCE_SOUTH = "francesouth"
GERMANY_CENTRAL = "germanycentral"
GERMANY_NORTH = "germanynorth"
GERMANY_NORTHEAST = "germanynortheast"
GERMANY_WEST_CENTRAL = "germanywestcentral"
GOVERNMENT_US_ARIZONA = "usgovarizona"
GOVERNMENT_US_DOD_CENTRAL = "usdodcentral"
GOVERNMENT_US_DOD_EAST = "usdodeast"
GOVERNMENT_US_IOWA = "usgoviowa"
GOVERNMENT_US_TEXAS = "usgovtexas"
GOVERNMENT_US_VIRGINIA = "usgovvirginia"
INDIA_CENTRAL = "centralindia"
INDIA_SOUTH = "southindia"
INDIA_WEST = "westindia"
JAPAN_EAST = "japaneast"
JAPAN_WEST = "japanwest"
KOREA_CENTRAL = "koreacentral"
KOREA_SOUTH = "koreasouth"
NORWAY_EAST = "norwayeast"
NORWAY_WEST = "norwaywest"
SOUTH_AFRICA_NORTH = "southafricanorth"
SOUTH_AFRICA_WEST = "southafricawest"
SWITZERLAND_NORTH = "switzerlandnorth"
SWITZERLAND_WEST = "switzerlandwest"
UAE_CENTRAL = "uaecentral"
UAE_NORTH = "uaenorth"
UK_SOUTH = "uksouth"
UK_WEST = "ukwest"
US_CENTRAL = "centralus"
US_EAST = "eastus"
US_EAST_2 = "eastus2"
US_NORTH_CENTRAL = "northcentralus"
US_SOUTH_CENTRAL = "southcentralus"
US_WEST = "westus"
US_WEST_2 = "westus2"
US_WEST_CENTRAL = "westcentralus"
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
# Licensed under the MIT License.
# ------------------------------------
import abc
import os

import msal

from .msal_client import MsalClient
from .._constants import EnvironmentVariables
from .._internal import get_default_authority, normalize_authority, validate_tenant_id
from .._persistent_cache import _load_persistent_cache, TokenCachePersistenceOptions
from .._persistent_cache import _load_persistent_cache

try:
ABC = abc.ABC
Expand All @@ -22,7 +24,7 @@

if TYPE_CHECKING:
# pylint:disable=ungrouped-imports,unused-import
from typing import Any, Mapping, Optional, Type, Union
from typing import Any, Optional, Type, Union


class MsalCredential(ABC):
Expand All @@ -32,6 +34,9 @@ def __init__(self, client_id, client_credential=None, **kwargs):
# type: (str, Optional[Union[str, dict]], **Any) -> None
authority = kwargs.pop("authority", None)
self._authority = normalize_authority(authority) if authority else get_default_authority()
self._regional_authority = kwargs.pop(
"regional_authority", os.environ.get(EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME)
)
self._tenant_id = kwargs.pop("tenant_id", None) or "organizations"
validate_tenant_id(self._tenant_id)

Expand Down Expand Up @@ -63,6 +68,7 @@ def _create_app(self, cls, **kwargs):
client_id=self._client_id,
client_credential=self._client_credential,
authority="{}/{}".format(self._authority, self._tenant_id),
azure_region=self._regional_authority,
token_cache=self._cache,
http_client=self._client,
**kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
class CertificateCredential(AsyncContextManager, GetTokenMixin):
"""Authenticates as a service principal using a certificate.
The certificate must have an RSA private key, because this credential signs assertions using RS256.
See Azure Active Directory documentation for more information on configuring certificate authentication:
https://docs.microsoft.com/azure/active-directory/develop/active-directory-certificate-credentials#register-your-certificate-with-microsoft-identity-platform
The certificate must have an RSA private key, because this credential signs assertions using RS256. See
`Azure Active Directory documentation
<https://docs.microsoft.com/azure/active-directory/develop/active-directory-certificate-credentials#register-your-certificate-with-microsoft-identity-platform>`_
for more information on configuring certificate authentication.
:param str tenant_id: ID of the service principal's tenant. Also called its 'directory' ID.
:param str client_id: the service principal's client ID
Expand Down
4 changes: 2 additions & 2 deletions sdk/identity/azure-identity/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@
]
),
install_requires=[
"azure-core<2.0.0,>=1.0.0",
"azure-core<2.0.0,>=1.11.0",
"cryptography>=2.1.4",
"msal<2.0.0,>=1.7.0",
"msal<2.0.0,>=1.12.0",
"msal-extensions~=0.3.0",
"six>=1.12.0",
],
Expand Down
37 changes: 34 additions & 3 deletions sdk/identity/azure-identity/tests/test_certificate_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os

from azure.core.pipeline.policies import ContentDecodePolicy, SansIOHTTPPolicy
from azure.identity import CertificateCredential, TokenCachePersistenceOptions
from azure.identity import CertificateCredential, RegionalAuthority, TokenCachePersistenceOptions
from azure.identity._constants import EnvironmentVariables
from azure.identity._internal.user_agent import USER_AGENT
from cryptography import x509
Expand All @@ -25,7 +25,6 @@
mock_response,
msal_validating_transport,
Request,
validating_transport,
)

try:
Expand Down Expand Up @@ -78,7 +77,7 @@ def test_policies_configurable():
policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock())

transport = msal_validating_transport(
requests=[Request()], responses=[mock_response(json_payload=build_aad_response(access_token="**"))],
requests=[Request()], responses=[mock_response(json_payload=build_aad_response(access_token="**"))]
)

credential = CertificateCredential(
Expand Down Expand Up @@ -135,6 +134,38 @@ def test_authority(authority):
assert kwargs["authority"] == expected_authority


def test_regional_authority():
"""the credential should configure MSAL with a regional authority specified via kwarg or environment variable"""

mock_confidential_client = Mock(
return_value=Mock(acquire_token_silent_with_error=lambda *_, **__: {"access_token": "**", "expires_in": 3600}),
)

for region in RegionalAuthority:
mock_confidential_client.reset_mock()

with patch.dict("os.environ", {}, clear=True):
credential = CertificateCredential("tenant", "client-id", CERT_PATH, regional_authority=region)
with patch("msal.ConfidentialClientApplication", mock_confidential_client):
# must call get_token because the credential constructs the MSAL application lazily
credential.get_token("scope")

assert mock_confidential_client.call_count == 1
_, kwargs = mock_confidential_client.call_args
assert kwargs["azure_region"] == region
mock_confidential_client.reset_mock()

# region can be configured via environment variable
with patch.dict("os.environ", {EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: region}, clear=True):
credential = CertificateCredential("tenant", "client-id", CERT_PATH)
with patch("msal.ConfidentialClientApplication", mock_confidential_client):
credential.get_token("scope")

assert mock_confidential_client.call_count == 1
_, kwargs = mock_confidential_client.call_args
assert kwargs["azure_region"] == region


def test_requires_certificate():
"""the credential should raise ValueError when not given a certificate"""

Expand Down
40 changes: 36 additions & 4 deletions sdk/identity/azure-identity/tests/test_client_secret_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
# Licensed under the MIT License.
# ------------------------------------
from azure.core.pipeline.policies import ContentDecodePolicy, SansIOHTTPPolicy
from azure.identity import ClientSecretCredential, TokenCachePersistenceOptions
from azure.identity import ClientSecretCredential, RegionalAuthority, TokenCachePersistenceOptions
from azure.identity._constants import EnvironmentVariables
from azure.identity._internal.user_agent import USER_AGENT
from msal import TokenCache
import pytest
from six.moves.urllib_parse import urlparse

from helpers import build_aad_response, mock_response, msal_validating_transport, Request, validating_transport
from helpers import build_aad_response, mock_response, msal_validating_transport, Request

try:
from unittest.mock import Mock, patch
Expand Down Expand Up @@ -43,7 +43,7 @@ def test_policies_configurable():
policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock())

transport = msal_validating_transport(
requests=[Request()], responses=[mock_response(json_payload=build_aad_response(access_token="**"))],
requests=[Request()], responses=[mock_response(json_payload=build_aad_response(access_token="**"))]
)

credential = ClientSecretCredential(
Expand Down Expand Up @@ -117,6 +117,38 @@ def test_authority(authority):
assert kwargs["authority"] == expected_authority


def test_regional_authority():
"""the credential should configure MSAL with a regional authority specified via kwarg or environment variable"""

mock_confidential_client = Mock(
return_value=Mock(acquire_token_silent_with_error=lambda *_, **__: {"access_token": "**", "expires_in": 3600})
)

for region in RegionalAuthority:
mock_confidential_client.reset_mock()

with patch.dict("os.environ", {}, clear=True):
credential = ClientSecretCredential("tenant", "client-id", "secret", regional_authority=region)
with patch("msal.ConfidentialClientApplication", mock_confidential_client):
# must call get_token because the credential constructs the MSAL application lazily
credential.get_token("scope")

assert mock_confidential_client.call_count == 1
_, kwargs = mock_confidential_client.call_args
assert kwargs["azure_region"] == region
mock_confidential_client.reset_mock()

# region can be configured via environment variable
with patch.dict("os.environ", {EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: region}, clear=True):
credential = ClientSecretCredential("tenant", "client-id", "secret")
with patch("msal.ConfidentialClientApplication", mock_confidential_client):
credential.get_token("scope")

assert mock_confidential_client.call_count == 1
_, kwargs = mock_confidential_client.call_args
assert kwargs["azure_region"] == region


def test_token_cache():
"""the credential should default to an in memory cache, and optionally use a persistent cache"""

Expand All @@ -126,7 +158,7 @@ def test_token_cache():
assert isinstance(credential._cache, TokenCache)

ClientSecretCredential(
"tenant", "client-id", "secret", cache_persistence_options=TokenCachePersistenceOptions(),
"tenant", "client-id", "secret", cache_persistence_options=TokenCachePersistenceOptions()
)
assert mock_msal_extensions.PersistedTokenCache.call_count == 1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def test_token_cache():
assert not mock_msal_extensions.PersistedTokenCache.called

ClientSecretCredential(
"tenant", "client-id", "secret", cache_persistence_options=TokenCachePersistenceOptions(),
"tenant", "client-id", "secret", cache_persistence_options=TokenCachePersistenceOptions()
)
assert mock_msal_extensions.PersistedTokenCache.call_count == 1

Expand Down
4 changes: 2 additions & 2 deletions shared_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ futures
mock
typing
typing-extensions
msal<2.0.0,>=1.7.0
msal<2.0.0,>=1.12.0
msal-extensions~=0.3.0
msrest>=0.5.0
msrestazure<2.0.0,>=0.4.32
Expand All @@ -134,7 +134,7 @@ pyjwt>=1.7.1
#override azure-cosmos azure-core<2.0.0,>=1.0.0
#override azure-data-tables azure-core<2.0.0,>=1.14.0
#override azure-eventhub azure-core<2.0.0,>=1.14.0
#override azure-identity azure-core<2.0.0,>=1.0.0
#override azure-identity azure-core<2.0.0,>=1.11.0
#override azure-keyvault-administration msrest>=0.6.21
#override azure-keyvault-administration azure-core<2.0.0,>=1.11.0
#override azure-keyvault-certificates msrest>=0.6.21
Expand Down

0 comments on commit d120b15

Please sign in to comment.