Skip to content

Commit

Permalink
credential wrapping MSAL's ConfidentialClientApplication
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell committed Jul 15, 2019
1 parent 91b42cc commit 63b4d81
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
from .confidential_client_credential import ConfidentialClientCredential
from .msal_transport_adapter import MsalTransportResponse, MsalTransportAdapter
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
"""A credential which wraps an MSAL ConfidentialClientApplication and delegates token acquisition and caching to it.
This entails monkeypatching MSAL's OAuth client with an adapter substituting an azure-core pipeline for Requests.
"""

import time

try:
from typing import TYPE_CHECKING
except ImportError:
TYPE_CHECKING = False

try:
from unittest import mock
except ImportError: # python < 3.3
import mock # type: ignore

if TYPE_CHECKING:
# pylint:disable=unused-import
from typing import Any, Mapping, Optional, Union

from azure.core.credentials import AccessToken
import msal

from .msal_transport_adapter import MsalTransportAdapter


class ConfidentialClientCredential(MsalTransportAdapter):
"""Wraps an MSAL ConfidentialClientApplication with the TokenCredential API"""

def __init__(self, client_id, client_credential, authority, **kwargs):
# type: (str, str, Union[str, Mapping[str, str]], Any) -> None
super(ConfidentialClientCredential, self).__init__(**kwargs)

self._client_id = client_id
self._client_credential = client_credential
self._authority = authority

# postpone creating the wrapped application because its initializer uses the network
self._app = None # type: Optional[msal.ConfidentialClientApplication]

def get_token(self, *scopes):
# type: (str) -> AccessToken

if not self._app:
self._app = self._create_msal_application()

# MSAL requires scopes be a list
scopes = list(scopes) # type: ignore
now = int(time.time())

# First try to get a cached access token or if a refresh token is cached, redeem it for an access token.
# Failing that, acquire a new token.
result = self._app.acquire_token_silent(scopes, account=None) or self._app.acquire_token_for_client(scopes)
return AccessToken(result["access_token"], now + int(result["expires_in"]))

def _create_msal_application(self):
# ConfidentialClientApplication's initializer uses msal.authority to send requests to AAD
with mock.patch("msal.authority.requests", self):
app = msal.ConfidentialClientApplication(
client_id=self._client_id, client_credential=self._client_credential, authority=self._authority
)
# replace the client's requests.Session with adapter
app.client.session = self
return app
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
"""Adapter to substitute an azure-core pipeline for Requests in MSAL application token acquisition methods."""

import json

try:
from typing import TYPE_CHECKING
except ImportError:
TYPE_CHECKING = False

if TYPE_CHECKING:
# pylint:disable=unused-import
from typing import Any, Dict, Mapping, Optional
from azure.core.pipeline import PipelineResponse

from azure.core.configuration import Configuration
from azure.core.exceptions import ClientAuthenticationError
from azure.core.pipeline import Pipeline
from azure.core.pipeline.policies import ContentDecodePolicy, NetworkTraceLoggingPolicy, RetryPolicy
from azure.core.pipeline.transport import HttpRequest, RequestsTransport


class MsalTransportResponse:
"""Wraps an azure-core PipelineResponse with the shape of requests.Response"""

def __init__(self, pipeline_response):
# type: (PipelineResponse) -> None
self._response = pipeline_response.http_response
self.status_code = self._response.status_code
self.text = self._response.text()

def json(self, **kwargs):
# type: (Any) -> Mapping[str, Any]
return json.loads(self.text, **kwargs)

def raise_for_status(self):
# type: () -> None
raise ClientAuthenticationError("authentication failed", self._response)


class MsalTransportAdapter:
"""Wraps an azure-core pipeline with the shape of requests.Session"""

def __init__(self, **kwargs):
# type: (Any) -> None
self._pipeline = self._build_pipeline(**kwargs)

@staticmethod
def create_config(**kwargs):
# type: (Any) -> Configuration
config = Configuration(**kwargs)
config.logging_policy = NetworkTraceLoggingPolicy(**kwargs)
config.retry_policy = RetryPolicy(**kwargs)
return config

def _build_pipeline(self, config=None, policies=None, transport=None, **kwargs):
config = config or self.create_config(**kwargs)
policies = policies or [ContentDecodePolicy(), config.retry_policy, config.logging_policy]
if not transport:
transport = RequestsTransport(configuration=config)
return Pipeline(transport=transport, policies=policies)

def get(self, url, headers=None, params=None, timeout=None, verify=None, **kwargs):
# type: (str, Optional[Mapping[str, str]], Optional[Dict[str, str]], float, bool, Any) -> MsalTransportResponse
request = HttpRequest("GET", url, headers=headers)
if params:
request.format_parameters(params)
response = self._pipeline.run(
request, stream=False, connection_timeout=timeout, connection_verify=verify, **kwargs
)
return MsalTransportResponse(response)

def post(self, url, data=None, headers=None, params=None, timeout=None, verify=None, **kwargs):
# type: (str, Optional[Mapping[str, str]], Optional[Mapping[str, str]], Optional[Dict[str, str]], float, bool, Any) -> MsalTransportResponse
request = HttpRequest("POST", url, headers=headers)
if params:
request.format_parameters(params)
if data:
request.headers["Content-Type"] = "application/x-www-form-urlencoded"
request.set_formdata_body(data)
response = self._pipeline.run(
request, stream=False, connection_timeout=timeout, connection_verify=verify, **kwargs
)
return MsalTransportResponse(response)
22 changes: 13 additions & 9 deletions sdk/identity/azure-identity/tests/test_live.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import os

try:
from unittest import mock
except ImportError: # python < 3.3
import mock # type: ignore

from azure.identity import DefaultAzureCredential, CertificateCredential, ClientSecretCredential
from azure.identity.constants import EnvironmentVariables
import pytest
from azure.identity._internal import ConfidentialClientCredential

ARM_SCOPE = "https://management.azure.com/.default"

Expand Down Expand Up @@ -46,3 +38,15 @@ def test_default_credential(live_identity_settings):
assert token
assert token.token
assert token.expires_on


def test_confidential_client_credential(live_identity_settings):
credential = ConfidentialClientCredential(
client_id=live_identity_settings["client_id"],
client_credential=live_identity_settings["client_secret"],
authority="https://login.microsoftonline.com/" + live_identity_settings["tenant_id"],
)
token = credential.get_token(ARM_SCOPE)
assert token
assert token.token
assert token.expires_on

0 comments on commit 63b4d81

Please sign in to comment.