-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
credential wrapping MSAL's ConfidentialClientApplication
- Loading branch information
Showing
4 changed files
with
174 additions
and
9 deletions.
There are no files selected for viewing
6 changes: 6 additions & 0 deletions
6
sdk/identity/azure-identity/azure/identity/_internal/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# ------------------------------------ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
# ------------------------------------ | ||
from .confidential_client_credential import ConfidentialClientCredential | ||
from .msal_transport_adapter import MsalTransportResponse, MsalTransportAdapter |
68 changes: 68 additions & 0 deletions
68
sdk/identity/azure-identity/azure/identity/_internal/confidential_client_credential.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# ------------------------------------ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
# ------------------------------------ | ||
"""A credential which wraps an MSAL ConfidentialClientApplication and delegates token acquisition and caching to it. | ||
This entails monkeypatching MSAL's OAuth client with an adapter substituting an azure-core pipeline for Requests. | ||
""" | ||
|
||
import time | ||
|
||
try: | ||
from typing import TYPE_CHECKING | ||
except ImportError: | ||
TYPE_CHECKING = False | ||
|
||
try: | ||
from unittest import mock | ||
except ImportError: # python < 3.3 | ||
import mock # type: ignore | ||
|
||
if TYPE_CHECKING: | ||
# pylint:disable=unused-import | ||
from typing import Any, Mapping, Optional, Union | ||
|
||
from azure.core.credentials import AccessToken | ||
import msal | ||
|
||
from .msal_transport_adapter import MsalTransportAdapter | ||
|
||
|
||
class ConfidentialClientCredential(MsalTransportAdapter): | ||
"""Wraps an MSAL ConfidentialClientApplication with the TokenCredential API""" | ||
|
||
def __init__(self, client_id, client_credential, authority, **kwargs): | ||
# type: (str, str, Union[str, Mapping[str, str]], Any) -> None | ||
super(ConfidentialClientCredential, self).__init__(**kwargs) | ||
|
||
self._client_id = client_id | ||
self._client_credential = client_credential | ||
self._authority = authority | ||
|
||
# postpone creating the wrapped application because its initializer uses the network | ||
self._app = None # type: Optional[msal.ConfidentialClientApplication] | ||
|
||
def get_token(self, *scopes): | ||
# type: (str) -> AccessToken | ||
|
||
if not self._app: | ||
self._app = self._create_msal_application() | ||
|
||
# MSAL requires scopes be a list | ||
scopes = list(scopes) # type: ignore | ||
now = int(time.time()) | ||
|
||
# First try to get a cached access token or if a refresh token is cached, redeem it for an access token. | ||
# Failing that, acquire a new token. | ||
result = self._app.acquire_token_silent(scopes, account=None) or self._app.acquire_token_for_client(scopes) | ||
return AccessToken(result["access_token"], now + int(result["expires_in"])) | ||
|
||
def _create_msal_application(self): | ||
# ConfidentialClientApplication's initializer uses msal.authority to send requests to AAD | ||
with mock.patch("msal.authority.requests", self): | ||
app = msal.ConfidentialClientApplication( | ||
client_id=self._client_id, client_credential=self._client_credential, authority=self._authority | ||
) | ||
# replace the client's requests.Session with adapter | ||
app.client.session = self | ||
return app |
87 changes: 87 additions & 0 deletions
87
sdk/identity/azure-identity/azure/identity/_internal/msal_transport_adapter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# ------------------------------------ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
# ------------------------------------ | ||
"""Adapter to substitute an azure-core pipeline for Requests in MSAL application token acquisition methods.""" | ||
|
||
import json | ||
|
||
try: | ||
from typing import TYPE_CHECKING | ||
except ImportError: | ||
TYPE_CHECKING = False | ||
|
||
if TYPE_CHECKING: | ||
# pylint:disable=unused-import | ||
from typing import Any, Dict, Mapping, Optional | ||
from azure.core.pipeline import PipelineResponse | ||
|
||
from azure.core.configuration import Configuration | ||
from azure.core.exceptions import ClientAuthenticationError | ||
from azure.core.pipeline import Pipeline | ||
from azure.core.pipeline.policies import ContentDecodePolicy, NetworkTraceLoggingPolicy, RetryPolicy | ||
from azure.core.pipeline.transport import HttpRequest, RequestsTransport | ||
|
||
|
||
class MsalTransportResponse: | ||
"""Wraps an azure-core PipelineResponse with the shape of requests.Response""" | ||
|
||
def __init__(self, pipeline_response): | ||
# type: (PipelineResponse) -> None | ||
self._response = pipeline_response.http_response | ||
self.status_code = self._response.status_code | ||
self.text = self._response.text() | ||
|
||
def json(self, **kwargs): | ||
# type: (Any) -> Mapping[str, Any] | ||
return json.loads(self.text, **kwargs) | ||
|
||
def raise_for_status(self): | ||
# type: () -> None | ||
raise ClientAuthenticationError("authentication failed", self._response) | ||
|
||
|
||
class MsalTransportAdapter: | ||
"""Wraps an azure-core pipeline with the shape of requests.Session""" | ||
|
||
def __init__(self, **kwargs): | ||
# type: (Any) -> None | ||
self._pipeline = self._build_pipeline(**kwargs) | ||
|
||
@staticmethod | ||
def create_config(**kwargs): | ||
# type: (Any) -> Configuration | ||
config = Configuration(**kwargs) | ||
config.logging_policy = NetworkTraceLoggingPolicy(**kwargs) | ||
config.retry_policy = RetryPolicy(**kwargs) | ||
return config | ||
|
||
def _build_pipeline(self, config=None, policies=None, transport=None, **kwargs): | ||
config = config or self.create_config(**kwargs) | ||
policies = policies or [ContentDecodePolicy(), config.retry_policy, config.logging_policy] | ||
if not transport: | ||
transport = RequestsTransport(configuration=config) | ||
return Pipeline(transport=transport, policies=policies) | ||
|
||
def get(self, url, headers=None, params=None, timeout=None, verify=None, **kwargs): | ||
# type: (str, Optional[Mapping[str, str]], Optional[Dict[str, str]], float, bool, Any) -> MsalTransportResponse | ||
request = HttpRequest("GET", url, headers=headers) | ||
if params: | ||
request.format_parameters(params) | ||
response = self._pipeline.run( | ||
request, stream=False, connection_timeout=timeout, connection_verify=verify, **kwargs | ||
) | ||
return MsalTransportResponse(response) | ||
|
||
def post(self, url, data=None, headers=None, params=None, timeout=None, verify=None, **kwargs): | ||
# type: (str, Optional[Mapping[str, str]], Optional[Mapping[str, str]], Optional[Dict[str, str]], float, bool, Any) -> MsalTransportResponse | ||
request = HttpRequest("POST", url, headers=headers) | ||
if params: | ||
request.format_parameters(params) | ||
if data: | ||
request.headers["Content-Type"] = "application/x-www-form-urlencoded" | ||
request.set_formdata_body(data) | ||
response = self._pipeline.run( | ||
request, stream=False, connection_timeout=timeout, connection_verify=verify, **kwargs | ||
) | ||
return MsalTransportResponse(response) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters