Skip to content

Commit

Permalink
[Core] Conditional Access: Show --scope for az login message when…
Browse files Browse the repository at this point in the history
… failed to refresh the access token (#17738)
  • Loading branch information
jiasli authored Aug 17, 2021
1 parent 08db2f8 commit 14cc787
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 39 deletions.
18 changes: 7 additions & 11 deletions src/azure-cli-core/azure/cli/core/_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from azure.cli.core._environment import get_config_dir
from azure.cli.core._session import ACCOUNT
from azure.cli.core.util import get_file_json, in_cloud_console, open_page_in_browser, can_launch_browser,\
is_windows, is_wsl, scopes_to_resource
is_windows, is_wsl, scopes_to_resource, resource_to_scopes
from azure.cli.core.cloud import get_active_cloud, set_cloud_subscription

logger = get_logger(__name__)
Expand Down Expand Up @@ -574,11 +574,7 @@ def get_login_credentials(self, resource=None, subscription_id=None, aux_subscri
"Please run `az login` with a user account or a service principal.")

if identity_type is None:
def _retrieve_token(sdk_resource=None):
# When called by
# - Track 1 SDK, use `resource` specified by CLI
# - Track 2 SDK, use `sdk_resource` specified by SDK and ignore `resource` specified by CLI
token_resource = sdk_resource or resource
def _retrieve_token(token_resource):
logger.debug("Retrieving token from ADAL for resource %r", token_resource)

if in_cloud_console() and account[_USER_ENTITY].get(_CLOUD_SHELL_ID):
Expand All @@ -591,8 +587,7 @@ def _retrieve_token(sdk_resource=None):
account[_TENANT_ID],
use_cert_sn_issuer)

def _retrieve_tokens_from_external_tenants(sdk_resource=None):
token_resource = sdk_resource or resource
def _retrieve_tokens_from_external_tenants(token_resource):
logger.debug("Retrieving token from ADAL for external tenants and resource %r", token_resource)

external_tokens = []
Expand All @@ -607,7 +602,8 @@ def _retrieve_tokens_from_external_tenants(sdk_resource=None):

from azure.cli.core.adal_authentication import AdalAuthentication
auth_object = AdalAuthentication(_retrieve_token,
_retrieve_tokens_from_external_tenants if external_tenants_info else None)
_retrieve_tokens_from_external_tenants if external_tenants_info else None,
resource=resource)
else:
if self._msi_creds is None:
self._msi_creds = MsiAccountTypes.msi_auth_factory(identity_type, identity_id, resource)
Expand Down Expand Up @@ -675,7 +671,7 @@ def get_msal_token(self, scopes, data):
raise CLIError("Unknown identity type {}".format(identity_type))

if 'error' in result:
from azure.cli.core.adal_authentication import aad_error_handler
from azure.cli.core.auth.util import aad_error_handler
aad_error_handler(result)

return username_or_sp_id, result["access_token"]
Expand Down Expand Up @@ -721,7 +717,7 @@ def get_raw_token(self, resource=None, subscription=None, tenant=None):
use_cert_sn_issuer)
except adal.AdalError as ex:
from azure.cli.core.adal_authentication import adal_error_handler
adal_error_handler(ex)
adal_error_handler(ex, scopes=resource_to_scopes(resource))
return (creds,
None if tenant else str(account[_SUBSCRIPTION_ID]),
str(tenant if tenant else account[_TENANT_ID]))
Expand Down
36 changes: 15 additions & 21 deletions src/azure-cli-core/azure/cli/core/adal_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from msrest.authentication import Authentication
from msrestazure.azure_active_directory import MSIAuthentication
from azure.core.credentials import AccessToken
from azure.cli.core.util import in_cloud_console, scopes_to_resource
from azure.cli.core.util import in_cloud_console, scopes_to_resource, resource_to_scopes

from knack.util import CLIError
from knack.log import get_logger
Expand All @@ -19,7 +19,7 @@

class AdalAuthentication(Authentication): # pylint: disable=too-few-public-methods

def __init__(self, token_retriever, external_tenant_token_retriever=None):
def __init__(self, token_retriever, external_tenant_token_retriever=None, resource=None):
# DO NOT call _token_retriever from outside azure-cli-core. It is only available for user or
# Service Principal credential (AdalAuthentication), but not for Managed Identity credential
# (MSIAuthenticationWrapper).
Expand All @@ -28,24 +28,31 @@ def __init__(self, token_retriever, external_tenant_token_retriever=None):
# - AdalAuthentication.get_token, which is designed for Track 2 SDKs
self._token_retriever = token_retriever
self._external_tenant_token_retriever = external_tenant_token_retriever
self._resource = resource

def _get_token(self, sdk_resource=None):
"""
:param sdk_resource: `resource` converted from Track 2 SDK's `scopes`
"""

# When called by
# - Track 1 SDK, use `resource` specified by CLI
# - Track 2 SDK, use `sdk_resource` specified by SDK and ignore `resource` specified by CLI
token_resource = sdk_resource or self._resource

external_tenant_tokens = None
try:
scheme, token, token_entry = self._token_retriever(sdk_resource)
scheme, token, token_entry = self._token_retriever(token_resource)
if self._external_tenant_token_retriever:
external_tenant_tokens = self._external_tenant_token_retriever(sdk_resource)
external_tenant_tokens = self._external_tenant_token_retriever(token_resource)
except CLIError as err:
if in_cloud_console():
AdalAuthentication._log_hostname()
raise err
except adal.AdalError as err:
if in_cloud_console():
AdalAuthentication._log_hostname()
adal_error_handler(err)
adal_error_handler(err, scopes=resource_to_scopes(token_resource))
except requests.exceptions.SSLError as err:
from .util import SSLERROR_TEMPLATE
raise CLIError(SSLERROR_TEMPLATE.format(str(err)))
Expand Down Expand Up @@ -236,24 +243,11 @@ def _timestamp(dt):
return dt.timestamp()


def aad_error_handler(error: dict):
""" Handle the error from AAD server returned by ADAL or MSAL. """
login_message = ("To re-authenticate, please {}. If the problem persists, "
"please contact your tenant administrator."
.format("refresh Azure Portal" if in_cloud_console() else "run `az login`"))

# https://docs.microsoft.com/en-us/azure/active-directory/develop/reference-aadsts-error-codes
# Search for an error code at https://login.microsoftonline.com/error
msg = error.get('error_description')

from azure.cli.core.azclierror import AuthenticationError
raise AuthenticationError(msg, login_message)


def adal_error_handler(err: adal.AdalError):
def adal_error_handler(err: adal.AdalError, **kwargs):
""" Handle AdalError. """
try:
aad_error_handler(err.error_response)
from azure.cli.core.auth.util import aad_error_handler
aad_error_handler(err.error_response, **kwargs)
except AttributeError:
# In case of AdalError created as
# AdalError('More than one token matches the criteria. The result is ambiguous.')
Expand Down
45 changes: 45 additions & 0 deletions src/azure-cli-core/azure/cli/core/auth/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------


def aad_error_handler(error, **kwargs):
""" Handle the error from AAD server returned by ADAL or MSAL. """

# https://docs.microsoft.com/en-us/azure/active-directory/develop/reference-aadsts-error-codes
# Search for an error code at https://login.microsoftonline.com/error
msg = error.get('error_description')
login_message = _generate_login_message(**kwargs)

from azure.cli.core.azclierror import AuthenticationError
raise AuthenticationError(msg, recommendation=login_message)


def _generate_login_command(scopes=None):
login_command = ['az login']

if scopes:
login_command.append('--scope {}'.format(' '.join(scopes)))

return ' '.join(login_command)


def _generate_login_message(**kwargs):
from azure.cli.core.util import in_cloud_console
login_command = _generate_login_command(**kwargs)

msg = "To re-authenticate, please {}" .format(
"refresh Azure Portal." if in_cloud_console() else "run:\n{}".format(login_command))

return msg


def decode_access_token(access_token):
# Decode the access token. We can do the same with https://jwt.ms
from msal.oauth2cli.oidc import decode_part
import json

# Access token consists of headers.claims.signature. Decode the claim part
decoded_str = decode_part(access_token.split('.')[1])
return json.loads(decoded_str)
15 changes: 8 additions & 7 deletions src/azure-cli-core/azure/cli/core/tests/test_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def setUpClass(cls):
'e-lOym1sH5iOcxfIjXF0Tp2y0f3zM7qCq8Cp1ZxEwz6xYIgByoxjErNXrOME5Ld1WizcsaWxTXpwxJn_'
'Q8U2g9kXHrbYFeY2gJxF_hnfLvNKxUKUBnftmyYxZwKi0GDS0BvdJnJnsqSRSpxUx__Ra9QJkG1IaDzj'
'ZcSZPHK45T6ohK9Hk9ktZo0crVl7Tmw')
cls.arm_resource = 'https://management.core.windows.net/'

def test_normalize(self):
cli = DummyCli()
Expand Down Expand Up @@ -551,7 +552,7 @@ def test_get_login_credentials(self, mock_get_token, mock_read_cred_file):
self.assertEqual(subscription_id, test_subscription_id)

# verify the cred._tokenRetriever is a working lambda
token_type, token = cred._token_retriever()
token_type, token = cred._token_retriever(self.arm_resource)
self.assertEqual(token, self.raw_token1)
self.assertEqual(some_token_type, token_type)
mock_get_token.assert_called_once_with(mock.ANY, self.user1, test_tenant_id,
Expand Down Expand Up @@ -595,11 +596,11 @@ def test_get_login_credentials_aux_subscriptions(self, mock_get_token, mock_read
self.assertEqual(subscription_id, test_subscription_id)

# verify the cred._tokenRetriever is a working lambda
token_type, token = cred._token_retriever()
token_type, token = cred._token_retriever(self.arm_resource)
self.assertEqual(token, self.raw_token1)
self.assertEqual(some_token_type, token_type)

token2 = cred._external_tenant_token_retriever()
token2 = cred._external_tenant_token_retriever(self.arm_resource)
self.assertEqual(len(token2), 1)
self.assertEqual(token2[0][1], raw_token2)

Expand Down Expand Up @@ -642,11 +643,11 @@ def test_get_login_credentials_aux_tenants(self, mock_get_token, mock_read_cred_
self.assertEqual(subscription_id, test_subscription_id)

# verify the cred._tokenRetriever is a working lambda
token_type, token = cred._token_retriever()
token_type, token = cred._token_retriever(self.arm_resource)
self.assertEqual(token, self.raw_token1)
self.assertEqual(some_token_type, token_type)

token2 = cred._external_tenant_token_retriever()
token2 = cred._external_tenant_token_retriever(self.arm_resource)
self.assertEqual(len(token2), 1)
self.assertEqual(token2[0][1], raw_token2)

Expand Down Expand Up @@ -949,7 +950,7 @@ def test_get_login_credentials_for_graph_client(self, mock_get_token, mock_read_
# action
cred, _, tenant_id = profile.get_login_credentials(
resource=cli.cloud.endpoints.active_directory_graph_resource_id)
_, _ = cred._token_retriever()
_, _ = cred._token_retriever('https://graph.windows.net/')
# verify
mock_get_token.assert_called_once_with(mock.ANY, self.user1, self.tenant_id,
'https://graph.windows.net/')
Expand All @@ -971,7 +972,7 @@ def test_get_login_credentials_for_data_lake_client(self, mock_get_token, mock_r
# action
cred, _, tenant_id = profile.get_login_credentials(
resource=cli.cloud.endpoints.active_directory_data_lake_resource_id)
_, _ = cred._token_retriever()
_, _ = cred._token_retriever('https://datalake.azure.net/')
# verify
mock_get_token.assert_called_once_with(mock.ANY, self.user1, self.tenant_id,
'https://datalake.azure.net/')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from knack.util import CLIError


@unittest.skip("Out of maintenance")
class TestProfile(unittest.TestCase):

@classmethod
Expand Down
3 changes: 3 additions & 0 deletions src/azure-cli-testsdk/azure/cli/testsdk/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,9 @@ def __init__(self, method_name):
self.kwargs = {}
self.test_resources_count = 0

def setUp(self):
patch_main_exception_handler(self)

def cmd(self, command, checks=None, expect_failure=False):
command = self._apply_kwargs(command)
return execute(self.cli_ctx, command, expect_failure=expect_failure).assert_with_checks(checks)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------

from azure.cli.core.azclierror import AuthenticationError
from azure.cli.testsdk import LiveScenarioTest
from azure.cli.core.auth.util import decode_access_token

ARM_URL = "https://eastus2euap.management.azure.com/" # ARM canary
ARM_MAX_RETRY = 30
ARM_RETRY_INTERVAL = 10


class ConditionalAccessScenarioTest(LiveScenarioTest):

def setUp(self):
super().setUp()
# Clear MSAL cache to avoid unexpected tokens from cache
self.cmd('az account clear')

def test_conditional_access_mfa(self):
"""
This test should be run using a user account that
- doesn't require MFA for ARM
- requires MFA for data-plane resource
The result ATs are checked per https://docs.microsoft.com/en-us/azure/active-directory/develop/access-tokens
Following claims are checked:
- aud (Audience): https://tools.ietf.org/html/rfc7519#section-4.1.3
- amr (Authentication Method Reference): https://tools.ietf.org/html/rfc8176
"""

resource = 'https://pas.windows.net/CheckMyAccess/Linux'
scope = resource + '/.default'

self.kwargs['scope'] = scope
self.kwargs['resource'] = resource

# region non-MFA session

# Login to ARM (MFA not required)
# In the browser, if the user already exists, make sure to logout first and re-login to clear browser cache
self.cmd('az login')

# Getting ARM AT and check claims
result = self.cmd('az account get-access-token').get_output_in_json()
decoded = decode_access_token(result['accessToken'])
assert decoded['aud'] == self.cli_ctx.cloud.endpoints.active_directory_resource_id
assert decoded['amr'] == ['pwd']

# Getting data-plane AT with ARM RT (step-up) fails
with self.assertRaises(AuthenticationError) as cm:
self.cmd('az account get-access-token --resource {resource}')

# Check re-login recommendation
re_login_command = 'az login --scope {scope}'.format(**self.kwargs)
assert 'AADSTS50076' in cm.exception.error_msg
assert re_login_command in cm.exception.recommendations[0]

# endregion

# region MFA session

# Re-login with data-plane scope (MFA required)
# Getting ARM AT with data-plane RT (step-down) succeeds
self.cmd(re_login_command)

# Getting ARM AT and check claims
result = self.cmd('az account get-access-token').get_output_in_json()
decoded = decode_access_token(result['accessToken'])
assert decoded['aud'] == self.cli_ctx.cloud.endpoints.active_directory_resource_id
assert decoded['amr'] == ['pwd']

# Getting data-plane AT and check claims
result = self.cmd('az account get-access-token --resource {resource}').get_output_in_json()
decoded = decode_access_token(result['accessToken'])
assert decoded['aud'] in scope
assert decoded['amr'] == ['pwd', 'mfa']

# endregion

0 comments on commit 14cc787

Please sign in to comment.