Skip to content

Commit

Permalink
Use MSAL's custom transport API (#11892)
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell authored Jun 26, 2020
1 parent b0d25bb commit fdf11b3
Show file tree
Hide file tree
Showing 11 changed files with 158 additions and 332 deletions.
9 changes: 5 additions & 4 deletions sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Release History

## 1.4.0b6 (Unreleased)
- Upgraded minimum `msal` version to 1.3.0
- The async `AzureCliCredential` correctly invokes `/bin/sh`
([#12048](https://github.com/Azure/azure-sdk-for-python/issues/12048))

Expand All @@ -18,14 +19,14 @@
identity by its client ID, continue using the `client_id` argument. To
specify an identity by any other ID, use the `identity_config` argument,
for example: `ManagedIdentityCredential(identity_config={"object_id": ".."})`
([#10989](https://github.com/Azure/azure-sdk-for-python/issues/10989))
([#10989](https://github.com/Azure/azure-sdk-for-python/issues/10989))
- `CertificateCredential` and `ClientSecretCredential` can optionally store
access tokens they acquire in a persistent cache. To enable this, construct
the credential with `enable_persistent_cache=True`. On Linux, the persistent
cache requires libsecret and `pygobject`. If these are unavailable or
unusable (e.g. in an SSH session), loading the persistent cache will raise an
error. You may optionally configure the credential to fall back to an
unencrypted cache by constructing it with keyword argument
unencrypted cache by constructing it with keyword argument
`allow_unencrypted_cache=True`.
([#11347](https://github.com/Azure/azure-sdk-for-python/issues/11347))
- `AzureCliCredential` raises `CredentialUnavailableError` when no user is
Expand Down Expand Up @@ -66,7 +67,7 @@

## 1.4.0b3 (2020-05-04)
- `EnvironmentCredential` correctly initializes `UsernamePasswordCredential`
with the value of `AZURE_TENANT_ID`
with the value of `AZURE_TENANT_ID`
([#11127](https://github.com/Azure/azure-sdk-for-python/pull/11127))
- Values for the constructor keyword argument `authority` and
`AZURE_AUTHORITY_HOST` may optionally specify an "https" scheme. For example,
Expand All @@ -86,7 +87,7 @@ with the value of `AZURE_TENANT_ID`
- `enable_persistent_cache=True` configures these credentials to use a
persistent cache on supported platforms (in this release, Windows only).
By default they cache in memory only.
- Now `DefaultAzureCredential` can authenticate with the identity signed in to
- Now `DefaultAzureCredential` can authenticate with the identity signed in to
Visual Studio Code's Azure extension.
([#10472](https://github.com/Azure/azure-sdk-for-python/issues/10472))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def __init__(self, client_id, username, password, **kwargs):
def _request_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> dict
app = self._get_app()
with self._adapter:
return app.acquire_token_by_username_password(
username=self._username, password=self._password, scopes=list(scopes)
)
return app.acquire_token_by_username_password(
username=self._username, password=self._password, scopes=list(scopes)
)
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ def get_default_authority():
from .certificate_credential_base import CertificateCredentialBase
from .client_secret_credential_base import ClientSecretCredentialBase
from .exception_wrapper import wrap_exceptions
from .msal_credentials import ConfidentialClientCredential, InteractiveCredential, PublicClientCredential
from .msal_transport_adapter import MsalTransportAdapter, MsalTransportResponse
from .msal_credentials import InteractiveCredential, PublicClientCredential


def _scopes_to_resource(*scopes):
Expand All @@ -62,11 +61,8 @@ def _scopes_to_resource(*scopes):
"AadClientCertificate",
"CertificateCredentialBase",
"ClientSecretCredentialBase",
"ConfidentialClientCredential",
"get_default_authority",
"InteractiveCredential",
"MsalTransportAdapter",
"MsalTransportResponse",
"normalize_authority",
"PublicClientCredential",
"wrap_exceptions",
Expand Down
137 changes: 137 additions & 0 deletions sdk/identity/azure-identity/azure/identity/_internal/msal_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import six

from azure.core.configuration import Configuration
from azure.core.exceptions import ClientAuthenticationError
from azure.core.pipeline import Pipeline
from azure.core.pipeline.policies import (
ContentDecodePolicy,
DistributedTracingPolicy,
HttpLoggingPolicy,
NetworkTraceLoggingPolicy,
ProxyPolicy,
RetryPolicy,
UserAgentPolicy,
)
from azure.core.pipeline.transport import HttpRequest, RequestsTransport

from .user_agent import USER_AGENT

try:
from typing import TYPE_CHECKING
except ImportError:
TYPE_CHECKING = False

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Dict, List, Optional, Union
from azure.core.pipeline import PipelineResponse
from azure.core.pipeline.policies import HTTPPolicy, SansIOHTTPPolicy
from azure.core.pipeline.transport import HttpTransport

PolicyList = List[Union[HTTPPolicy, SansIOHTTPPolicy]]
RequestData = Union[Dict[str, str], str]


class MsalResponse(object):
"""Wraps HttpResponse according to msal.oauth2cli.http"""

def __init__(self, response):
# type: (PipelineResponse) -> None
self._response = response

@property
def status_code(self):
# type: () -> int
return self._response.http_response.status_code

@property
def text(self):
# type: () -> str
return self._response.http_response.text(encoding="utf-8")

def raise_for_status(self):
if self.status_code < 400:
return

if ContentDecodePolicy.CONTEXT_NAME in self._response.context:
content = self._response.context[ContentDecodePolicy.CONTEXT_NAME]
if "error" in content or "error_description" in content:
message = "Authentication failed: {}".format(content.get("error_description") or content.get("error"))
else:
for secret in ("access_token", "refresh_token"):
if secret in content:
content[secret] = "***"
message = 'Unexpected response from Azure Active Directory: "{}"'.format(content)
else:
message = "Unexpected response from Azure Active Directory"

raise ClientAuthenticationError(message=message, response=self._response.http_response)


class MsalClient(object):
"""Wraps Pipeline according to msal.oauth2cli.http"""

def __init__(self, **kwargs): # pylint:disable=missing-client-constructor-parameter-credential
# type: (**Any) -> None
self._pipeline = _build_pipeline(**kwargs)

def post(self, url, params=None, data=None, headers=None, **kwargs): # pylint:disable=unused-argument
# type: (str, Optional[Dict[str, str]], RequestData, Optional[Dict[str, str]], **Any) -> MsalResponse
request = HttpRequest("POST", url, headers=headers)
if params:
request.format_parameters(params)
if data:
if isinstance(data, dict):
request.headers["Content-Type"] = "application/x-www-form-urlencoded"
request.set_formdata_body(data)
elif isinstance(data, six.text_type):
body_bytes = six.ensure_binary(data)
request.set_bytes_body(body_bytes)
else:
raise ValueError('expected "data" to be text or a dict')

response = self._pipeline.run(request)
return MsalResponse(response)

def get(self, url, params=None, headers=None, **kwargs): # pylint:disable=unused-argument
# type: (str, Optional[Dict[str, str]], Optional[Dict[str, str]], **Any) -> MsalResponse
request = HttpRequest("GET", url, headers=headers)
if params:
request.format_parameters(params)
response = self._pipeline.run(request)
return MsalResponse(response)


def _create_config(**kwargs):
# type: (Any) -> Configuration
config = Configuration(**kwargs)
config.logging_policy = NetworkTraceLoggingPolicy(**kwargs)
config.retry_policy = RetryPolicy(**kwargs)
config.proxy_policy = ProxyPolicy(**kwargs)
config.user_agent_policy = UserAgentPolicy(base_user_agent=USER_AGENT, **kwargs)
return config


def _build_pipeline(config=None, policies=None, transport=None, **kwargs):
# type: (Optional[Configuration], Optional[PolicyList], Optional[HttpTransport], **Any) -> Pipeline
config = config or _create_config(**kwargs)

if policies is None: # [] is a valid policy list
policies = [
ContentDecodePolicy(),
config.user_agent_policy,
config.proxy_policy,
config.retry_policy,
config.logging_policy,
DistributedTracingPolicy(**kwargs),
HttpLoggingPolicy(**kwargs),
]

if not transport:
transport = RequestsTransport(**kwargs)

return Pipeline(transport=transport, policies=policies)
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
"""Credentials wrapping MSAL applications and delegating token acquisition and caching to them.
This entails monkeypatching MSAL's OAuth client with an adapter substituting an azure-core pipeline for Requests.
"""
import abc
import base64
import json
Expand All @@ -17,7 +14,7 @@
from azure.core.exceptions import ClientAuthenticationError

from .exception_wrapper import wrap_exceptions
from .msal_transport_adapter import MsalTransportAdapter
from .msal_client import MsalClient
from .persistent_cache import load_user_cache
from .._constants import KnownAuthorities
from .._exceptions import AuthenticationRequiredError, CredentialUnavailableError
Expand Down Expand Up @@ -102,7 +99,7 @@ def __init__(self, client_id, client_credential=None, **kwargs):
else:
self._cache = msal.TokenCache()

self._adapter = kwargs.pop("msal_adapter", None) or MsalTransportAdapter(**kwargs)
self._client = MsalClient(**kwargs)

# postpone creating the wrapped application because its initializer uses the network
self._msal_app = None # type: Optional[msal.ClientApplication]
Expand All @@ -119,53 +116,17 @@ def _get_app(self):

def _create_app(self, cls):
# type: (Type[msal.ClientApplication]) -> msal.ClientApplication
"""Creates an MSAL application, patching msal.authority to use an azure-core pipeline during tenant discovery"""

# MSAL application initializers use msal.authority to send AAD tenant discovery requests
with self._adapter:
# MSAL's "authority" is a URL e.g. https://login.microsoftonline.com/common
app = cls(
client_id=self._client_id,
client_credential=self._client_credential,
authority="{}/{}".format(self._authority, self._tenant_id),
token_cache=self._cache,
)

# monkeypatch the app to replace requests.Session with MsalTransportAdapter
app.client.session.close()
app.client.session = self._adapter
app = cls(
client_id=self._client_id,
client_credential=self._client_credential,
authority="{}/{}".format(self._authority, self._tenant_id),
token_cache=self._cache,
http_client=self._client,
)

return app


class ConfidentialClientCredential(MsalCredential):
"""Wraps an MSAL ConfidentialClientApplication with the TokenCredential API"""

@wrap_exceptions
def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
# type: (*str, **Any) -> AccessToken

# MSAL requires scopes be a list
scopes = list(scopes) # type: ignore
now = int(time.time())

# First try to get a cached access token or if a refresh token is cached, redeem it for an access token.
# Failing that, acquire a new token.
app = self._get_app()
result = app.acquire_token_silent(scopes, account=None) or app.acquire_token_for_client(scopes)

if "access_token" not in result:
raise ClientAuthenticationError(message="authentication failed: {}".format(result.get("error_description")))

return AccessToken(result["access_token"], now + int(result["expires_in"]))

def _get_app(self):
# type: () -> msal.ConfidentialClientApplication
if not self._msal_app:
self._msal_app = self._create_app(msal.ConfidentialClientApplication)
return self._msal_app


class PublicClientCredential(MsalCredential):
"""Wraps an MSAL PublicClientApplication with the TokenCredential API"""

Expand Down
Loading

0 comments on commit fdf11b3

Please sign in to comment.