Skip to content

Commit

Permalink
[Identity] Enable CAE toggle per token request
Browse files Browse the repository at this point in the history
Signed-off-by: Paul Van Eck <paulvaneck@microsoft.com>
  • Loading branch information
pvaneck committed Jun 28, 2023
1 parent dc395e5 commit 9a8c432
Show file tree
Hide file tree
Showing 37 changed files with 578 additions and 282 deletions.
2 changes: 2 additions & 0 deletions sdk/core/azure-core/azure/core/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def get_token(
:keyword str claims: Additional claims required in the token, such as those returned in a resource
provider's claims challenge following an authorization failure.
:keyword str tenant_id: Optional tenant to include in the token request.
:keyword str enable_cae: Enables configuring "CP1" client capabilities on all token requests to support
Continuous Access Evaluation (CAE). Defaults to False.
:rtype: AccessToken
:return: An AccessToken instance containing the token string and its expiration time in Unix time.
Expand Down
2 changes: 2 additions & 0 deletions sdk/core/azure-core/azure/core/credentials_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ async def get_token(
:keyword str claims: Additional claims required in the token, such as those returned in a resource
provider's claims challenge following an authorization failure.
:keyword str tenant_id: Optional tenant to include in the token request.
:keyword str enable_cae: Enables configuring "CP1" client capabilities on all token requests to support
Continuous Access Evaluation (CAE). Defaults to False.
:rtype: AccessToken
:return: An AccessToken instance containing the token string and its expiration time in Unix time.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# license information.
# -------------------------------------------------------------------------
import time
from typing import TYPE_CHECKING, Dict, Optional, TypeVar
from typing import TYPE_CHECKING, Dict, Optional, TypeVar, Any

from . import HTTPPolicy, SansIOHTTPPolicy
from ...exceptions import ServiceRequestError
Expand All @@ -30,13 +30,16 @@ class _BearerTokenCredentialPolicyBase:
:param credential: The credential.
:type credential: ~azure.core.credentials.TokenCredential
:param str scopes: Lets you specify the type of access needed.
:keyword str enable_cae: Enables configuring "CP1" client capabilities on all token requests to support
Continuous Access Evaluation (CAE). Defaults to False.
"""

def __init__(self, credential: "TokenCredential", *scopes: str, **kwargs) -> None: # pylint:disable=unused-argument
super(_BearerTokenCredentialPolicyBase, self).__init__()
self._scopes = scopes
self._credential = credential
self._token: Optional["AccessToken"] = None
self._enable_cae: bool = kwargs.get("enable_cae", False)

@staticmethod
def _enforce_https(request: "PipelineRequest") -> None:
Expand Down Expand Up @@ -74,6 +77,8 @@ class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy):
:param credential: The credential.
:type credential: ~azure.core.TokenCredential
:param str scopes: Lets you specify the type of access needed.
:keyword str enable_cae: Enables configuring "CP1" client capabilities on all token requests to support
Continuous Access Evaluation (CAE). Defaults to False.
:raises: :class:`~azure.core.exceptions.ServiceRequestError`
"""

Expand All @@ -87,7 +92,10 @@ def on_request(self, request: "PipelineRequest") -> None:
self._enforce_https(request)

if self._token is None or self._need_new_token:
self._token = self._credential.get_token(*self._scopes)
kwargs: Dict[str, Any] = {}
if self._enable_cae:
kwargs["enable_cae"] = self._enable_cae
self._token = self._credential.get_token(*self._scopes, **kwargs)
self._update_headers(request.http_request.headers, self._token.token)

def authorize_request(self, request: "PipelineRequest", *scopes: str, **kwargs) -> None:
Expand All @@ -99,6 +107,8 @@ def authorize_request(self, request: "PipelineRequest", *scopes: str, **kwargs)
:param ~azure.core.pipeline.PipelineRequest request: the request
:param str scopes: required scopes of authentication
"""
if self._enable_cae:
kwargs["enable_cae"] = kwargs.get("enable_cae", self._enable_cae)
self._token = self._credential.get_token(*scopes, **kwargs)
self._update_headers(request.http_request.headers, self._token.token)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class AsyncBearerTokenCredentialPolicy(AsyncHTTPPolicy):
:param credential: The credential.
:type credential: ~azure.core.credentials.TokenCredential
:param str scopes: Lets you specify the type of access needed.
:keyword str enable_cae: Enables configuring "CP1" client capabilities on all token requests to support
Continuous Access Evaluation (CAE). Defaults to False.
"""

def __init__(self, credential: "AsyncTokenCredential", *scopes: str, **kwargs: Any) -> None:
Expand All @@ -35,6 +37,7 @@ def __init__(self, credential: "AsyncTokenCredential", *scopes: str, **kwargs: A
self._lock = asyncio.Lock()
self._scopes = scopes
self._token: Optional["AccessToken"] = None
self._enable_cae: bool = kwargs.get("enable_cae", False)

async def on_request(self, request: "PipelineRequest") -> None: # pylint:disable=invalid-overridden-method
"""Adds a bearer token Authorization header to request and sends request to next policy.
Expand All @@ -49,7 +52,10 @@ async def on_request(self, request: "PipelineRequest") -> None: # pylint:disabl
async with self._lock:
# double check because another coroutine may have acquired a token while we waited to acquire the lock
if self._token is None or self._need_new_token():
self._token = await await_result(self._credential.get_token, *self._scopes)
get_token_args = {}
if self._enable_cae:
get_token_args["enable_cae"] = self._enable_cae
self._token = await await_result(self._credential.get_token, *self._scopes, **get_token_args)
request.http_request.headers["Authorization"] = "Bearer " + cast(AccessToken, self._token).token

async def authorize_request(self, request: "PipelineRequest", *scopes: str, **kwargs: Any) -> None:
Expand All @@ -61,6 +67,8 @@ async def authorize_request(self, request: "PipelineRequest", *scopes: str, **kw
:param ~azure.core.pipeline.PipelineRequest request: the request
:param str scopes: required scopes of authentication
"""
if self._enable_cae:
kwargs["enable_cae"] = kwargs.get("enable_cae", self._enable_cae)
async with self._lock:
self._token = await await_result(self._credential.get_token, *scopes, **kwargs)
request.http_request.headers["Authorization"] = "Bearer " + cast(AccessToken, self._token).token
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,16 @@
# --------------------------------------------------------------------------
import base64
import time
from typing import Optional, TypeVar
from typing import Optional, TypeVar, TYPE_CHECKING

from azure.core.pipeline.policies import BearerTokenCredentialPolicy, SansIOHTTPPolicy
from azure.core.pipeline import PipelineRequest, PipelineResponse
from azure.core.exceptions import ServiceRequestError

if TYPE_CHECKING:
# pylint:disable=unused-import
from azure.core.credentials import TokenCredential


HTTPRequestType = TypeVar("HTTPRequestType")
HTTPResponseType = TypeVar("HTTPResponseType")
Expand All @@ -46,6 +50,10 @@ class ARMChallengeAuthenticationPolicy(BearerTokenCredentialPolicy):
:param str scopes: required authentication scopes
"""

def __init__(self, credential: "TokenCredential", *scopes: str, **kwargs) -> None: # pylint:disable=unused-argument
kwargs.setdefault("enable_cae", True) # ARM supports Continuous Access Evaluation (CAE).
super().__init__(credential, *scopes, **kwargs)

def on_challenge(
self,
request: PipelineRequest[HTTPRequestType],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
# IN THE SOFTWARE.
#
# --------------------------------------------------------------------------
from typing import TypeVar, Awaitable, Optional
from typing import Any, TypeVar, Awaitable, Optional, TYPE_CHECKING
import inspect

from azure.core.pipeline.policies import (
Expand All @@ -34,6 +34,9 @@

from ._authentication import _parse_claims_challenge, _AuxiliaryAuthenticationPolicyBase

if TYPE_CHECKING:
from azure.core.credentials_async import AsyncTokenCredential


HTTPRequestType = TypeVar("HTTPRequestType")
AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType")
Expand All @@ -57,6 +60,10 @@ class AsyncARMChallengeAuthenticationPolicy(AsyncBearerTokenCredentialPolicy):
:param str scopes: required authentication scopes
"""

def __init__(self, credential: "AsyncTokenCredential", *scopes: str, **kwargs: Any) -> None:
kwargs.setdefault("enable_cae", True) # ARM supports Continuous Access Evaluation (CAE).
super().__init__(credential, *scopes, **kwargs)

# pylint:disable=unused-argument
async def on_challenge(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ async def get_token(*scopes, **kwargs):
assert response.http_response.status_code == 200
assert transport.send.call_count == 2
assert credential.get_token.call_count == 2
credential.get_token.assert_called_with(expected_scope, claims=expected_claims)
credential.get_token.assert_called_with(expected_scope, claims=expected_claims, enable_cae=True)
with pytest.raises(StopIteration):
next(tokens)
with pytest.raises(StopIteration):
Expand Down
2 changes: 1 addition & 1 deletion sdk/core/azure-mgmt-core/tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def get_token(*scopes, **kwargs):
assert response.http_response.status_code == 200
assert transport.send.call_count == 2
assert credential.get_token.call_count == 2
credential.get_token.assert_called_with(expected_scope, claims=expected_claims)
credential.get_token.assert_called_with(expected_scope, claims=expected_claims, enable_cae=True)
with pytest.raises(StopIteration):
next(tokens)
with pytest.raises(StopIteration):
Expand Down
2 changes: 2 additions & 0 deletions sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

### Breaking Changes

- CP1 client capabilities (CAE) is no longer always-on by default for user credentials. This capability will now be configured as-needed in each `get_token` request by each SDK.

### Bugs Fixed

### Other Changes
Expand Down
5 changes: 3 additions & 2 deletions sdk/identity/azure-identity/azure/identity/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
DEFAULT_REFRESH_OFFSET = 300
DEFAULT_TOKEN_REFRESH_RETRY_DELAY = 30

CACHE_PRIMARY_SUFFIX = ".main"
CACHE_CAE_SUFFIX = ".cae"


class AzureAuthorityHosts:
AZURE_CHINA = "login.chinacloudapi.cn"
Expand Down Expand Up @@ -50,5 +53,3 @@ class EnvironmentVariables:

AZURE_FEDERATED_TOKEN_FILE = "AZURE_FEDERATED_TOKEN_FILE"
WORKLOAD_IDENTITY_VARS = (AZURE_AUTHORITY_HOST, AZURE_TENANT_ID, AZURE_FEDERATED_TOKEN_FILE)

AZURE_IDENTITY_DISABLE_CP1 = "AZURE_IDENTITY_DISABLE_CP1"
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
https://learn.microsoft.com/azure/active-directory/develop/scopes-oidc.
:keyword str claims: additional claims required in the token, such as those returned in a resource provider's
claims challenge following an authorization failure
:keyword bool enable_cae: enables configuring "CP1" client capabilities to support Continuous
Access Evaluation (CAE). Defaults to False.
:rtype: :class:`azure.core.credentials.AccessToken`
:raises ~azure.identity.CredentialUnavailableError: the cache is unavailable or contains insufficient user
information
Expand Down Expand Up @@ -96,20 +98,28 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
if not scopes:
raise ValueError("'get_token' requires at least one scope")

if not self._initialized:
self._initialize()
if not self._client_initialized:
self._initialize_client()

if not self._cache:
raise CredentialUnavailableError(message="Shared token cache unavailable")
is_cae = bool(kwargs.get("enable_cae", False))
token_cache = self._cae_cache if is_cae else self._cache

account = self._get_account(self._username, self._tenant_id)
# Try to load the cache if it is None.
if not token_cache:
token_cache = self._initialize_cache(is_cae=is_cae)

token = self._get_cached_access_token(scopes, account)
# If the cache is still None, raise an error.
if not token_cache:
raise CredentialUnavailableError(message="Shared token cache unavailable")

account = self._get_account(self._username, self._tenant_id, is_cae=is_cae)

token = self._get_cached_access_token(scopes, account, is_cae=is_cae)
if token:
return token

# try each refresh token, returning the first access token acquired
for refresh_token in self._get_refresh_tokens(account):
for refresh_token in self._get_refresh_tokens(account, is_cae=is_cae):
token = self._client.obtain_token_by_refresh_token(scopes, refresh_token, **kwargs)
return token

Expand Down
72 changes: 48 additions & 24 deletions sdk/identity/azure-identity/azure/identity/_credentials/silent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import os
import platform
import time
from typing import Dict, Optional, Any

from msal import PublicClientApplication
from msal import PublicClientApplication, TokenCache

from azure.core.credentials import AccessToken

Expand All @@ -17,7 +16,7 @@
from .._internal.msal_client import MsalClient
from .._internal.shared_token_cache import NO_TOKEN
from .._persistent_cache import _load_persistent_cache, TokenCachePersistenceOptions
from .._constants import EnvironmentVariables
from .._constants import CACHE_CAE_SUFFIX, CACHE_PRIMARY_SUFFIX
from .. import AuthenticationRecord


Expand All @@ -37,11 +36,15 @@ def __init__(
self._tenant_id = tenant_id or self._auth_record.tenant_id
validate_tenant_id(self._tenant_id)
self._cache = kwargs.pop("_cache", None)
self._cae_cache = kwargs.pop("_cae_cache", None)

self._cache_persistence_options = kwargs.pop("cache_persistence_options", None)

self._client_applications: Dict[str, PublicClientApplication] = {}
self._cae_client_applications: Dict[str, PublicClientApplication] = {}

self._additionally_allowed_tenants = kwargs.pop("additionally_allowed_tenants", [])
self._client = MsalClient(**kwargs)
self._initialized = False

def __enter__(self):
self._client.__enter__()
Expand All @@ -54,46 +57,67 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
if not scopes:
raise ValueError('"get_token" requires at least one scope')

if not self._initialized:
self._initialize()
token_cache = self._cae_cache if kwargs.get("enable_cae") else self._cache

if not self._cache:
raise CredentialUnavailableError(message="Shared token cache unavailable")
# Try to load the cache if it is None.
if not token_cache:
token_cache = self._initialize_cache(is_cae=bool(kwargs.get("enable_cae")))

# If the cache is still None, raise an error.
if not token_cache:
raise CredentialUnavailableError(message="Shared token cache unavailable")

return self._acquire_token_silent(*scopes, **kwargs)

def _initialize(self):
if not self._cache and platform.system() in {"Darwin", "Linux", "Windows"}:
def _initialize_cache(self, is_cae: bool = False) -> TokenCache:

# If no cache options were provided, the default cache will be used. This credential accepts the
# user's default cache regardless of whether it's encrypted. It doesn't create a new cache. If the
# default cache exists, the user must have created it earlier. If it's unencrypted, the user must
# have allowed that.
cache_options = self._cache_persistence_options or TokenCachePersistenceOptions(allow_unencrypted_storage=True)
is_platform_supported = platform.system() in {"Darwin", "Linux", "Windows"}

if not self._cache and is_platform_supported and not is_cae:
try:
# If no cache options were provided, the default cache will be used. This credential accepts the
# user's default cache regardless of whether it's encrypted. It doesn't create a new cache. If the
# default cache exists, the user must have created it earlier. If it's unencrypted, the user must
# have allowed that.
options = self._cache_persistence_options or \
TokenCachePersistenceOptions(allow_unencrypted_storage=True)
self._cache = _load_persistent_cache(options)
self._cache = _load_persistent_cache(cache_options, cache_suffix=CACHE_PRIMARY_SUFFIX)
except Exception: # pylint:disable=broad-except
pass
return None

self._initialized = True
if not self._cae_cache and is_platform_supported and is_cae:
try:
self._cae_cache = _load_persistent_cache(cache_options, cache_suffix=CACHE_CAE_SUFFIX)
except Exception: # pylint:disable=broad-except
return None

return self._cae_cache if is_cae else self._cache

def _get_client_application(self, **kwargs: Any):
tenant_id = resolve_tenant(
self._tenant_id,
additionally_allowed_tenants=self._additionally_allowed_tenants,
**kwargs
)
if tenant_id not in self._client_applications:

client_applications_map = self._client_applications
capabilities = None
token_cache = self._cache

if kwargs.get("enable_cae"):
client_applications_map = self._cae_client_applications
# CP1 = can handle claims challenges (CAE)
capabilities = None if EnvironmentVariables.AZURE_IDENTITY_DISABLE_CP1 in os.environ else ["CP1"]
self._client_applications[tenant_id] = PublicClientApplication(
capabilities = ["CP1"]
token_cache = self._cae_cache

if tenant_id not in client_applications_map:
client_applications_map[tenant_id] = PublicClientApplication(
client_id=self._auth_record.client_id,
authority="https://{}/{}".format(self._auth_record.authority, tenant_id),
token_cache=self._cache,
token_cache=token_cache,
http_client=self._client,
client_capabilities=capabilities
)
return self._client_applications[tenant_id]
return client_applications_map[tenant_id]

@wrap_exceptions
def _acquire_token_silent(self, *scopes: str, **kwargs: Any) -> AccessToken:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,4 @@ def _run_pipeline(self, request: HttpRequest, **kwargs: Any) -> AccessToken:
kwargs.pop("claims", None)
now = int(time.time())
response = self._pipeline.run(request, retry_on_methods=self._POST, **kwargs)
return self._process_response(response, now)
return self._process_response(response, now, **kwargs)
Loading

0 comments on commit 9a8c432

Please sign in to comment.