Skip to content

Commit

Permalink
ManagedIdentityClient handles unexpected content-type (#18137)
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell authored Apr 28, 2021
1 parent e85087d commit 8b1c366
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
from typing import TYPE_CHECKING

from msal import TokenCache
import six

from azure.core.configuration import Configuration
from azure.core.credentials import AccessToken
from azure.core.exceptions import ClientAuthenticationError
from azure.core.exceptions import ClientAuthenticationError, DecodeError
from azure.core.pipeline import Pipeline
from azure.core.pipeline.policies import (
ContentDecodePolicy,
Expand Down Expand Up @@ -58,10 +59,19 @@ def __init__(self, request_factory, client_id=None, **kwargs):
def _process_response(self, response, request_time):
# type: (PipelineResponse, int) -> AccessToken

# ContentDecodePolicy sets this, and should have raised if it couldn't deserialize the response
content = ContentDecodePolicy.deserialize_from_http_generics(response.http_response) # type: dict
if not content:
raise ClientAuthenticationError(message="No token received.", response=response.http_response)
try:
content = ContentDecodePolicy.deserialize_from_text(
response.http_response.text(), mime_type="application/json"
)
if not content:
raise ClientAuthenticationError(message="No token received.", response=response.http_response)
except DecodeError as ex:
if response.http_response.content_type.startswith("application/json"):
message = "Failed to deserialize JSON from response"
else:
message = 'Unexpected content type "{}"'.format(response.http_response.content_type)
six.raise_from(ClientAuthenticationError(message=message, response=response.http_response), ex)

if "access_token" not in content or not ("expires_in" in content or "expires_on" in content):
if content and "access_token" in content:
content["access_token"] = "****"
Expand All @@ -79,7 +89,8 @@ def _process_response(self, response, request_time):

# caching is the final step because TokenCache.add mutates its "event"
self._cache.add(
event={"response": content, "scope": [content["resource"]]}, now=request_time,
event={"response": content, "scope": [content["resource"]]},
now=request_time,
)

return token
Expand Down
61 changes: 61 additions & 0 deletions sdk/identity/azure-identity/tests/test_managed_identity_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import json
import time

from azure.core.exceptions import ClientAuthenticationError
from azure.core.pipeline.transport import HttpRequest
from azure.identity._internal.managed_identity_client import ManagedIdentityClient
import pytest

from helpers import mock, mock_response, Request, validating_transport

Expand Down Expand Up @@ -44,3 +47,61 @@ def test_caching():
token = client.get_cached_token(scope)
assert token.expires_on == expected_expires_on
assert token.token == expected_token


def test_deserializes_json_from_text():
"""The client should gracefully handle a response with a JSON body and content-type text/plain"""

scope = "scope"
now = int(time.time())
expected_expires_on = now + 3600
expected_token = "*"

def send(request, **_):
body = json.dumps(
{
"access_token": expected_token,
"expires_in": 3600,
"expires_on": expected_expires_on,
"resource": scope,
"token_type": "Bearer",
}
)
return mock.Mock(
status_code=200,
headers={"Content-Type": "text/plain"},
content_type="text/plain",
text=lambda encoding=None: body,
)

client = ManagedIdentityClient(
request_factory=lambda _, __: HttpRequest("GET", "http://localhost"), transport=mock.Mock(send=send)
)

token = client.request_token(scope)
assert token.expires_on == expected_expires_on
assert token.token == expected_token


@pytest.mark.parametrize("content_type", ("text/html","application/json"))
def test_unexpected_content(content_type):
content = "<html><body>not JSON</body></html>"

def send(request, **_):
return mock.Mock(
status_code=200,
headers={"Content-Type": content_type},
content_type=content_type,
text=lambda encoding=None: content,
)

client = ManagedIdentityClient(
request_factory=lambda _, __: HttpRequest("GET", "http://localhost"), transport=mock.Mock(send=send)
)

with pytest.raises(ClientAuthenticationError) as ex:
client.request_token("scope")
assert ex.value.response.text() == content

if "json" not in content_type:
assert content_type in ex.value.message
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,21 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import json
import time
from unittest.mock import patch
from unittest.mock import Mock, patch

from azure.core.exceptions import ClientAuthenticationError
from azure.core.pipeline.transport import HttpRequest
from azure.identity.aio._internal.managed_identity_client import AsyncManagedIdentityClient
import pytest

from helpers import mock_response, Request
from helpers_async import async_validating_transport

pytestmark = pytest.mark.asyncio


@pytest.mark.asyncio
async def test_caching():
scope = "scope"
now = int(time.time())
Expand Down Expand Up @@ -48,3 +51,61 @@ async def test_caching():
token = client.get_cached_token(scope)
assert token.expires_on == expected_expires_on
assert token.token == expected_token


async def test_deserializes_json_from_text():
"""The client should gracefully handle a response with a JSON body and content-type text/plain"""

scope = "scope"
now = int(time.time())
expected_expires_on = now + 3600
expected_token = "*"

async def send(request, **_):
body = json.dumps(
{
"access_token": expected_token,
"expires_in": 3600,
"expires_on": expected_expires_on,
"resource": scope,
"token_type": "Bearer",
}
)
return Mock(
status_code=200,
headers={"Content-Type": "text/plain"},
content_type="text/plain",
text=lambda encoding=None: body,
)

client = AsyncManagedIdentityClient(
request_factory=lambda _, __: HttpRequest("GET", "http://localhost"), transport=Mock(send=send)
)

token = await client.request_token(scope)
assert token.expires_on == expected_expires_on
assert token.token == expected_token


@pytest.mark.parametrize("content_type", ("text/html", "application/json"))
async def test_unexpected_content(content_type):
content = "<html><body>not JSON</body></html>"

async def send(request, **_):
return Mock(
status_code=200,
headers={"Content-Type": content_type},
content_type=content_type,
text=lambda encoding=None: content,
)

client = AsyncManagedIdentityClient(
request_factory=lambda _, __: HttpRequest("GET", "http://localhost"), transport=Mock(send=send)
)

with pytest.raises(ClientAuthenticationError) as ex:
await client.request_token("scope")
assert ex.value.response.text() == content

if "json" not in content_type:
assert content_type in ex.value.message

0 comments on commit 8b1c366

Please sign in to comment.