Skip to content

Commit

Permalink
auth
Browse files Browse the repository at this point in the history
  • Loading branch information
jiasli committed Mar 3, 2025
1 parent 95ef126 commit a94fc1c
Show file tree
Hide file tree
Showing 8 changed files with 112 additions and 97 deletions.
39 changes: 19 additions & 20 deletions src/azure-cli-core/azure/cli/core/_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from azure.cli.core._session import ACCOUNT
from azure.cli.core.azclierror import AuthenticationError
from azure.cli.core.cloud import get_active_cloud, set_cloud_subscription
from azure.cli.core.auth.credential_adaptor import CredentialAdaptor
from azure.cli.core.util import in_cloud_console, can_launch_browser, is_github_codespaces
from knack.log import get_logger
from knack.util import CLIError
Expand Down Expand Up @@ -313,9 +314,10 @@ def login_with_managed_identity_azure_arc(self, identity_id=None, allow_no_subsc
import jwt
identity_type = MsiAccountTypes.system_assigned
from .auth.msal_credentials import ManagedIdentityCredential
from .auth.constants import ACCESS_TOKEN

cred = ManagedIdentityCredential()
token = cred.get_token(*self._arm_scope).token
token = cred.acquire_token(self._arm_scope)[ACCESS_TOKEN]
logger.info('Managed identity: token was retrieved. Now trying to initialize local accounts...')
decode = jwt.decode(token, algorithms=['RS256'], options={"verify_signature": False})
tenant = decode['tid']
Expand All @@ -339,9 +341,10 @@ def login_with_managed_identity_azure_arc(self, identity_id=None, allow_no_subsc
def login_in_cloud_shell(self):
import jwt
from .auth.msal_credentials import CloudShellCredential
from .auth.constants import ACCESS_TOKEN

cred = CloudShellCredential()
token = cred.get_token(*self._arm_scope).token
token = cred.acquire_token(self._arm_scope)[ACCESS_TOKEN]
logger.info('Cloud Shell token was retrieved. Now trying to initialize local accounts...')
decode = jwt.decode(token, algorithms=['RS256'], options={"verify_signature": False})
tenant = decode['tid']
Expand Down Expand Up @@ -397,21 +400,19 @@ def get_login_credentials(self, subscription_id=None, aux_subscriptions=None, au
if in_cloud_console() and account[_USER_ENTITY].get(_CLOUD_SHELL_ID):
# Cloud Shell
from .auth.msal_credentials import CloudShellCredential
from azure.cli.core.auth.credential_adaptor import CredentialAdaptor
# The credential must be wrapped by CredentialAdaptor so that it can work with Track 1 SDKs.
cred = CredentialAdaptor(CloudShellCredential())
sdk_cred = CredentialAdaptor(CloudShellCredential())

elif managed_identity_type:
# managed identity
if _on_azure_arc():
from .auth.msal_credentials import ManagedIdentityCredential
from azure.cli.core.auth.credential_adaptor import CredentialAdaptor
# The credential must be wrapped by CredentialAdaptor so that it can work with Track 1 SDKs.
cred = CredentialAdaptor(ManagedIdentityCredential())
sdk_cred = CredentialAdaptor(ManagedIdentityCredential())
else:
# The resource is merely used by msrestazure to get the first access token.
# It is not actually used in an API invocation.
cred = MsiAccountTypes.msi_auth_factory(
sdk_cred = MsiAccountTypes.msi_auth_factory(
managed_identity_type, managed_identity_id,
self.cli_ctx.cloud.endpoints.active_directory_resource_id)

Expand All @@ -431,10 +432,9 @@ def get_login_credentials(self, subscription_id=None, aux_subscriptions=None, au
external_credentials = []
for external_tenant in external_tenants:
external_credentials.append(self._create_credential(account, tenant_id=external_tenant))
from azure.cli.core.auth.credential_adaptor import CredentialAdaptor
cred = CredentialAdaptor(credential, auxiliary_credentials=external_credentials)
sdk_cred = CredentialAdaptor(credential, auxiliary_credentials=external_credentials)

return (cred,
return (sdk_cred,
str(account[_SUBSCRIPTION_ID]),
str(account[_TENANT_ID]))

Expand All @@ -460,24 +460,24 @@ def get_raw_token(self, resource=None, scopes=None, subscription=None, tenant=No
if tenant:
raise CLIError("Tenant shouldn't be specified for Cloud Shell account")
from .auth.msal_credentials import CloudShellCredential
cred = CloudShellCredential()
sdk_cred = CredentialAdaptor(CloudShellCredential())

elif managed_identity_type:
# managed identity
if tenant:
raise CLIError("Tenant shouldn't be specified for managed identity account")
if _on_azure_arc():
from .auth.msal_credentials import ManagedIdentityCredential
cred = ManagedIdentityCredential()
sdk_cred = CredentialAdaptor(ManagedIdentityCredential())
else:
from .auth.util import scopes_to_resource
cred = MsiAccountTypes.msi_auth_factory(managed_identity_type, managed_identity_id,
scopes_to_resource(scopes))
sdk_cred = MsiAccountTypes.msi_auth_factory(managed_identity_type, managed_identity_id,
scopes_to_resource(scopes))

else:
cred = self._create_credential(account, tenant_id=tenant)
sdk_cred = CredentialAdaptor(self._create_credential(account, tenant_id=tenant))

sdk_token = cred.get_token(*scopes)
sdk_token = sdk_cred.get_token(*scopes)
# Convert epoch int 'expires_on' to datetime string 'expiresOn' for backward compatibility
# WARNING: expiresOn is deprecated and will be removed in future release.
import datetime
Expand Down Expand Up @@ -856,7 +856,6 @@ def find_using_common_tenant(self, username, credential=None):
specific_tenant_credential = identity.get_user_credential(username)

try:

subscriptions = self.find_using_specific_tenant(tenant_id, specific_tenant_credential,
tenant_id_description=t)
except AuthenticationError as ex:
Expand Down Expand Up @@ -927,9 +926,9 @@ def _create_subscription_client(self, credential):
raise CLIInternalError("Unable to get '{}' in profile '{}'"
.format(ResourceType.MGMT_RESOURCE_SUBSCRIPTIONS, self.cli_ctx.cloud.profile))
api_version = get_api_version(self.cli_ctx, ResourceType.MGMT_RESOURCE_SUBSCRIPTIONS)
client_kwargs = _prepare_mgmt_client_kwargs_track2(self.cli_ctx, credential)

client = client_type(credential, api_version=api_version,
sdk_cred = CredentialAdaptor(credential)
client_kwargs = _prepare_mgmt_client_kwargs_track2(self.cli_ctx, sdk_cred)
client = client_type(sdk_cred, api_version=api_version,
base_url=self.cli_ctx.cloud.endpoints.resource_manager,
**client_kwargs)
return client
Expand Down
3 changes: 3 additions & 0 deletions src/azure-cli-core/azure/cli/core/auth/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@
# --------------------------------------------------------------------------------------------

AZURE_CLI_CLIENT_ID = '04b07795-8ddb-461a-bbee-02f9e1bf7b46'

ACCESS_TOKEN = 'access_token'
EXPIRES_IN = "expires_in"
15 changes: 8 additions & 7 deletions src/azure-cli-core/azure/cli/core/auth/credential_adaptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@
# --------------------------------------------------------------------------------------------

from knack.log import get_logger
from .util import build_sdk_access_token

logger = get_logger(__name__)


class CredentialAdaptor:
def __init__(self, credential, auxiliary_credentials=None):
"""Cross-tenant credential adaptor. It takes a main credential and auxiliary credentials.
"""Credential adaptor between MSAL credential and SDK credential.
It implements Track 2 SDK's azure.core.credentials.TokenCredential by exposing get_token.
:param credential: Main credential from .msal_authentication
:param auxiliary_credentials: Credentials from .msal_authentication for cross tenant authentication.
Details about cross tenant authentication:
:param credential: MSAL credential from ._msal_credentials
:param auxiliary_credentials: MSAL credentials for cross-tenant authentication.
Details about cross-tenant authentication:
https://learn.microsoft.com/en-us/azure/azure-resource-manager/management/authenticate-multi-tenant
"""

Expand All @@ -32,11 +32,12 @@ def get_token(self, *scopes, **kwargs):
if 'data' in kwargs:
filtered_kwargs['data'] = kwargs['data']

return self._credential.get_token(*scopes, **filtered_kwargs)
return build_sdk_access_token(self._credential.acquire_token(list(scopes), **filtered_kwargs))

def get_auxiliary_tokens(self, *scopes, **kwargs):
"""Get access tokens from auxiliary credentials."""
# To test cross-tenant authentication, see https://github.com/Azure/azure-cli/issues/16691
if self._auxiliary_credentials:
return [cred.get_token(*scopes, **kwargs) for cred in self._auxiliary_credentials]
return [build_sdk_access_token(cred.acquire_token(list(scopes), **kwargs))
for cred in self._auxiliary_credentials]
return None
5 changes: 2 additions & 3 deletions src/azure-cli-core/azure/cli/core/auth/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,8 @@ def login_with_service_principal(self, client_id, credential, scopes):
"""
sp_auth = ServicePrincipalAuth.build_from_credential(self.tenant_id, client_id, credential)
client_credential = sp_auth.get_msal_client_credential()
cca = ConfidentialClientApplication(client_id, client_credential=client_credential, **self._msal_app_kwargs)
result = cca.acquire_token_for_client(scopes)
check_result(result)
cred = ServicePrincipalCredential(client_id, client_credential, **self._msal_app_kwargs)
cred.acquire_token(scopes)

# Only persist the service principal after a successful login
entry = sp_auth.get_entry_to_persist()
Expand Down
57 changes: 24 additions & 33 deletions src/azure-cli-core/azure/cli/core/auth/msal_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,7 @@
# --------------------------------------------------------------------------------------------

"""
Credentials defined in this module are alternative implementations of credentials provided by Azure Identity.
These credentials implement azure.core.credentials.TokenCredential by exposing `get_token` method for Track 2
SDK invocation.
If you want to implement your own credential, the credential must also expose `get_token` method.
`get_token` method takes `scopes` as positional arguments and other optional `kwargs`, such as `claims`, `data`.
The return value should be a named tuple containing two elements: token (str), expires_on (int). You may simply use
azure.cli.core.auth.util.AccessToken to build the return value. See below credentials as examples.
Credentials to acquire tokens from MSAL.
"""

from knack.log import get_logger
Expand All @@ -22,15 +13,15 @@
ManagedIdentityClient, SystemAssignedManagedIdentity)

from .constants import AZURE_CLI_CLIENT_ID
from .util import check_result, build_sdk_access_token
from .util import check_result

logger = get_logger(__name__)


class UserCredential: # pylint: disable=too-few-public-methods

def __init__(self, client_id, username, **kwargs):
"""User credential implementing get_token interface.
"""User credential wrapping msal.application.PublicClientApplication
:param client_id: Client ID of the CLI.
:param username: The username for user credential.
Expand All @@ -52,14 +43,16 @@ def __init__(self, client_id, username, **kwargs):

self._account = accounts[0]

def get_token(self, *scopes, claims=None, **kwargs):
# scopes = ['https://pas.windows.net/CheckMyAccess/Linux/.default']
logger.debug("UserCredential.get_token: scopes=%r, claims=%r, kwargs=%r", scopes, claims, kwargs)
def acquire_token(self, scopes, claims=None, **kwargs):
# scopes must be a list.
# For acquiring SSH certificate, scopes is ['https://pas.windows.net/CheckMyAccess/Linux/.default']
# kwargs is already sanitized by CredentialAdaptor, so it can be safely passed to MSAL
logger.debug("UserCredential.acquire_token: scopes=%r, claims=%r, kwargs=%r", scopes, claims, kwargs)

if claims:
logger.warning('Acquiring new access token silently for tenant %s with claims challenge: %s',
self._msal_app.authority.tenant, claims)
result = self._msal_app.acquire_token_silent_with_error(list(scopes), self._account, claims_challenge=claims,
result = self._msal_app.acquire_token_silent_with_error(scopes, self._account, claims_challenge=claims,
**kwargs)

from azure.cli.core.azclierror import AuthenticationError
Expand All @@ -82,7 +75,7 @@ def get_token(self, *scopes, claims=None, **kwargs):
success_template, error_template = read_response_templates()

result = self._msal_app.acquire_token_interactive(
list(scopes), login_hint=self._account['username'],
scopes, login_hint=self._account['username'],
port=8400 if self._msal_app.authority.is_adfs else None,
success_template=success_template, error_template=error_template, **kwargs)
check_result(result)
Expand All @@ -91,25 +84,24 @@ def get_token(self, *scopes, claims=None, **kwargs):
# launch browser, but show the error message and `az login` command instead.
else:
raise
return build_sdk_access_token(result)
return result


class ServicePrincipalCredential: # pylint: disable=too-few-public-methods

def __init__(self, client_id, client_credential, **kwargs):
"""Service principal credential implementing get_token interface.
"""Service principal credential wrapping msal.application.ConfidentialClientApplication.
:param client_id: The service principal's client ID.
:param client_credential: client_credential that will be passed to MSAL.
"""
self._msal_app = ConfidentialClientApplication(client_id, client_credential, **kwargs)

def get_token(self, *scopes, **kwargs):
logger.debug("ServicePrincipalCredential.get_token: scopes=%r, kwargs=%r", scopes, kwargs)
self._msal_app = ConfidentialClientApplication(client_id, client_credential=client_credential, **kwargs)

result = self._msal_app.acquire_token_for_client(list(scopes), **kwargs)
def acquire_token(self, scopes, **kwargs):
logger.debug("ServicePrincipalCredential.acquire_token: scopes=%r, kwargs=%r", scopes, kwargs)
result = self._msal_app.acquire_token_for_client(scopes, **kwargs)
check_result(result)
return build_sdk_access_token(result)
return result


class CloudShellCredential: # pylint: disable=too-few-public-methods
Expand All @@ -126,12 +118,11 @@ def __init__(self):
# token_cache=...
)

def get_token(self, *scopes, **kwargs):
logger.debug("CloudShellCredential.get_token: scopes=%r, kwargs=%r", scopes, kwargs)
# kwargs is already sanitized by CredentialAdaptor, so it can be safely passed to MSAL
result = self._msal_app.acquire_token_interactive(list(scopes), prompt="none", **kwargs)
def acquire_token(self, scopes, **kwargs):
logger.debug("CloudShellCredential.acquire_token: scopes=%r, kwargs=%r", scopes, kwargs)
result = self._msal_app.acquire_token_interactive(scopes, prompt="none", **kwargs)
check_result(result, scopes=scopes)
return build_sdk_access_token(result)
return result


class ManagedIdentityCredential: # pylint: disable=too-few-public-methods
Expand All @@ -143,10 +134,10 @@ def __init__(self):
import requests
self._msal_client = ManagedIdentityClient(SystemAssignedManagedIdentity(), http_client=requests.Session())

def get_token(self, *scopes, **kwargs):
logger.debug("ManagedIdentityCredential.get_token: scopes=%r, kwargs=%r", scopes, kwargs)
def acquire_token(self, scopes, **kwargs):
logger.debug("ManagedIdentityCredential.acquire_token: scopes=%r, kwargs=%r", scopes, kwargs)

from .util import scopes_to_resource
result = self._msal_client.acquire_token_for_client(resource=scopes_to_resource(scopes))
check_result(result)
return build_sdk_access_token(result)
return result
11 changes: 7 additions & 4 deletions src/azure-cli-core/azure/cli/core/auth/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,6 @@ def check_result(result, **kwargs):


def build_sdk_access_token(token_entry):
import time
request_time = int(time.time())

# MSAL token entry sample:
# {
# 'access_token': 'eyJ0eXAiOiJKV...',
Expand All @@ -153,7 +150,8 @@ def build_sdk_access_token(token_entry):
# Importing azure.core.credentials.AccessToken is expensive.
# This can slow down commands that doesn't need azure.core, like `az account get-access-token`.
# So We define our own AccessToken.
return AccessToken(token_entry["access_token"], request_time + token_entry["expires_in"])
from .constants import ACCESS_TOKEN, EXPIRES_IN
return AccessToken(token_entry[ACCESS_TOKEN], _now_timestamp() + token_entry[EXPIRES_IN])


def decode_access_token(access_token):
Expand All @@ -177,3 +175,8 @@ def read_response_templates():
error_template = f.read()

return success_template, error_template


def _now_timestamp():
import time
return int(time.time())
Loading

0 comments on commit a94fc1c

Please sign in to comment.