Skip to content

Commit

Permalink
[Identity] Enable CAE toggle per token request
Browse files Browse the repository at this point in the history
Signed-off-by: Paul Van Eck <paulvaneck@microsoft.com>
  • Loading branch information
pvaneck committed Jun 20, 2023
1 parent dc395e5 commit 8297f0a
Show file tree
Hide file tree
Showing 20 changed files with 193 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,16 @@ class _BearerTokenCredentialPolicyBase:
:param credential: The credential.
:type credential: ~azure.core.credentials.TokenCredential
:param str scopes: Lets you specify the type of access needed.
:keyword str enable_cae: Enables configuring "CP1" client capabilities on all token requests to support
Continuous Access Evaluation (CAE). Defaults to False.
"""

def __init__(self, credential: "TokenCredential", *scopes: str, **kwargs) -> None: # pylint:disable=unused-argument
super(_BearerTokenCredentialPolicyBase, self).__init__()
self._scopes = scopes
self._credential = credential
self._token: Optional["AccessToken"] = None
self._enable_cae: bool = kwargs.get("enable_cae", False)

@staticmethod
def _enforce_https(request: "PipelineRequest") -> None:
Expand Down Expand Up @@ -74,6 +77,8 @@ class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy):
:param credential: The credential.
:type credential: ~azure.core.TokenCredential
:param str scopes: Lets you specify the type of access needed.
:keyword str enable_cae: Enables configuring "CP1" client capabilities on all token requests to support
Continuous Access Evaluation (CAE). Defaults to False.
:raises: :class:`~azure.core.exceptions.ServiceRequestError`
"""

Expand All @@ -87,7 +92,7 @@ def on_request(self, request: "PipelineRequest") -> None:
self._enforce_https(request)

if self._token is None or self._need_new_token:
self._token = self._credential.get_token(*self._scopes)
self._token = self._credential.get_token(*self._scopes, enable_cae=self._enable_cae)
self._update_headers(request.http_request.headers, self._token.token)

def authorize_request(self, request: "PipelineRequest", *scopes: str, **kwargs) -> None:
Expand All @@ -99,6 +104,7 @@ def authorize_request(self, request: "PipelineRequest", *scopes: str, **kwargs)
:param ~azure.core.pipeline.PipelineRequest request: the request
:param str scopes: required scopes of authentication
"""
kwargs.setdefault("enable_cae", self._enable_cae)
self._token = self._credential.get_token(*scopes, **kwargs)
self._update_headers(request.http_request.headers, self._token.token)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class AsyncBearerTokenCredentialPolicy(AsyncHTTPPolicy):
:param credential: The credential.
:type credential: ~azure.core.credentials.TokenCredential
:param str scopes: Lets you specify the type of access needed.
:keyword str enable_cae: Enables configuring "CP1" client capabilities on all token requests to support
Continuous Access Evaluation (CAE). Defaults to False.
"""

def __init__(self, credential: "AsyncTokenCredential", *scopes: str, **kwargs: Any) -> None:
Expand All @@ -35,6 +37,7 @@ def __init__(self, credential: "AsyncTokenCredential", *scopes: str, **kwargs: A
self._lock = asyncio.Lock()
self._scopes = scopes
self._token: Optional["AccessToken"] = None
self._enable_cae: bool = kwargs.get("enable_cae", False)

async def on_request(self, request: "PipelineRequest") -> None: # pylint:disable=invalid-overridden-method
"""Adds a bearer token Authorization header to request and sends request to next policy.
Expand All @@ -49,7 +52,7 @@ async def on_request(self, request: "PipelineRequest") -> None: # pylint:disabl
async with self._lock:
# double check because another coroutine may have acquired a token while we waited to acquire the lock
if self._token is None or self._need_new_token():
self._token = await await_result(self._credential.get_token, *self._scopes)
self._token = await await_result(self._credential.get_token, *self._scopes, enable_cae=self._enable_cae)
request.http_request.headers["Authorization"] = "Bearer " + cast(AccessToken, self._token).token

async def authorize_request(self, request: "PipelineRequest", *scopes: str, **kwargs: Any) -> None:
Expand All @@ -61,6 +64,7 @@ async def authorize_request(self, request: "PipelineRequest", *scopes: str, **kw
:param ~azure.core.pipeline.PipelineRequest request: the request
:param str scopes: required scopes of authentication
"""
kwargs.setdefault("enable_cae", self._enable_cae)
async with self._lock:
self._token = await await_result(self._credential.get_token, *scopes, **kwargs)
request.http_request.headers["Authorization"] = "Bearer " + cast(AccessToken, self._token).token
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,16 @@
# --------------------------------------------------------------------------
import base64
import time
from typing import Optional, TypeVar
from typing import Optional, TypeVar, TYPE_CHECKING

from azure.core.pipeline.policies import BearerTokenCredentialPolicy, SansIOHTTPPolicy
from azure.core.pipeline import PipelineRequest, PipelineResponse
from azure.core.exceptions import ServiceRequestError

if TYPE_CHECKING:
# pylint:disable=unused-import
from azure.core.credentials import TokenCredential


HTTPRequestType = TypeVar("HTTPRequestType")
HTTPResponseType = TypeVar("HTTPResponseType")
Expand All @@ -46,6 +50,10 @@ class ARMChallengeAuthenticationPolicy(BearerTokenCredentialPolicy):
:param str scopes: required authentication scopes
"""

def __init__(self, credential: TokenCredential, *scopes: str, **kwargs) -> None: # pylint:disable=unused-argument
kwargs.setdefault("enable_cae", True) # ARM supports Continuous Access Evaluation (CAE).
super().__init__(credential, *scopes, **kwargs)

def on_challenge(
self,
request: PipelineRequest[HTTPRequestType],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
# IN THE SOFTWARE.
#
# --------------------------------------------------------------------------
from typing import TypeVar, Awaitable, Optional
from typing import Any, TypeVar, Awaitable, Optional, TYPE_CHECKING
import inspect

from azure.core.pipeline.policies import (
Expand All @@ -34,6 +34,9 @@

from ._authentication import _parse_claims_challenge, _AuxiliaryAuthenticationPolicyBase

if TYPE_CHECKING:
from azure.core.credentials_async import AsyncTokenCredential


HTTPRequestType = TypeVar("HTTPRequestType")
AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType")
Expand All @@ -57,6 +60,10 @@ class AsyncARMChallengeAuthenticationPolicy(AsyncBearerTokenCredentialPolicy):
:param str scopes: required authentication scopes
"""

def __init__(self, credential: AsyncTokenCredential, *scopes: str, **kwargs: Any) -> None:
kwargs.setdefault("enable_cae", True) # ARM supports Continuous Access Evaluation (CAE).
super().__init__(credential, *scopes, **kwargs)

# pylint:disable=unused-argument
async def on_challenge(
self,
Expand Down
2 changes: 2 additions & 0 deletions sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

### Breaking Changes

- CP1 client capabilities (CAE) is no longer always-on by default for user credentials. This capability will now be configured as-needed in each `get_token` request by each SDK.

### Bugs Fixed

### Other Changes
Expand Down
5 changes: 3 additions & 2 deletions sdk/identity/azure-identity/azure/identity/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
DEFAULT_REFRESH_OFFSET = 300
DEFAULT_TOKEN_REFRESH_RETRY_DELAY = 30

CACHE_PRIMARY_SUFFIX = ""
CACHE_CAE_SUFFIX = ".cae"


class AzureAuthorityHosts:
AZURE_CHINA = "login.chinacloudapi.cn"
Expand Down Expand Up @@ -50,5 +53,3 @@ class EnvironmentVariables:

AZURE_FEDERATED_TOKEN_FILE = "AZURE_FEDERATED_TOKEN_FILE"
WORKLOAD_IDENTITY_VARS = (AZURE_AUTHORITY_HOST, AZURE_TENANT_ID, AZURE_FEDERATED_TOKEN_FILE)

AZURE_IDENTITY_DISABLE_CP1 = "AZURE_IDENTITY_DISABLE_CP1"
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
https://learn.microsoft.com/azure/active-directory/develop/scopes-oidc.
:keyword str claims: additional claims required in the token, such as those returned in a resource provider's
claims challenge following an authorization failure
:keyword bool enable_cae: enables configuring "CP1" client capabilities to support Continuous
Access Evaluation (CAE). Defaults to False.
:rtype: :class:`azure.core.credentials.AccessToken`
:raises ~azure.identity.CredentialUnavailableError: the cache is unavailable or contains insufficient user
information
Expand Down
51 changes: 36 additions & 15 deletions sdk/identity/azure-identity/azure/identity/_credentials/silent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import os
import platform
import time
from typing import Dict, Optional, Any
Expand All @@ -17,7 +16,7 @@
from .._internal.msal_client import MsalClient
from .._internal.shared_token_cache import NO_TOKEN
from .._persistent_cache import _load_persistent_cache, TokenCachePersistenceOptions
from .._constants import EnvironmentVariables
from .._constants import CACHE_CAE_SUFFIX, CACHE_PRIMARY_SUFFIX
from .. import AuthenticationRecord


Expand All @@ -37,8 +36,13 @@ def __init__(
self._tenant_id = tenant_id or self._auth_record.tenant_id
validate_tenant_id(self._tenant_id)
self._cache = kwargs.pop("_cache", None)
self._cae_cache = kwargs.pop("_cae_cache", None)

self._cache_persistence_options = kwargs.pop("cache_persistence_options", None)

self._client_applications: Dict[str, PublicClientApplication] = {}
self._cae_client_applications: Dict[str, PublicClientApplication] = {}

self._additionally_allowed_tenants = kwargs.pop("additionally_allowed_tenants", [])
self._client = MsalClient(**kwargs)
self._initialized = False
Expand All @@ -63,15 +67,23 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
return self._acquire_token_silent(*scopes, **kwargs)

def _initialize(self):
if not self._cache and platform.system() in {"Darwin", "Linux", "Windows"}:

# If no cache options were provided, the default cache will be used. This credential accepts the
# user's default cache regardless of whether it's encrypted. It doesn't create a new cache. If the
# default cache exists, the user must have created it earlier. If it's unencrypted, the user must
# have allowed that.
cache_options = self._cache_persistence_options or TokenCachePersistenceOptions(allow_unencrypted_storage=True)
is_platform_supported = platform.system() in {"Darwin", "Linux", "Windows"}

if not self._cache and is_platform_supported:
try:
self._cache = _load_persistent_cache(cache_options, cache_suffix=CACHE_PRIMARY_SUFFIX)
except Exception: # pylint:disable=broad-except
pass

if not self._cae_cache and is_platform_supported:
try:
# If no cache options were provided, the default cache will be used. This credential accepts the
# user's default cache regardless of whether it's encrypted. It doesn't create a new cache. If the
# default cache exists, the user must have created it earlier. If it's unencrypted, the user must
# have allowed that.
options = self._cache_persistence_options or \
TokenCachePersistenceOptions(allow_unencrypted_storage=True)
self._cache = _load_persistent_cache(options)
self._cae_cache = _load_persistent_cache(cache_options, cache_suffix=CACHE_CAE_SUFFIX)
except Exception: # pylint:disable=broad-except
pass

Expand All @@ -83,17 +95,26 @@ def _get_client_application(self, **kwargs: Any):
additionally_allowed_tenants=self._additionally_allowed_tenants,
**kwargs
)
if tenant_id not in self._client_applications:

client_applications_map = self._client_applications
capabilities = None
token_cache = self._cache

if kwargs.get("enable_cae"):
client_applications_map = self._cae_client_applications
# CP1 = can handle claims challenges (CAE)
capabilities = None if EnvironmentVariables.AZURE_IDENTITY_DISABLE_CP1 in os.environ else ["CP1"]
self._client_applications[tenant_id] = PublicClientApplication(
capabilities = ["CP1"]
token_cache = self._cae_cache

if tenant_id not in client_applications_map:
client_applications_map[tenant_id] = PublicClientApplication(
client_id=self._auth_record.client_id,
authority="https://{}/{}".format(self._auth_record.authority, tenant_id),
token_cache=self._cache,
token_cache=token_cache,
http_client=self._client,
client_capabilities=capabilities
)
return self._client_applications[tenant_id]
return client_applications_map[tenant_id]

@wrap_exceptions
def _acquire_token_silent(self, *scopes: str, **kwargs: Any) -> AccessToken:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,4 @@ def _run_pipeline(self, request: HttpRequest, **kwargs: Any) -> AccessToken:
kwargs.pop("claims", None)
now = int(time.time())
response = self._pipeline.run(request, retry_on_methods=self._POST, **kwargs)
return self._process_response(response, now)
return self._process_response(response, now, **kwargs)
Loading

0 comments on commit 8297f0a

Please sign in to comment.